Unverified Commit 3d61f89a authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #134 from ROCm/merge_from_public

Merge from public
parents c160c6cf 4558a3f8
......@@ -76,8 +76,11 @@ std::string SequenceStr(const std::vector<int>& v);
std::string MakeTuple(const std::vector<std::string>& v);
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
template <int... xs>
const std::string S = SequenceStr({xs...});
#pragma clang diagnostic pop
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
......
......@@ -3,6 +3,7 @@
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/types.hpp"
#include "ck/host/utils.hpp"
#include <cassert>
......@@ -32,11 +33,11 @@ static std::string GetGemmSpec(const std::size_t m,
}
// function to update prologue/epilogue with user provided operation
void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue)
void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
{
if(!prologue.empty())
if(!pro.empty())
{
this->prologue = prologue;
this->prologue = pro;
this->cde_elem_op = "CDEElementOp";
}
else
......@@ -45,11 +46,11 @@ void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue)
}
}
void Operation_Xdl_CShuffle::update_epilogue(const std::string& epilogue)
void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
{
if(!epilogue.empty())
if(!epi.empty())
{
this->epilogue = epilogue;
this->epilogue = epi;
this->cde_elem_op = "CDEElementOp";
}
else
......
......@@ -4,6 +4,7 @@
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include <iostream>
#include "ck/host/stringutils.hpp"
#include "ck/host/types.hpp"
#include "ck/host/utils.hpp"
#include <cassert>
......@@ -11,34 +12,15 @@ namespace ck {
namespace host {
namespace conv {
// calculate appropriate Gemm Specification based on input tensor dimensions
// NOTE: in CK, MNKPadding is always used for forward convolution
static std::string GetGemmSpec(const std::size_t m,
const std::size_t n,
const std::size_t k,
const std::size_t m_per_block,
const std::size_t n_per_block,
const std::size_t k_per_block)
{
std::string spec = "";
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K";
if(spec == "")
return "ck::tensor_operation::device::GemmSpecialization::Default";
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
// NOTE: in CK, MNKPadding is always used for forward convolution, so didn't
// add GemmSpec function here
// function to update prologue/epilogue with user provided operation
void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologue)
void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& pro)
{
if(!prologue.empty())
if(!pro.empty())
{
this->prologue = prologue;
this->prologue = pro;
this->cde_elem_op = "CDEElementOp";
}
else
......@@ -47,11 +29,11 @@ void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologu
}
}
void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epilogue)
void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epi)
{
if(!epilogue.empty())
if(!epi.empty())
{
this->epilogue = epilogue;
this->epilogue = epi;
this->cde_elem_op = "CDEElementOp";
}
else
......@@ -233,6 +215,12 @@ extern "C" __global__ void run_${name}(
${BElementwiseOperation}{},
${CDEElementwiseOperation}{1.0f, 1.0f});
if(!DeviceConv::IsSupportedArgument(arg))
{
printf("Arguement is not supported.\n");
return;
};
constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler();
// GridwiseGemm
......
......@@ -4,7 +4,10 @@
namespace ck {
namespace host {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
const std::string config_header = "";
#pragma clang diagnostic pop
std::unordered_map<std::string_view, std::string_view> GetHeaders()
{
......
......@@ -4,7 +4,9 @@ file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
foreach(TEST_SRC ${TEST_SRCS})
set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP)
get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE)
rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC})
add_executable(test_host_${BASE_NAME} ${TEST_SRC})
add_dependencies(codegen test_host_${BASE_NAME})
add_test(NAME codegen_test_${BASE_NAME} COMMAND test_host_${BASE_NAME})
target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host)
# target_link_libraries(test_host_${BASE_NAME} ${CK_ROOT}/build/lib/libutility.a)
target_include_directories(test_host_${BASE_NAME} PUBLIC include())
......
......@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C),
static_cast<int>(prob.Y),
static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
......@@ -109,7 +108,6 @@ struct Epilogue
1,
static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
......@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C),
static_cast<int>(prob.Y),
static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
......@@ -109,7 +108,6 @@ struct Epilogue
1,
static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
......@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C),
static_cast<int>(prob.Y),
static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
......@@ -109,7 +108,6 @@ struct Epilogue
1,
static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
......@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C),
static_cast<int>(prob.Y),
static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
......@@ -109,7 +108,6 @@ struct Epilogue
1,
static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
......@@ -118,4 +118,4 @@ void kernel::launch(hipStream_t stream,
launch_kernel(impl->fun, stream, global, local, kernargs.data(), size);
}
} // namespace rtc
\ No newline at end of file
} // namespace rtc
......@@ -45,4 +45,4 @@ void tmp_dir::execute(const std::string& cmd) const
tmp_dir::~tmp_dir() { std::filesystem::remove_all(this->path); }
} // namespace rtc
\ No newline at end of file
} // namespace rtc
rocm-docs-core==1.4.1
rocm-docs-core==1.7.1
sphinxcontrib-bibtex==2.6.2
......@@ -4,33 +4,33 @@
#
# pip-compile requirements.in
#
accessible-pygments==0.0.3
accessible-pygments==0.0.5
# via pydata-sphinx-theme
alabaster==0.7.13
alabaster==0.7.16
# via sphinx
babel==2.12.1
babel==2.15.0
# via
# pydata-sphinx-theme
# sphinx
beautifulsoup4==4.11.2
beautifulsoup4==4.12.3
# via pydata-sphinx-theme
breathe==4.34.0
breathe==4.35.0
# via rocm-docs-core
certifi==2023.7.22
certifi==2024.7.4
# via requests
cffi==1.15.1
cffi==1.16.0
# via
# cryptography
# pynacl
charset-normalizer==3.1.0
charset-normalizer==3.3.2
# via requests
click==8.1.3
click==8.1.7
# via sphinx-external-toc
cryptography==41.0.6
cryptography==43.0.0
# via pyjwt
deprecated==1.2.13
deprecated==1.2.14
# via pygithub
docutils==0.16
docutils==0.21.2
# via
# breathe
# myst-parser
......@@ -38,35 +38,35 @@ docutils==0.16
# pydata-sphinx-theme
# sphinx
# sphinxcontrib-bibtex
fastjsonschema==2.18.0
fastjsonschema==2.20.0
# via rocm-docs-core
gitdb==4.0.10
gitdb==4.0.11
# via gitpython
gitpython==3.1.37
gitpython==3.1.43
# via rocm-docs-core
idna==3.4
idna==3.7
# via requests
imagesize==1.4.1
# via sphinx
jinja2==3.1.2
jinja2==3.1.4
# via
# myst-parser
# sphinx
latexcodec==2.0.1
latexcodec==3.0.0
# via pybtex
markdown-it-py==2.2.0
markdown-it-py==3.0.0
# via
# mdit-py-plugins
# myst-parser
markupsafe==2.1.2
markupsafe==2.1.5
# via jinja2
mdit-py-plugins==0.3.5
mdit-py-plugins==0.4.1
# via myst-parser
mdurl==0.1.2
# via markdown-it-py
myst-parser==1.0.0
myst-parser==3.0.1
# via rocm-docs-core
packaging==23.0
packaging==24.1
# via
# pydata-sphinx-theme
# sphinx
......@@ -74,48 +74,46 @@ pybtex==0.24.0
# via
# pybtex-docutils
# sphinxcontrib-bibtex
pybtex-docutils==1.0.2
pybtex-docutils==1.0.3
# via sphinxcontrib-bibtex
pycparser==2.21
pycparser==2.22
# via cffi
pydata-sphinx-theme==0.13.3
pydata-sphinx-theme==0.15.4
# via
# rocm-docs-core
# sphinx-book-theme
pygithub==1.58.1
pygithub==2.3.0
# via rocm-docs-core
pygments==2.15.0
pygments==2.18.0
# via
# accessible-pygments
# pydata-sphinx-theme
# sphinx
pyjwt[crypto]==2.6.0
pyjwt[crypto]==2.8.0
# via pygithub
pynacl==1.5.0
# via pygithub
pyyaml==6.0
pyyaml==6.0.1
# via
# myst-parser
# pybtex
# rocm-docs-core
# sphinx-external-toc
requests==2.31.0
requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core==1.4.1
rocm-docs-core==1.7.1
# via -r requirements.in
six==1.16.0
# via
# latexcodec
# pybtex
smmap==5.0.0
# via pybtex
smmap==5.0.1
# via gitdb
snowballstemmer==2.2.0
# via sphinx
soupsieve==2.4
soupsieve==2.5
# via beautifulsoup4
sphinx==5.3.0
sphinx==7.4.7
# via
# breathe
# myst-parser
......@@ -127,33 +125,39 @@ sphinx==5.3.0
# sphinx-external-toc
# sphinx-notfound-page
# sphinxcontrib-bibtex
sphinx-book-theme==1.0.1
sphinx-book-theme==1.1.3
# via rocm-docs-core
sphinx-copybutton==0.5.1
sphinx-copybutton==0.5.2
# via rocm-docs-core
sphinx-design==0.4.1
sphinx-design==0.6.0
# via rocm-docs-core
sphinx-external-toc==0.3.1
sphinx-external-toc==1.0.1
# via rocm-docs-core
sphinx-notfound-page==0.8.3
sphinx-notfound-page==1.0.3
# via rocm-docs-core
sphinxcontrib-applehelp==1.0.4
sphinxcontrib-applehelp==2.0.0
# via sphinx
sphinxcontrib-bibtex==2.6.2
# via -r requirements.in
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-devhelp==2.0.0
# via sphinx
sphinxcontrib-htmlhelp==2.0.1
sphinxcontrib-htmlhelp==2.1.0
# via sphinx
sphinxcontrib-jsmath==1.0.1
# via sphinx
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-qthelp==2.0.0
# via sphinx
sphinxcontrib-serializinghtml==1.1.5
sphinxcontrib-serializinghtml==2.0.0
# via sphinx
typing-extensions==4.5.0
# via pydata-sphinx-theme
urllib3==1.26.18
# via requests
wrapt==1.15.0
tomli==2.0.1
# via sphinx
typing-extensions==4.12.2
# via
# pydata-sphinx-theme
# pygithub
urllib3==2.2.2
# via
# pygithub
# requests
wrapt==1.16.0
# via deprecated
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
......@@ -28,14 +28,14 @@ using DeviceGemmV2Instance =
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
128, 256,
224, 256,
128, 16, 16,
16, 16,
4, 8,
7, 8,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 1,
2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 1,
2, 16, 16, 0,
1, 2, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ck::f8_t>;
// clang-format on
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cassert>
......@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc,
inline HostTensorDescriptor
make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size)
{
std::vector<ck::index_t> dimensions{problem_size.G_, problem_size.N_};
std::vector<ck::long_index_t> dimensions{problem_size.G_, problem_size.N_};
ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions));
......
add_example_executable(example_reduce_blockwise reduce_blockwise.cpp)
add_example_executable(example_reduce_threadwise_multi_d reduce_threadwise_multi_d.cpp)
add_example_executable(example_reduce_multiblock_atomic_add reduce_multiblock_atomic_add.cpp)
add_example_executable(example_reduce_blockwise_two_call reduce_blockwise_two_call.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <initializer_list>
......@@ -255,34 +255,61 @@ int main(int argc, char* argv[])
else
{
// for testing half_t
pass =
pass && reduce_blockwise_test<ck::half_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass =
pass && reduce_blockwise_test<ck::half_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing float
pass =
pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing double
pass =
pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing bhalf_t
pass = pass &&
reduce_blockwise_test<ck::bhalf_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass &&
reduce_blockwise_test<ck::bhalf_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing int8_t
pass =
pass && reduce_blockwise_test<int8_t, int32_t, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass =
pass && reduce_blockwise_test<int8_t, int32_t, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
// for testing int4_t using AVG operation
pass =
pass && reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing int4_t using MAX operation
pass =
pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
#endif
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -316,7 +316,17 @@ int reduce_blockwise_impl(bool do_verification,
auto invoker_ptr = reduce.MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
int log_level = 0, cold_niters = 5, nrepeat = 50;
if(beta != 0.0f)
{
std::cerr << "Warning: With beta != 0.0f there must be only one repeat for correct results "
"since out memory is being overwritten."
<< std::endl;
cold_niters = 0;
nrepeat = 1;
}
float avg_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, log_level, cold_niters, nrepeat});
std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) +
invariant_total_length * sizeof(InOutDataType);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -38,7 +38,8 @@ struct ReduceShape
static constexpr ck::index_t NumReduceDim_ = NumReduceDim;
};
using reduce_shape_instances = std::tuple<ReduceShape<3, 1>,
using reduce_shape_instances = std::tuple<ReduceShape<12, 3>,
ReduceShape<3, 1>,
ReduceShape<3, 2>,
ReduceShape<4, 1>,
ReduceShape<4, 2>,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/utility/reduction_enums.hpp"
#include "reduce_threadwise_multi_d_impl.hpp"
#include "reduce_example_common.hpp"
using namespace ck;
using namespace ck::tensor_operation::device;
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'},
{nullptr, 0, nullptr, 0}};
class SimpleAppArgs
{
private:
int option_index = 0;
public:
std::vector<size_t> inLengths = {16, 64, 32, 16};
std::vector<int> reduceDims = {0};
std::vector<float> scales = {1.0f, 0.0f};
bool do_verification = true;
int data_type = 1;
int init_method = 2;
bool time_kernel = true;
public:
void show_usage(const char* cmd)
{
std::cout << "Usage of " << cmd << std::endl;
std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths"
<< std::endl;
std::cout << "--reduceDims or -R, comma separated list of to-reduce dimensions"
<< std::endl;
std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
"comparing with the host-based reduction"
<< std::endl;
std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4)"
<< std::endl;
std::cout << "Arg2 -- init method (0=no init, 1=single integer value, 2=scope integer "
"value, 3=decimal value)"
<< std::endl;
std::cout << "Arg3 -- time kernel (0=no, 1=yes)" << std::endl;
};
int processArgs(int argc, char* argv[])
{
using ck::host_common::getTypeValuesFromString;
int ch;
while(1)
{
ch = getopt_long(argc, argv, "D:R:v:l:", long_options, &option_index);
if(ch == -1)
break;
switch(ch)
{
case 'D':
if(!optarg)
throw std::runtime_error("Invalid option format!");
inLengths = getTypeValuesFromString<size_t>(optarg);
break;
case 'R':
if(!optarg)
throw std::runtime_error("Invalid option format!");
reduceDims = getTypeValuesFromString<int>(optarg);
break;
case 'v':
if(!optarg)
throw std::runtime_error("Invalid option format!");
do_verification = static_cast<bool>(std::atoi(optarg));
break;
case '?':
if(std::string(long_options[option_index].name) == "help")
{
show_usage(argv[0]);
return (-1);
};
break;
default: show_usage(argv[0]); return (-1);
};
};
if(optind + 3 > argc)
{
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
};
data_type = std::atoi(argv[optind++]);
init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
if(scales.empty())
{
scales.push_back(1.0f);
scales.push_back(0.0f);
};
return (0);
};
};
template <typename InOutDataType,
typename AccDataType,
ReduceTensorOp ReduceOpId,
index_t PropagateNan,
index_t OutputIndex>
bool reduce_threadwise_multi_d_test(bool do_verification,
int init_method,
bool time_kernel,
const std::vector<size_t>& inLengths,
const std::vector<int>& reduceDims,
float alpha,
float beta)
{
bool matched = false;
int result = 0;
const auto tuple_object = reduce_shape_instances{};
static_for<0, std::tuple_size<reduce_shape_instances>::value, 1>{}([&](auto i) {
if(matched)
return;
using ShapeType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;
if(ShapeType::Rank_ != inLengths.size() || ShapeType::NumReduceDim_ != reduceDims.size())
return;
std::array<int, ShapeType::NumReduceDim_> arrReduceDims;
ck::ranges::copy(reduceDims, arrReduceDims.begin());
result = reduce_threadwise_multi_d_impl<InOutDataType,
AccDataType,
ReduceOpId,
ShapeType::Rank_,
ShapeType::NumReduceDim_,
PropagateNan,
OutputIndex>(
do_verification, init_method, time_kernel, inLengths, arrReduceDims, alpha, beta);
matched = true;
});
return (result == 0) ? true : false;
};
constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::AVG;
constexpr bool PropagateNan = true;
constexpr bool OutputIndex = false;
int main(int argc, char* argv[])
{
bool pass = true;
if(argc > 1)
{
SimpleAppArgs arg;
if(arg.processArgs(argc, argv) < 0)
return (-1);
if(arg.data_type == 0)
{
pass = reduce_threadwise_multi_d_test<ck::half_t,
float,
ReduceOpId,
PropagateNan,
OutputIndex>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inLengths,
arg.reduceDims,
arg.scales[0],
arg.scales[1]);
}
else if(arg.data_type == 1)
{
pass =
reduce_threadwise_multi_d_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inLengths,
arg.reduceDims,
arg.scales[0],
arg.scales[1]);
}
}
else
{
// for testing half_t
pass = pass && reduce_threadwise_multi_d_test<ck::half_t,
float,
ReduceOpId,
PropagateNan,
OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0}, 1.0f, 0.0f);
// for testing float
pass = pass &&
reduce_threadwise_multi_d_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0}, 1.0f, 0.0f);
// for testing bhalf_t
pass = pass && reduce_threadwise_multi_d_test<ck::bhalf_t,
float,
ReduceOpId,
PropagateNan,
OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0}, 1.0f, 0.0f);
}
return (pass ? 0 : 1);
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment