"vscode:/vscode.git/clone" did not exist on "39a1f853be86e916303dbea926df5b3ced16dbbf"
Commit f6ceef78 authored by ThomasNing's avatar ThomasNing
Browse files

merge with the develop branch

parents 536c5458 25935b57
...@@ -92,7 +92,6 @@ struct Epilogue ...@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C), static_cast<int>(prob.C),
static_cast<int>(prob.Y), static_cast<int>(prob.Y),
static_cast<int>(prob.X)}; static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C), ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C), static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
...@@ -109,7 +108,6 @@ struct Epilogue ...@@ -109,7 +108,6 @@ struct Epilogue
1, 1,
static_cast<int>(prob.X * prob.C), static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)}; static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
...@@ -92,7 +92,6 @@ struct Epilogue ...@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C), static_cast<int>(prob.C),
static_cast<int>(prob.Y), static_cast<int>(prob.Y),
static_cast<int>(prob.X)}; static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C), ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C), static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
...@@ -109,7 +108,6 @@ struct Epilogue ...@@ -109,7 +108,6 @@ struct Epilogue
1, 1,
static_cast<int>(prob.X * prob.C), static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)}; static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2}; ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
...@@ -92,7 +92,6 @@ struct Epilogue ...@@ -92,7 +92,6 @@ struct Epilogue
static_cast<int>(prob.C), static_cast<int>(prob.C),
static_cast<int>(prob.Y), static_cast<int>(prob.Y),
static_cast<int>(prob.X)}; static_cast<int>(prob.X)};
ck::Array<ck::index_t, 5> d_lengths = {};
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C), ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C), static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
...@@ -109,7 +108,6 @@ struct Epilogue ...@@ -109,7 +108,6 @@ struct Epilogue
1, 1,
static_cast<int>(prob.X * prob.C), static_cast<int>(prob.X * prob.C),
static_cast<int>(prob.C)}; static_cast<int>(prob.C)};
ck::Array<ck::index_t, 5> d_strides = {};
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1}; ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
......
find_package(hip)
file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp) file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp)
add_library(ck_rtc ${RTC_SOURCES}) add_library(ck_rtc ${RTC_SOURCES})
target_include_directories(ck_rtc PUBLIC include) target_include_directories(ck_rtc PUBLIC include)
......
...@@ -118,4 +118,4 @@ void kernel::launch(hipStream_t stream, ...@@ -118,4 +118,4 @@ void kernel::launch(hipStream_t stream,
launch_kernel(impl->fun, stream, global, local, kernargs.data(), size); launch_kernel(impl->fun, stream, global, local, kernargs.data(), size);
} }
} // namespace rtc } // namespace rtc
\ No newline at end of file
...@@ -45,4 +45,4 @@ void tmp_dir::execute(const std::string& cmd) const ...@@ -45,4 +45,4 @@ void tmp_dir::execute(const std::string& cmd) const
tmp_dir::~tmp_dir() { std::filesystem::remove_all(this->path); } tmp_dir::~tmp_dir() { std::filesystem::remove_all(this->path); }
} // namespace rtc } // namespace rtc
\ No newline at end of file
rocm-docs-core==1.6.0 rocm-docs-core==1.7.2
sphinxcontrib-bibtex==2.6.2 sphinxcontrib-bibtex==2.6.2
...@@ -4,33 +4,33 @@ ...@@ -4,33 +4,33 @@
# #
# pip-compile requirements.in # pip-compile requirements.in
# #
accessible-pygments==0.0.3 accessible-pygments==0.0.5
# via pydata-sphinx-theme # via pydata-sphinx-theme
alabaster==0.7.13 alabaster==0.7.16
# via sphinx # via sphinx
babel==2.12.1 babel==2.15.0
# via # via
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
beautifulsoup4==4.11.2 beautifulsoup4==4.12.3
# via pydata-sphinx-theme # via pydata-sphinx-theme
breathe==4.34.0 breathe==4.35.0
# via rocm-docs-core # via rocm-docs-core
certifi==2023.7.22 certifi==2024.7.4
# via requests # via requests
cffi==1.15.1 cffi==1.16.0
# via # via
# cryptography # cryptography
# pynacl # pynacl
charset-normalizer==3.1.0 charset-normalizer==3.3.2
# via requests # via requests
click==8.1.3 click==8.1.7
# via sphinx-external-toc # via sphinx-external-toc
cryptography==41.0.6 cryptography==43.0.0
# via pyjwt # via pyjwt
deprecated==1.2.13 deprecated==1.2.14
# via pygithub # via pygithub
docutils==0.16 docutils==0.21.2
# via # via
# breathe # breathe
# myst-parser # myst-parser
...@@ -38,35 +38,35 @@ docutils==0.16 ...@@ -38,35 +38,35 @@ docutils==0.16
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
# sphinxcontrib-bibtex # sphinxcontrib-bibtex
fastjsonschema==2.18.0 fastjsonschema==2.20.0
# via rocm-docs-core # via rocm-docs-core
gitdb==4.0.10 gitdb==4.0.11
# via gitpython # via gitpython
gitpython==3.1.37 gitpython==3.1.43
# via rocm-docs-core # via rocm-docs-core
idna==3.4 idna==3.7
# via requests # via requests
imagesize==1.4.1 imagesize==1.4.1
# via sphinx # via sphinx
jinja2==3.1.2 jinja2==3.1.4
# via # via
# myst-parser # myst-parser
# sphinx # sphinx
latexcodec==2.0.1 latexcodec==3.0.0
# via pybtex # via pybtex
markdown-it-py==2.2.0 markdown-it-py==3.0.0
# via # via
# mdit-py-plugins # mdit-py-plugins
# myst-parser # myst-parser
markupsafe==2.1.2 markupsafe==2.1.5
# via jinja2 # via jinja2
mdit-py-plugins==0.3.5 mdit-py-plugins==0.4.1
# via myst-parser # via myst-parser
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
myst-parser==1.0.0 myst-parser==3.0.1
# via rocm-docs-core # via rocm-docs-core
packaging==23.0 packaging==24.1
# via # via
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
...@@ -74,48 +74,46 @@ pybtex==0.24.0 ...@@ -74,48 +74,46 @@ pybtex==0.24.0
# via # via
# pybtex-docutils # pybtex-docutils
# sphinxcontrib-bibtex # sphinxcontrib-bibtex
pybtex-docutils==1.0.2 pybtex-docutils==1.0.3
# via sphinxcontrib-bibtex # via sphinxcontrib-bibtex
pycparser==2.21 pycparser==2.22
# via cffi # via cffi
pydata-sphinx-theme==0.13.3 pydata-sphinx-theme==0.15.4
# via # via
# rocm-docs-core # rocm-docs-core
# sphinx-book-theme # sphinx-book-theme
pygithub==1.58.1 pygithub==2.3.0
# via rocm-docs-core # via rocm-docs-core
pygments==2.15.0 pygments==2.18.0
# via # via
# accessible-pygments # accessible-pygments
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
pyjwt[crypto]==2.6.0 pyjwt[crypto]==2.8.0
# via pygithub # via pygithub
pynacl==1.5.0 pynacl==1.5.0
# via pygithub # via pygithub
pyyaml==6.0 pyyaml==6.0.1
# via # via
# myst-parser # myst-parser
# pybtex # pybtex
# rocm-docs-core # rocm-docs-core
# sphinx-external-toc # sphinx-external-toc
requests==2.31.0 requests==2.32.3
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.6.0 rocm-docs-core==1.7.2
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via pybtex
# latexcodec smmap==5.0.1
# pybtex
smmap==5.0.0
# via gitdb # via gitdb
snowballstemmer==2.2.0 snowballstemmer==2.2.0
# via sphinx # via sphinx
soupsieve==2.4 soupsieve==2.5
# via beautifulsoup4 # via beautifulsoup4
sphinx==5.3.0 sphinx==7.4.7
# via # via
# breathe # breathe
# myst-parser # myst-parser
...@@ -127,33 +125,39 @@ sphinx==5.3.0 ...@@ -127,33 +125,39 @@ sphinx==5.3.0
# sphinx-external-toc # sphinx-external-toc
# sphinx-notfound-page # sphinx-notfound-page
# sphinxcontrib-bibtex # sphinxcontrib-bibtex
sphinx-book-theme==1.0.1 sphinx-book-theme==1.1.3
# via rocm-docs-core # via rocm-docs-core
sphinx-copybutton==0.5.1 sphinx-copybutton==0.5.2
# via rocm-docs-core # via rocm-docs-core
sphinx-design==0.4.1 sphinx-design==0.6.0
# via rocm-docs-core # via rocm-docs-core
sphinx-external-toc==0.3.1 sphinx-external-toc==1.0.1
# via rocm-docs-core # via rocm-docs-core
sphinx-notfound-page==0.8.3 sphinx-notfound-page==1.0.3
# via rocm-docs-core # via rocm-docs-core
sphinxcontrib-applehelp==1.0.4 sphinxcontrib-applehelp==2.0.0
# via sphinx # via sphinx
sphinxcontrib-bibtex==2.6.2 sphinxcontrib-bibtex==2.6.2
# via -r requirements.in # via -r requirements.in
sphinxcontrib-devhelp==1.0.2 sphinxcontrib-devhelp==2.0.0
# via sphinx # via sphinx
sphinxcontrib-htmlhelp==2.0.1 sphinxcontrib-htmlhelp==2.1.0
# via sphinx # via sphinx
sphinxcontrib-jsmath==1.0.1 sphinxcontrib-jsmath==1.0.1
# via sphinx # via sphinx
sphinxcontrib-qthelp==1.0.3 sphinxcontrib-qthelp==2.0.0
# via sphinx # via sphinx
sphinxcontrib-serializinghtml==1.1.5 sphinxcontrib-serializinghtml==2.0.0
# via sphinx # via sphinx
typing-extensions==4.5.0 tomli==2.0.1
# via pydata-sphinx-theme # via sphinx
urllib3==1.26.18 typing-extensions==4.12.2
# via requests # via
wrapt==1.15.0 # pydata-sphinx-theme
# pygithub
urllib3==2.2.2
# via
# pygithub
# requests
wrapt==1.16.0
# via deprecated # via deprecated
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
using ADataType = ck::f8_t; using ADataType = ck::f8_t;
using BDataType = ck::f8_t; using BDataType = ck::f8_t;
using CDataType = ck::half_t; using CDataType = ck::f8_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = float;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -34,11 +34,11 @@ inline __host__ __device__ constexpr double get_rtol() ...@@ -34,11 +34,11 @@ inline __host__ __device__ constexpr double get_rtol()
} }
else if constexpr(std::is_same_v<DataType, ck::f8_t>) else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{ {
return 1e-1; // 240 and 224 are acceptable return 2e-1;
} }
else if constexpr(std::is_same_v<DataType, ck::bf8_t>) else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{ {
return 1.5e-1; // 57344 and 49152 are acceptable return 2e-1;
} }
else else
{ {
...@@ -75,11 +75,11 @@ inline __host__ __device__ constexpr double get_atol() ...@@ -75,11 +75,11 @@ inline __host__ __device__ constexpr double get_atol()
} }
else if constexpr(std::is_same_v<DataType, ck::f8_t>) else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{ {
return 16.1; // 240 and 224 are acceptable return 2e-1;
} }
else if constexpr(std::is_same_v<DataType, ck::bf8_t>) else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{ {
return 8192.1; // 57344 and 49152 are acceptable return 2e-1;
} }
else else
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
...@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc, ...@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc,
inline HostTensorDescriptor inline HostTensorDescriptor
make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size) make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size)
{ {
std::vector<ck::index_t> dimensions{problem_size.G_, problem_size.N_}; std::vector<ck::long_index_t> dimensions{problem_size.G_, problem_size.N_};
ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions)); ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions));
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <initializer_list> #include <initializer_list>
...@@ -255,34 +255,61 @@ int main(int argc, char* argv[]) ...@@ -255,34 +255,61 @@ int main(int argc, char* argv[])
else else
{ {
// for testing half_t // for testing half_t
pass =
pass && reduce_blockwise_test<ck::half_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass =
pass && reduce_blockwise_test<ck::half_t, float, ReduceOpId, PropagateNan, OutputIndex>( pass && reduce_blockwise_test<ck::half_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing float // for testing float
pass =
pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>( pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing double // for testing double
pass =
pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>( pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing bhalf_t // for testing bhalf_t
pass = pass &&
reduce_blockwise_test<ck::bhalf_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && pass = pass &&
reduce_blockwise_test<ck::bhalf_t, float, ReduceOpId, PropagateNan, OutputIndex>( reduce_blockwise_test<ck::bhalf_t, float, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing int8_t // for testing int8_t
pass =
pass && reduce_blockwise_test<int8_t, int32_t, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass =
pass && reduce_blockwise_test<int8_t, int32_t, ReduceOpId, PropagateNan, OutputIndex>( pass && reduce_blockwise_test<int8_t, int32_t, ReduceOpId, PropagateNan, OutputIndex>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
// for testing int4_t using AVG operation // for testing int4_t using AVG operation
pass =
pass && reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>( pass = pass && reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
// for testing int4_t using MAX operation // for testing int4_t using MAX operation
pass =
pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>(
true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f);
pass = pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>( pass = pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>(
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
#endif #endif
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -316,7 +316,17 @@ int reduce_blockwise_impl(bool do_verification, ...@@ -316,7 +316,17 @@ int reduce_blockwise_impl(bool do_verification,
auto invoker_ptr = reduce.MakeInvokerPointer(); auto invoker_ptr = reduce.MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); int log_level = 0, cold_niters = 5, nrepeat = 50;
if(beta != 0.0f)
{
std::cerr << "Warning: With beta != 0.0f there must be only one repeat for correct results "
"since out memory is being overwritten."
<< std::endl;
cold_niters = 0;
nrepeat = 1;
}
float avg_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, log_level, cold_niters, nrepeat});
std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) +
invariant_total_length * sizeof(InOutDataType); invariant_total_length * sizeof(InOutDataType);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -38,7 +38,8 @@ struct ReduceShape ...@@ -38,7 +38,8 @@ struct ReduceShape
static constexpr ck::index_t NumReduceDim_ = NumReduceDim; static constexpr ck::index_t NumReduceDim_ = NumReduceDim;
}; };
using reduce_shape_instances = std::tuple<ReduceShape<3, 1>, using reduce_shape_instances = std::tuple<ReduceShape<12, 3>,
ReduceShape<3, 1>,
ReduceShape<3, 2>, ReduceShape<3, 2>,
ReduceShape<4, 1>, ReduceShape<4, 1>,
ReduceShape<4, 2>, ReduceShape<4, 2>,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification, ...@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification,
// reset input to zero // reset input to zero
in_device_buf.SetZero(); in_device_buf.SetZero();
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
for(ck::index_t d = 0; d < NDimSpatial; d++)
{
input_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
filter_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
output_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
conv_filter_dilations_i32[d] =
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
}
// do GEMM // do GEMM
auto conv = DeviceConvNdBwdDataInstance{}; auto conv = DeviceConvNdBwdDataInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
...@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification, ...@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification,
conv.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), conv.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.N_, static_cast<ck::index_t>(conv_param.N_),
conv_param.K_, static_cast<ck::index_t>(conv_param.K_),
conv_param.C_, static_cast<ck::index_t>(conv_param.C_),
conv_param.input_spatial_lengths_, input_spatial_lengths_i32,
conv_param.filter_spatial_lengths_, filter_spatial_lengths_i32,
conv_param.GetOutputSpatialLengths(), output_spatial_lengths_i32,
conv_param.conv_filter_strides_, conv_filter_strides_i32,
conv_param.conv_filter_dilations_, conv_filter_dilations_i32,
conv_param.input_left_pads_, input_left_pads_i32,
conv_param.input_right_pads_, input_right_pads_i32,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
......
...@@ -23,12 +23,8 @@ ...@@ -23,12 +23,8 @@
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
#ifdef CK_ENABLE_FP8 using F8 = ck::f8_t;
using F8 = ck::f8_t; using BF8 = ck::bf8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
......
...@@ -3,6 +3,7 @@ add_subdirectory(convinvscale) ...@@ -3,6 +3,7 @@ add_subdirectory(convinvscale)
add_subdirectory(convscale) add_subdirectory(convscale)
add_subdirectory(convscale_relu) add_subdirectory(convscale_relu)
add_subdirectory(convscale_add) add_subdirectory(convscale_add)
add_subdirectory(convscale_reduce)
add_subdirectory(multi_AB) add_subdirectory(multi_AB)
add_subdirectory(unary) add_subdirectory(unary)
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_activ_xdl_convscale_reduce)
add_example_executable(example_convnd_fwd_xdl_convscale_relu_amax_fp8 convnd_fwd_xdl_convscale_relu_amax_fp8.cpp)
add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_relu_amax_fp8)
add_example_executable(example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp)
add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8)
set(target 1)
endif()
endforeach()
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_reduce.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/type.hpp"
namespace ew = ck::tensor_operation::element_wise;
using PassThrough = ew::PassThrough;
using ConvScaleRelu = ew::UnaryCombinedOp<ew::Scale, ew::Scale, ew::Relu>;
using ConvScale = ew::UnaryCombinedOp<ew::Scale, ew::Scale, PassThrough>;
using UnaryScaleConvert = ew::Scale;
void print_helper_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename ConvOutDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename ConvElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op)
{
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<ConvOutDataType> host_conv(out_g_n_k_wos_desc);
Tensor<ConvOutDataType> device_conv(out_g_n_k_wos_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break;
case 11: // used for debugging
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-1.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem conv_device_buf(conv_param.GetOutputByte<ConvOutDataType>());
DeviceMem out_device_buf(conv_param.GetOutputByte<OutDataType>());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// random scale values
float scale_in = float(std::rand()) / float(RAND_MAX);
float scale_wei = float(std::rand()) / float(RAND_MAX);
float scale_out = float(std::rand()) / float(RAND_MAX);
std::cout << std::endl;
std::cout << "scale_in: " << scale_in << std::endl;
std::cout << "scale_wei: " << scale_wei << std::endl;
std::cout << "scale_out: " << scale_out << std::endl;
// convolution elementwise operation
auto conv_element_op = ConvElementOp{ew::Scale{scale_in}, ew::Scale{scale_wei}, {}};
auto scale_convert = UnaryScaleConvert{scale_out}; // elementwise scale and type cast
// do Conv
auto conv = DeviceConvNDFwdInstance{};
auto conv_invoker = conv.MakeInvoker();
auto conv_argument =
conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
conv_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
conv_element_op);
if(!conv.IsSupportedArgument(conv_argument))
{
throw std::runtime_error(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem");
}
std::string kernels = conv.GetTypeString();
float avg_time = conv_invoker.Run(conv_argument, StreamConfig{nullptr, time_kernel});
using DeviceElementwiseScale = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<ConvOutDataType>, // InDataTypeTuple
ck::Tuple<OutDataType>, // OutDataTypeTuple
UnaryScaleConvert, // UnaryScaleConvert
NDimSpatial + 3, // NumDim
256, // BlockSize
128, // M0PerBlock
128, // M1PerBlock
8, // M0PerThread
8, // M1PerThread
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<8>>; // OutScalarPerVectorSeq
auto device_ew_scale = DeviceElementwiseScale{};
auto scale_invoker = device_ew_scale.MakeInvoker();
auto scale_argument = device_ew_scale.MakeArgument(e_g_n_k_wos_lengths,
{e_g_n_k_wos_strides},
{e_g_n_k_wos_strides},
{conv_device_buf.GetDeviceBuffer()},
{out_device_buf.GetDeviceBuffer()},
scale_convert);
if(!device_ew_scale.IsSupportedArgument(scale_argument))
{
throw std::runtime_error(
"wrong! DeviceElementwiseScale with the specified compilation parameters does "
"not support this problem");
}
kernels += std::string("\n\t\t ") + device_ew_scale.GetTypeString();
avg_time += scale_invoker.Run(scale_argument, StreamConfig{nullptr, time_kernel});
constexpr auto ReduceOpId = ck::ReduceTensorOp::AMAX;
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename ck::reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename ck::reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
using DeviceReduceInstance =
ck::tensor_operation::device::DeviceReduceMultiBlock<ConvOutDataType,
ConvOutDataType,
ConvOutDataType,
NDimSpatial + 3,
NDimSpatial + 3,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
ck::InMemoryDataOperationEnum::Set,
true, // PropagateNan
false, // OutputIndex
false, // HaveIndexInputIfOutputIndex
256, // BlockSize
4, // MThreadClusterSize
64, // KThreadClusterSize
1, // MThreadSliceSize
1, // KThreadSliceSize
1, // InSrcVectorDim
1, // InSrceVectorSize
1>; // OutDstVectorSize
std::vector<size_t> outLengths = {1};
Tensor<ConvOutDataType> amax_host(outLengths);
Tensor<ConvOutDataType> amax_from_device(outLengths);
auto amax_host_strides = amax_host.mDesc.GetStrides();
std::array<int, NDimSpatial + 3> reduce_dims;
std::iota(reduce_dims.begin(), reduce_dims.end(), 0); // 0,..., NDimSpatial+3-1
std::array<ck::index_t, 1> reduce_out_lengths{1};
std::array<ck::index_t, 1> reduce_out_strides{static_cast<ck::index_t>(amax_host_strides[0])};
DeviceMem amax_device(sizeof(ConvOutDataType) * amax_host.mDesc.GetElementSpaceSize());
DeviceMem index_device;
InElementwiseOperation in_elementwise_op;
AccElementwiseOperation acc_elementwise_op;
std::tie(in_elementwise_op, acc_elementwise_op) =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
static_cast<int32_t>(host_conv.mDesc.GetElementSize()));
// Hack convolution output strides for reduction as kernel expects stride 1 for the last
// dimension. It only works because the reduction is done on the whole tensor and result is
// independent of the order of elements.
std::array<ck::index_t, NDimSpatial + 3> reduction_strides{};
copy(HostTensorDescriptor(e_g_n_k_wos_lengths).GetStrides(), reduction_strides);
auto device_reduce = DeviceReduceInstance{};
auto reduce_invoker = device_reduce.MakeInvokerPointer();
auto reduce_argument = device_reduce.MakeArgumentPointer(e_g_n_k_wos_lengths,
reduction_strides,
reduce_out_lengths,
reduce_out_strides,
reduce_dims,
1.0,
0.0,
conv_device_buf.GetDeviceBuffer(),
nullptr,
amax_device.GetDeviceBuffer(),
nullptr,
in_elementwise_op,
acc_elementwise_op);
if(!device_reduce.IsSupportedArgument(reduce_argument.get()))
{
throw std::runtime_error(
"wrong! DeviceReduceInstance with the specified compilation parameters does "
"not support this runtime parameters!");
};
kernels += std::string("\n\t\t ") + device_reduce.GetTypeString();
float reduce_time =
reduce_invoker->Run(reduce_argument.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel)
std::cout << "\nReduce time: " << reduce_time << " ms" << std::endl;
avg_time += reduce_time;
std::size_t flop = conv_param.GetFlops(); // convolution FLOPs
auto conv_out_elems = host_conv.GetElementSize(); // number of elements in conv result tensor
// 3 element-wise scale multipliers + 1 AMAX
std::size_t elementwise_ops = 3 + 1;
if constexpr(ck::is_same_v<ConvElementOp, ConvScaleRelu>)
{
elementwise_ops += 1; // +1 element-wise relu
}
flop += elementwise_ops * conv_out_elems;
// convolution + elementwise scaling (in + wei + output byte count)
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, ConvOutDataType>();
num_btype += sizeof(float) + sizeof(float); // + 2 scales
// elementwise scaling + F8 conversion
num_btype += conv_param.GetOutputByte<ConvOutDataType>() + sizeof(float) +
conv_param.GetOutputByte<OutDataType>();
// AMAX
num_btype += conv_param.GetOutputByte<ConvOutDataType>() + sizeof(float);
if(time_kernel)
{
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << std::endl;
}
std::cout << "\nKernels: " << kernels << std::endl;
if(do_verification)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
ConvOutDataType,
InElementOp,
WeiElementOp,
ConvElementOp>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
host_conv,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
conv_element_op);
ref_invoker.Run(ref_argument);
conv_device_buf.FromDevice(device_conv.mData.data());
out_device_buf.FromDevice(out_device.mData.data());
out_host.ForEach([&](auto&, auto idx) { scale_convert(out_host(idx), host_conv(idx)); });
std::cout << "\nComparing output to reference: " << std::endl;
auto tight_tol_check = ck::utils::check_err(out_device, out_host, "Error: ");
if(!tight_tol_check)
{
std::cout << "\n\tRecompare applying tolerances...\n";
std::cout << "\t\trtol = " << get_rtol<OutDataType>() << std::endl;
std::cout << "\t\tatol = " << get_atol<OutDataType>() << std::endl;
auto loose_tol_check = ck::utils::check_err(out_device,
out_host,
"Error: incorrect convolution results!",
get_rtol<OutDataType>(),
get_atol<OutDataType>());
if(!loose_tol_check)
{
return false;
}
}
std::cout << "Success!" << std::endl;
/// Verify AMAX
using RefReduceInstance =
ck::tensor_operation::host::ReferenceReduce<ConvOutDataType,
ConvOutDataType,
ConvOutDataType,
NDimSpatial + 3,
NDimSpatial + 3,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
true,
false>;
auto ref_reduce = RefReduceInstance{};
auto ref_reduce_invoker = ref_reduce.MakeInvokerPointer();
auto ref_reduce_argument = ref_reduce.MakeArgumentPointer(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
reduce_out_lengths,
reduce_out_strides,
reduce_dims,
1.0,
0.0,
host_conv.mData.data(),
nullptr,
amax_host.mData.data(),
nullptr,
in_elementwise_op,
acc_elementwise_op);
if(!ref_reduce.IsSupportedArgument(ref_reduce_argument.get()))
{
throw std::runtime_error(
"wrong! RefReduceInstance with the specified compilation parameters does "
"not support this runtime parameters!");
};
ref_reduce_invoker->Run(ref_reduce_argument.get());
amax_device.FromDevice(amax_from_device.mData.data());
std::cout << "\namax: " << amax_from_device.mData[0] << std::endl;
std::cout << "amax_ref: " << amax_host.mData[0] << std::endl;
return ck::utils::check_err(amax_from_device, amax_host, "Error: incorrect AMAX results!");
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using ConvOutDataType = float; // data type of convolution result
using OutDataType = ck::f8_t; // data type of final result
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScale;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
ConvOutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
AComputeDataType,
BComputeDataType>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment