Commit bd689f40 authored by illsilin's avatar illsilin
Browse files

merge from public repo

parents c160c6cf a94113a9
...@@ -76,8 +76,11 @@ std::string SequenceStr(const std::vector<int>& v); ...@@ -76,8 +76,11 @@ std::string SequenceStr(const std::vector<int>& v);
std::string MakeTuple(const std::vector<std::string>& v); std::string MakeTuple(const std::vector<std::string>& v);
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
template <int... xs> template <int... xs>
const std::string S = SequenceStr({xs...}); const std::string S = SequenceStr({xs...});
#pragma clang diagnostic pop
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "ck/host/device_gemm_multiple_d/operation.hpp" #include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/stringutils.hpp" #include "ck/host/stringutils.hpp"
#include "ck/host/types.hpp"
#include "ck/host/utils.hpp" #include "ck/host/utils.hpp"
#include <cassert> #include <cassert>
...@@ -32,11 +33,11 @@ static std::string GetGemmSpec(const std::size_t m, ...@@ -32,11 +33,11 @@ static std::string GetGemmSpec(const std::size_t m,
} }
// function to update prologue/epilogue with user provided operation // function to update prologue/epilogue with user provided operation
void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue) void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
{ {
if(!prologue.empty()) if(!pro.empty())
{ {
this->prologue = prologue; this->prologue = pro;
this->cde_elem_op = "CDEElementOp"; this->cde_elem_op = "CDEElementOp";
} }
else else
...@@ -45,11 +46,11 @@ void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue) ...@@ -45,11 +46,11 @@ void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue)
} }
} }
void Operation_Xdl_CShuffle::update_epilogue(const std::string& epilogue) void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
{ {
if(!epilogue.empty()) if(!epi.empty())
{ {
this->epilogue = epilogue; this->epilogue = epi;
this->cde_elem_op = "CDEElementOp"; this->cde_elem_op = "CDEElementOp";
} }
else else
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include <iostream> #include <iostream>
#include "ck/host/stringutils.hpp" #include "ck/host/stringutils.hpp"
#include "ck/host/types.hpp"
#include "ck/host/utils.hpp" #include "ck/host/utils.hpp"
#include <cassert> #include <cassert>
...@@ -11,34 +12,15 @@ namespace ck { ...@@ -11,34 +12,15 @@ namespace ck {
namespace host { namespace host {
namespace conv { namespace conv {
// calculate appropriate Gemm Specification based on input tensor dimensions // NOTE: in CK, MNKPadding is always used for forward convolution, so didn't
// NOTE: in CK, MNKPadding is always used for forward convolution // add GemmSpec function here
static std::string GetGemmSpec(const std::size_t m,
const std::size_t n,
const std::size_t k,
const std::size_t m_per_block,
const std::size_t n_per_block,
const std::size_t k_per_block)
{
std::string spec = "";
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K";
if(spec == "")
return "ck::tensor_operation::device::GemmSpecialization::Default";
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
// function to update prologue/epilogue with user provided operation // function to update prologue/epilogue with user provided operation
void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologue) void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& pro)
{ {
if(!prologue.empty()) if(!pro.empty())
{ {
this->prologue = prologue; this->prologue = pro;
this->cde_elem_op = "CDEElementOp"; this->cde_elem_op = "CDEElementOp";
} }
else else
...@@ -47,11 +29,11 @@ void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologu ...@@ -47,11 +29,11 @@ void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologu
} }
} }
void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epilogue) void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epi)
{ {
if(!epilogue.empty()) if(!epi.empty())
{ {
this->epilogue = epilogue; this->epilogue = epi;
this->cde_elem_op = "CDEElementOp"; this->cde_elem_op = "CDEElementOp";
} }
else else
...@@ -233,6 +215,12 @@ extern "C" __global__ void run_${name}( ...@@ -233,6 +215,12 @@ extern "C" __global__ void run_${name}(
${BElementwiseOperation}{}, ${BElementwiseOperation}{},
${CDEElementwiseOperation}{1.0f, 1.0f}); ${CDEElementwiseOperation}{1.0f, 1.0f});
if(!DeviceConv::IsSupportedArgument(arg))
{
printf("Arguement is not supported.\n");
return;
};
constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(); constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler();
// GridwiseGemm // GridwiseGemm
......
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
namespace ck { namespace ck {
namespace host { namespace host {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
const std::string config_header = ""; const std::string config_header = "";
#pragma clang diagnostic pop
std::unordered_map<std::string_view, std::string_view> GetHeaders() std::unordered_map<std::string_view, std::string_view> GetHeaders()
{ {
......
...@@ -4,7 +4,9 @@ file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) ...@@ -4,7 +4,9 @@ file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
foreach(TEST_SRC ${TEST_SRCS}) foreach(TEST_SRC ${TEST_SRCS})
set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP) set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP)
get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE)
rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC}) add_executable(test_host_${BASE_NAME} ${TEST_SRC})
add_dependencies(codegen test_host_${BASE_NAME})
add_test(NAME codegen_test_${BASE_NAME} COMMAND test_host_${BASE_NAME})
target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host) target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host)
# target_link_libraries(test_host_${BASE_NAME} ${CK_ROOT}/build/lib/libutility.a) # target_link_libraries(test_host_${BASE_NAME} ${CK_ROOT}/build/lib/libutility.a)
target_include_directories(test_host_${BASE_NAME} PUBLIC include()) target_include_directories(test_host_${BASE_NAME} PUBLIC include())
......
...@@ -92,7 +92,6 @@ struct Epilogue ...@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C), static_cast<int>(prob.C),
static_cast<int>(prob.Y), static_cast<int>(prob.Y),
static_cast<int>(prob.X)}; static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C), ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C), static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
...@@ -109,7 +108,6 @@ struct Epilogue ...@@ -109,7 +108,6 @@ struct Epilogue
1, 1,
static_cast<int>(prob.X * prob.C), static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)}; static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2}; ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
...@@ -92,7 +92,6 @@ struct Epilogue ...@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C), static_cast<int>(prob.C),
static_cast<int>(prob.Y), static_cast<int>(prob.Y),
static_cast<int>(prob.X)}; static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C), ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C), static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
...@@ -109,7 +108,6 @@ struct Epilogue ...@@ -109,7 +108,6 @@ struct Epilogue
1, 1,
static_cast<int>(prob.X * prob.C), static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)}; static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
...@@ -92,7 +92,6 @@ struct Epilogue ...@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C), static_cast<int>(prob.C),
static_cast<int>(prob.Y), static_cast<int>(prob.Y),
static_cast<int>(prob.X)}; static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C), ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C), static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
...@@ -109,7 +108,6 @@ struct Epilogue ...@@ -109,7 +108,6 @@ struct Epilogue
1, 1,
static_cast<int>(prob.X * prob.C), static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)}; static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2}; ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
...@@ -92,7 +92,6 @@ struct Epilogue ...@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C), static_cast<int>(prob.C),
static_cast<int>(prob.Y), static_cast<int>(prob.Y),
static_cast<int>(prob.X)}; static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C), ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C), static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
...@@ -109,7 +108,6 @@ struct Epilogue ...@@ -109,7 +108,6 @@ struct Epilogue
1, 1,
static_cast<int>(prob.X * prob.C), static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)}; static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
rocm-docs-core==1.4.1 rocm-docs-core==1.7.1
sphinxcontrib-bibtex==2.6.2 sphinxcontrib-bibtex==2.6.2
...@@ -4,33 +4,33 @@ ...@@ -4,33 +4,33 @@
# #
# pip-compile requirements.in # pip-compile requirements.in
# #
accessible-pygments==0.0.3 accessible-pygments==0.0.5
# via pydata-sphinx-theme # via pydata-sphinx-theme
alabaster==0.7.13 alabaster==0.7.16
# via sphinx # via sphinx
babel==2.12.1 babel==2.15.0
# via # via
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
beautifulsoup4==4.11.2 beautifulsoup4==4.12.3
# via pydata-sphinx-theme # via pydata-sphinx-theme
breathe==4.34.0 breathe==4.35.0
# via rocm-docs-core # via rocm-docs-core
certifi==2023.7.22 certifi==2024.7.4
# via requests # via requests
cffi==1.15.1 cffi==1.16.0
# via # via
# cryptography # cryptography
# pynacl # pynacl
charset-normalizer==3.1.0 charset-normalizer==3.3.2
# via requests # via requests
click==8.1.3 click==8.1.7
# via sphinx-external-toc # via sphinx-external-toc
cryptography==41.0.6 cryptography==43.0.0
# via pyjwt # via pyjwt
deprecated==1.2.13 deprecated==1.2.14
# via pygithub # via pygithub
docutils==0.16 docutils==0.21.2
# via # via
# breathe # breathe
# myst-parser # myst-parser
...@@ -38,35 +38,35 @@ docutils==0.16 ...@@ -38,35 +38,35 @@ docutils==0.16
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
# sphinxcontrib-bibtex # sphinxcontrib-bibtex
fastjsonschema==2.18.0 fastjsonschema==2.20.0
# via rocm-docs-core # via rocm-docs-core
gitdb==4.0.10 gitdb==4.0.11
# via gitpython # via gitpython
gitpython==3.1.37 gitpython==3.1.43
# via rocm-docs-core # via rocm-docs-core
idna==3.4 idna==3.7
# via requests # via requests
imagesize==1.4.1 imagesize==1.4.1
# via sphinx # via sphinx
jinja2==3.1.2 jinja2==3.1.4
# via # via
# myst-parser # myst-parser
# sphinx # sphinx
latexcodec==2.0.1 latexcodec==3.0.0
# via pybtex # via pybtex
markdown-it-py==2.2.0 markdown-it-py==3.0.0
# via # via
# mdit-py-plugins # mdit-py-plugins
# myst-parser # myst-parser
markupsafe==2.1.2 markupsafe==2.1.5
# via jinja2 # via jinja2
mdit-py-plugins==0.3.5 mdit-py-plugins==0.4.1
# via myst-parser # via myst-parser
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
myst-parser==1.0.0 myst-parser==3.0.1
# via rocm-docs-core # via rocm-docs-core
packaging==23.0 packaging==24.1
# via # via
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
...@@ -74,48 +74,46 @@ pybtex==0.24.0 ...@@ -74,48 +74,46 @@ pybtex==0.24.0
# via # via
# pybtex-docutils # pybtex-docutils
# sphinxcontrib-bibtex # sphinxcontrib-bibtex
pybtex-docutils==1.0.2 pybtex-docutils==1.0.3
# via sphinxcontrib-bibtex # via sphinxcontrib-bibtex
pycparser==2.21 pycparser==2.22
# via cffi # via cffi
pydata-sphinx-theme==0.13.3 pydata-sphinx-theme==0.15.4
# via # via
# rocm-docs-core # rocm-docs-core
# sphinx-book-theme # sphinx-book-theme
pygithub==1.58.1 pygithub==2.3.0
# via rocm-docs-core # via rocm-docs-core
pygments==2.15.0 pygments==2.18.0
# via # via
# accessible-pygments # accessible-pygments
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
pyjwt[crypto]==2.6.0 pyjwt[crypto]==2.8.0
# via pygithub # via pygithub
pynacl==1.5.0 pynacl==1.5.0
# via pygithub # via pygithub
pyyaml==6.0 pyyaml==6.0.1
# via # via
# myst-parser # myst-parser
# pybtex # pybtex
# rocm-docs-core # rocm-docs-core
# sphinx-external-toc # sphinx-external-toc
requests==2.31.0 requests==2.32.3
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.4.1 rocm-docs-core==1.7.1
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via pybtex
# latexcodec smmap==5.0.1
# pybtex
smmap==5.0.0
# via gitdb # via gitdb
snowballstemmer==2.2.0 snowballstemmer==2.2.0
# via sphinx # via sphinx
soupsieve==2.4 soupsieve==2.5
# via beautifulsoup4 # via beautifulsoup4
sphinx==5.3.0 sphinx==7.4.7
# via # via
# breathe # breathe
# myst-parser # myst-parser
...@@ -127,33 +125,39 @@ sphinx==5.3.0 ...@@ -127,33 +125,39 @@ sphinx==5.3.0
# sphinx-external-toc # sphinx-external-toc
# sphinx-notfound-page # sphinx-notfound-page
# sphinxcontrib-bibtex # sphinxcontrib-bibtex
sphinx-book-theme==1.0.1 sphinx-book-theme==1.1.3
# via rocm-docs-core # via rocm-docs-core
sphinx-copybutton==0.5.1 sphinx-copybutton==0.5.2
# via rocm-docs-core # via rocm-docs-core
sphinx-design==0.4.1 sphinx-design==0.6.0
# via rocm-docs-core # via rocm-docs-core
sphinx-external-toc==0.3.1 sphinx-external-toc==1.0.1
# via rocm-docs-core # via rocm-docs-core
sphinx-notfound-page==0.8.3 sphinx-notfound-page==1.0.3
# via rocm-docs-core # via rocm-docs-core
sphinxcontrib-applehelp==1.0.4 sphinxcontrib-applehelp==2.0.0
# via sphinx # via sphinx
sphinxcontrib-bibtex==2.6.2 sphinxcontrib-bibtex==2.6.2
# via -r requirements.in # via -r requirements.in
sphinxcontrib-devhelp==1.0.2 sphinxcontrib-devhelp==2.0.0
# via sphinx # via sphinx
sphinxcontrib-htmlhelp==2.0.1 sphinxcontrib-htmlhelp==2.1.0
# via sphinx # via sphinx
sphinxcontrib-jsmath==1.0.1 sphinxcontrib-jsmath==1.0.1
# via sphinx # via sphinx
sphinxcontrib-qthelp==1.0.3 sphinxcontrib-qthelp==2.0.0
# via sphinx # via sphinx
sphinxcontrib-serializinghtml==1.1.5 sphinxcontrib-serializinghtml==2.0.0
# via sphinx # via sphinx
typing-extensions==4.5.0 tomli==2.0.1
# via pydata-sphinx-theme # via sphinx
urllib3==1.26.18 typing-extensions==4.12.2
# via requests # via
wrapt==1.15.0 # pydata-sphinx-theme
# pygithub
urllib3==2.2.2
# via
# pygithub
# requests
wrapt==1.16.0
# via deprecated # via deprecated
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
...@@ -28,14 +28,14 @@ using DeviceGemmV2Instance = ...@@ -28,14 +28,14 @@ using DeviceGemmV2Instance =
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault, PassThrough, PassThrough, PassThrough, GemmDefault,
256, 256,
128, 256, 224, 256,
128, 16, 16, 128, 16, 16,
16, 16, 16, 16,
4, 8, 7, 8,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 1, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 1, 2, 16, 16, 0,
1, 2, S<1, 32, 1, 8>, 8, 1, 2, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ck::f8_t>; ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ck::f8_t>;
// clang-format on // clang-format on
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
...@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc, ...@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc,
inline HostTensorDescriptor inline HostTensorDescriptor
make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size) make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size)
{ {
std::vector<ck::index_t> dimensions{problem_size.G_, problem_size.N_}; std::vector<ck::long_index_t> dimensions{problem_size.G_, problem_size.N_};
ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions)); ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions));
......
add_example_executable(example_reduce_blockwise reduce_blockwise.cpp) add_example_executable(example_reduce_blockwise reduce_blockwise.cpp)
add_example_executable(example_reduce_threadwise_multi_d reduce_threadwise_multi_d.cpp)
add_example_executable(example_reduce_multiblock_atomic_add reduce_multiblock_atomic_add.cpp) add_example_executable(example_reduce_multiblock_atomic_add reduce_multiblock_atomic_add.cpp)
add_example_executable(example_reduce_blockwise_two_call reduce_blockwise_two_call.cpp) add_example_executable(example_reduce_blockwise_two_call reduce_blockwise_two_call.cpp)
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <initializer_list> #include <initializer_list>
...@@ -255,34 +255,61 @@ int main(int argc, char* argv[]) ...@@ -255,34 +255,61 @@ int main(int argc, char* argv[])
else else
{ {
// for testing half_t // for testing half_t
pass =
pass && reduce_blockwise_test<ck::half_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass =
pass && reduce_blockwise_test<ck::half_t, float, ReduceOpId, PropagateNan, OutputIndex>( pass && reduce_blockwise_test<ck::half_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing float // for testing float
pass =
pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>( pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing double // for testing double
pass =
pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>( pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing bhalf_t // for testing bhalf_t
pass = pass &&
reduce_blockwise_test<ck::bhalf_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && pass = pass &&
reduce_blockwise_test<ck::bhalf_t, float, ReduceOpId, PropagateNan, OutputIndex>( reduce_blockwise_test<ck::bhalf_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing int8_t // for testing int8_t
pass =
pass && reduce_blockwise_test<int8_t, int32_t, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass =
pass && reduce_blockwise_test<int8_t, int32_t, ReduceOpId, PropagateNan, OutputIndex>( pass && reduce_blockwise_test<int8_t, int32_t, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
// for testing int4_t using AVG operation // for testing int4_t using AVG operation
pass =
pass && reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>( pass = pass && reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing int4_t using MAX operation // for testing int4_t using MAX operation
pass =
pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>( pass = pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
#endif #endif
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -316,7 +316,17 @@ int reduce_blockwise_impl(bool do_verification, ...@@ -316,7 +316,17 @@ int reduce_blockwise_impl(bool do_verification,
auto invoker_ptr = reduce.MakeInvokerPointer(); auto invoker_ptr = reduce.MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); int log_level = 0, cold_niters = 5, nrepeat = 50;
if(beta != 0.0f)
{
std::cerr << "Warning: With beta != 0.0f there must be only one repeat for correct results "
"since out memory is being overwritten."
<< std::endl;
cold_niters = 0;
nrepeat = 1;
}
float avg_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, log_level, cold_niters, nrepeat});
std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) +
invariant_total_length * sizeof(InOutDataType); invariant_total_length * sizeof(InOutDataType);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -38,7 +38,8 @@ struct ReduceShape ...@@ -38,7 +38,8 @@ struct ReduceShape
static constexpr ck::index_t NumReduceDim_ = NumReduceDim; static constexpr ck::index_t NumReduceDim_ = NumReduceDim;
}; };
using reduce_shape_instances = std::tuple<ReduceShape<3, 1>, using reduce_shape_instances = std::tuple<ReduceShape<12, 3>,
ReduceShape<3, 1>,
ReduceShape<3, 2>, ReduceShape<3, 2>,
ReduceShape<4, 1>, ReduceShape<4, 1>,
ReduceShape<4, 2>, ReduceShape<4, 2>,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/utility/reduction_enums.hpp"
#include "reduce_threadwise_multi_d_impl.hpp"
#include "reduce_example_common.hpp"
using namespace ck;
using namespace ck::tensor_operation::device;
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'},
{nullptr, 0, nullptr, 0}};
class SimpleAppArgs
{
private:
int option_index = 0;
public:
std::vector<size_t> inLengths = {16, 64, 32, 16};
std::vector<int> reduceDims = {0};
std::vector<float> scales = {1.0f, 0.0f};
bool do_verification = true;
int data_type = 1;
int init_method = 2;
bool time_kernel = true;
public:
void show_usage(const char* cmd)
{
std::cout << "Usage of " << cmd << std::endl;
std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths"
<< std::endl;
std::cout << "--reduceDims or -R, comma separated list of to-reduce dimensions"
<< std::endl;
std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
"comparing with the host-based reduction"
<< std::endl;
std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4)"
<< std::endl;
std::cout << "Arg2 -- init method (0=no init, 1=single integer value, 2=scope integer "
"value, 3=decimal value)"
<< std::endl;
std::cout << "Arg3 -- time kernel (0=no, 1=yes)" << std::endl;
};
int processArgs(int argc, char* argv[])
{
using ck::host_common::getTypeValuesFromString;
int ch;
while(1)
{
ch = getopt_long(argc, argv, "D:R:v:l:", long_options, &option_index);
if(ch == -1)
break;
switch(ch)
{
case 'D':
if(!optarg)
throw std::runtime_error("Invalid option format!");
inLengths = getTypeValuesFromString<size_t>(optarg);
break;
case 'R':
if(!optarg)
throw std::runtime_error("Invalid option format!");
reduceDims = getTypeValuesFromString<int>(optarg);
break;
case 'v':
if(!optarg)
throw std::runtime_error("Invalid option format!");
do_verification = static_cast<bool>(std::atoi(optarg));
break;
case '?':
if(std::string(long_options[option_index].name) == "help")
{
show_usage(argv[0]);
return (-1);
};
break;
default: show_usage(argv[0]); return (-1);
};
};
if(optind + 3 > argc)
{
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
};
data_type = std::atoi(argv[optind++]);
init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
if(scales.empty())
{
scales.push_back(1.0f);
scales.push_back(0.0f);
};
return (0);
};
};
template <typename InOutDataType,
typename AccDataType,
ReduceTensorOp ReduceOpId,
index_t PropagateNan,
index_t OutputIndex>
bool reduce_threadwise_multi_d_test(bool do_verification,
int init_method,
bool time_kernel,
const std::vector<size_t>& inLengths,
const std::vector<int>& reduceDims,
float alpha,
float beta)
{
bool matched = false;
int result = 0;
const auto tuple_object = reduce_shape_instances{};
static_for<0, std::tuple_size<reduce_shape_instances>::value, 1>{}([&](auto i) {
if(matched)
return;
using ShapeType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;
if(ShapeType::Rank_ != inLengths.size() || ShapeType::NumReduceDim_ != reduceDims.size())
return;
std::array<int, ShapeType::NumReduceDim_> arrReduceDims;
ck::ranges::copy(reduceDims, arrReduceDims.begin());
result = reduce_threadwise_multi_d_impl<InOutDataType,
AccDataType,
ReduceOpId,
ShapeType::Rank_,
ShapeType::NumReduceDim_,
PropagateNan,
OutputIndex>(
do_verification, init_method, time_kernel, inLengths, arrReduceDims, alpha, beta);
matched = true;
});
return (result == 0) ? true : false;
};
constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::AVG;
constexpr bool PropagateNan = true;
constexpr bool OutputIndex = false;
int main(int argc, char* argv[])
{
bool pass = true;
if(argc > 1)
{
SimpleAppArgs arg;
if(arg.processArgs(argc, argv) < 0)
return (-1);
if(arg.data_type == 0)
{
pass = reduce_threadwise_multi_d_test<ck::half_t,
float,
ReduceOpId,
PropagateNan,
OutputIndex>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inLengths,
arg.reduceDims,
arg.scales[0],
arg.scales[1]);
}
else if(arg.data_type == 1)
{
pass =
reduce_threadwise_multi_d_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inLengths,
arg.reduceDims,
arg.scales[0],
arg.scales[1]);
}
}
else
{
// for testing half_t
pass = pass && reduce_threadwise_multi_d_test<ck::half_t,
float,
ReduceOpId,
PropagateNan,
OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0}, 1.0f, 0.0f);
// for testing float
pass = pass &&
reduce_threadwise_multi_d_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0}, 1.0f, 0.0f);
// for testing bhalf_t
pass = pass && reduce_threadwise_multi_d_test<ck::bhalf_t,
float,
ReduceOpId,
PropagateNan,
OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0}, 1.0f, 0.0f);
}
return (pass ? 0 : 1);
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment