Commit 5b7c2432 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'rosenrodt/gemm-standalone-bench' into wavelet_model

parents 7e493730 5a995b14
...@@ -69,16 +69,16 @@ template <> std::string type_to_string<int32_t>() { return "int32"; } ...@@ -69,16 +69,16 @@ template <> std::string type_to_string<int32_t>() { return "int32"; }
// clang-format on // clang-format on
template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank> template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank>
void profile_normalization_impl(int do_verification, void profile_softmax_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
bool time_kernel, bool time_kernel,
std::vector<index_t> in_length, std::vector<index_t> in_length,
std::vector<index_t> in_strides, std::vector<index_t> in_strides,
std::vector<index_t> reduce_dims, std::vector<index_t> reduce_dims,
AccDataType alpha, AccDataType alpha,
AccDataType beta, AccDataType beta,
NormType norm_type) NormType norm_type)
{ {
if(Rank != in_length.size()) if(Rank != in_length.size())
{ {
......
...@@ -12,8 +12,7 @@ using ck::index_t; ...@@ -12,8 +12,7 @@ using ck::index_t;
struct LayernormArgParser struct LayernormArgParser
{ {
std::unordered_map<std::string, std::vector<int>> long_opts = { std::unordered_map<std::string, std::vector<int>> long_opts = {{"length", {}}};
{"length", {}}, {"strideXY", {}}, {"strideGamma", {}}, {"strideBeta", {}}};
bool parse_opt(int argc, char* argv[], const std::string& key, int i) bool parse_opt(int argc, char* argv[], const std::string& key, int i)
{ {
...@@ -52,9 +51,6 @@ void print_help_layernorm() ...@@ -52,9 +51,6 @@ void print_help_layernorm()
<< "arg4: print tensor value (0: no; 1: yes)\n" << "arg4: print tensor value (0: no; 1: yes)\n"
<< "arg5: time kernel (0=no, 1=yes)\n" << "arg5: time kernel (0=no, 1=yes)\n"
<< "--length: tensor extents (e.g, --length 1024 1024) \n" << "--length: tensor extents (e.g, --length 1024 1024) \n"
<< "--strideXY: tensor strides (e.g, --strideXY 1024 1)\n"
<< "--strideGamma: tensor strides (e.g, --strideGamma 1)\n"
<< "--strideBeta: tensor strides (e.g, --strideBeta 1)\n"
<< std::endl; << std::endl;
} }
...@@ -77,10 +73,7 @@ int profile_layernorm(int argc, char* argv[]) ...@@ -77,10 +73,7 @@ int profile_layernorm(int argc, char* argv[])
// parse the long options // parse the long options
arg_parser(argc, argv); arg_parser(argc, argv);
const std::vector<index_t> length = arg_parser.long_opts["length"]; const std::vector<index_t> length = arg_parser.long_opts["length"];
const std::vector<index_t> strideXY = arg_parser.long_opts["strideXY"];
const std::vector<index_t> strideGamma = arg_parser.long_opts["strideGamma"];
const std::vector<index_t> strideBeta = arg_parser.long_opts["strideBeta"];
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -88,25 +81,13 @@ int profile_layernorm(int argc, char* argv[]) ...@@ -88,25 +81,13 @@ int profile_layernorm(int argc, char* argv[])
if(data_type == ck::DataTypeEnum::Half) if(data_type == ck::DataTypeEnum::Half)
{ {
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, rank>(do_verification, ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, rank>(
init_method, do_verification, init_method, do_log, time_kernel, length);
do_log,
time_kernel,
length,
strideXY,
strideGamma,
strideBeta);
} }
else if(data_type == ck::DataTypeEnum::Float) else if(data_type == ck::DataTypeEnum::Float)
{ {
ck::profiler::profile_layernorm_impl<F32, F32, F32, F32, F32, rank>(do_verification, ck::profiler::profile_layernorm_impl<F32, F32, F32, F32, F32, rank>(
init_method, do_verification, init_method, do_log, time_kernel, length);
do_log,
time_kernel,
length,
strideXY,
strideGamma,
strideBeta);
} }
else else
{ {
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include "profiler/include/profile_normalization_impl.hpp" #include "profiler/include/profile_softmax_impl.hpp"
using ck::index_t; using ck::index_t;
using ck::profiler::NormDataType; using ck::profiler::NormDataType;
...@@ -95,30 +95,29 @@ int profile_normalization(int argc, char* argv[]) ...@@ -95,30 +95,29 @@ int profile_normalization(int argc, char* argv[])
{ {
if(data_type == NormDataType::F16_F16) if(data_type == NormDataType::F16_F16)
{ {
ck::profiler::profile_normalization_impl<ck::half_t, float, ck::half_t, 3>( ck::profiler::profile_softmax_impl<ck::half_t, float, ck::half_t, 3>(do_verification,
do_verification, init_method,
init_method, do_log,
do_log, time_kernel,
time_kernel, length,
length, stride,
stride, reduce,
reduce, float(alpha),
float(alpha), float(beta),
float(beta), norm_type);
norm_type);
} }
else if(data_type == NormDataType::F32_F32) else if(data_type == NormDataType::F32_F32)
{ {
ck::profiler::profile_normalization_impl<float, float, float, 3>(do_verification, ck::profiler::profile_softmax_impl<float, float, float, 3>(do_verification,
init_method, init_method,
do_log, do_log,
time_kernel, time_kernel,
length, length,
stride, stride,
reduce, reduce,
float(alpha), float(alpha),
float(beta), float(beta),
norm_type); norm_type);
} }
else else
{ {
...@@ -129,30 +128,29 @@ int profile_normalization(int argc, char* argv[]) ...@@ -129,30 +128,29 @@ int profile_normalization(int argc, char* argv[])
{ {
if(data_type == NormDataType::F16_F16) if(data_type == NormDataType::F16_F16)
{ {
ck::profiler::profile_normalization_impl<ck::half_t, float, ck::half_t, 4>( ck::profiler::profile_softmax_impl<ck::half_t, float, ck::half_t, 4>(do_verification,
do_verification, init_method,
init_method, do_log,
do_log, time_kernel,
time_kernel, length,
length, stride,
stride, reduce,
reduce, float(alpha),
float(alpha), float(beta),
float(beta), norm_type);
norm_type);
} }
else if(data_type == NormDataType::F32_F32) else if(data_type == NormDataType::F32_F32)
{ {
ck::profiler::profile_normalization_impl<float, float, float, 4>(do_verification, ck::profiler::profile_softmax_impl<float, float, float, 4>(do_verification,
init_method, init_method,
do_log, do_log,
time_kernel, time_kernel,
length, length,
stride, stride,
reduce, reduce,
float(alpha), float(alpha),
float(beta), float(beta),
norm_type); norm_type);
} }
else else
{ {
......
...@@ -6,11 +6,10 @@ include(googletest) ...@@ -6,11 +6,10 @@ include(googletest)
add_custom_target(tests) add_custom_target(tests)
function(add_test_executable TEST_NAME) function(add_test_executable TEST_NAME)
message("adding test ${TEST_NAME}") message("adding test ${TEST_NAME}")
add_executable(${TEST_NAME} ${ARGN}) add_executable(${TEST_NAME} ${ARGN})
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> ) add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME}) add_dependencies(check ${TEST_NAME})
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
...@@ -23,6 +22,7 @@ function(add_gtest_executable TEST_NAME) ...@@ -23,6 +22,7 @@ function(add_gtest_executable TEST_NAME)
add_executable(${TEST_NAME} ${ARGN}) add_executable(${TEST_NAME} ${ARGN})
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME}) add_dependencies(check ${TEST_NAME})
# suppress gtest warnings # suppress gtest warnings
target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef)
target_link_libraries(${TEST_NAME} PRIVATE gtest_main) target_link_libraries(${TEST_NAME} PRIVATE gtest_main)
...@@ -30,7 +30,6 @@ function(add_gtest_executable TEST_NAME) ...@@ -30,7 +30,6 @@ function(add_gtest_executable TEST_NAME)
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
endfunction(add_gtest_executable TEST_NAME) endfunction(add_gtest_executable TEST_NAME)
add_subdirectory(magic_number_division) add_subdirectory(magic_number_division)
add_subdirectory(space_filling_curve) add_subdirectory(space_filling_curve)
add_subdirectory(conv_util) add_subdirectory(conv_util)
...@@ -51,5 +50,5 @@ add_subdirectory(convnd_bwd_data) ...@@ -51,5 +50,5 @@ add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_convnd_fwd) add_subdirectory(grouped_convnd_fwd)
add_subdirectory(block_to_ctile_map) add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax) add_subdirectory(softmax)
add_subdirectory(layernorm) add_subdirectory(normalization)
add_subdirectory(data_type) add_subdirectory(data_type)
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <vector> #include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp" #include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
using ck::tensor_operation::device::GemmSpecialization; using ck::tensor_operation::device::GemmSpecialization;
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <vector> #include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp" #include "profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp"
using ck::tensor_operation::device::GemmSpecialization; using ck::tensor_operation::device::GemmSpecialization;
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <vector> #include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp" #include "profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp"
using ck::tensor_operation::device::GemmSpecialization; using ck::tensor_operation::device::GemmSpecialization;
......
...@@ -5,237 +5,89 @@ ...@@ -5,237 +5,89 @@
#include <iostream> #include <iostream>
#include <initializer_list> #include <initializer_list>
#include <vector> #include <vector>
#include <tuple>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "profiler/include/profile_conv_bwd_data_impl.hpp" #include "profiler/include/profile_conv_bwd_data_impl.hpp"
template <typename Tuple>
class TestConvndBwdData : public ::testing::Test class TestConvndBwdData : public ::testing::Test
{ {
protected: protected:
using DataType = std::tuple_element_t<0, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params; std::vector<ck::utils::conv::ConvParam> conv_params;
};
// 1d template <ck::index_t NDimSpatial>
TEST_F(TestConvndBwdData, Conv1dBwdData) void Run()
{
conv_params.clear();
conv_params.push_back({1, 1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
conv_params.push_back({1, 1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
conv_params.push_back({1, 1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
for(auto& param : conv_params)
{ {
bool pass; for(auto& param : conv_params)
{
// fp32 bool pass;
pass = ck::profiler::profile_conv_bwd_data_impl<1, EXPECT_FALSE(conv_params.empty());
ck::tensor_layout::convolution::NWC, pass = ck::profiler::profile_conv_bwd_data_impl<
ck::tensor_layout::convolution::KXC, NDimSpatial,
ck::tensor_layout::convolution::NWK, ck::tuple_element_t<NDimSpatial - 1,
float, ck::Tuple<ck::tensor_layout::convolution::NWC,
float, ck::tensor_layout::convolution::NHWC,
float>(true, // do_verification ck::tensor_layout::convolution::NDHWC>>,
1, // init_method ck::tuple_element_t<NDimSpatial - 1,
false, // do_log ck::Tuple<ck::tensor_layout::convolution::KXC,
false, // time_kernel ck::tensor_layout::convolution::KYXC,
param); ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
EXPECT_TRUE(pass); ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
// fp16 ck::tensor_layout::convolution::NDHWK>>,
pass = ck::profiler::profile_conv_bwd_data_impl<1, DataType,
ck::tensor_layout::convolution::NWC, DataType,
ck::tensor_layout::convolution::KXC, DataType>(true, // do_verification
ck::tensor_layout::convolution::NWK, 1, // init_method integer value
ck::half_t, false, // do_log
ck::half_t, false, // time_kernel
ck::half_t>(true, // do_verification param);
1, // init_method EXPECT_TRUE(pass);
false, // do_log }
false, // time_kernel }
param); };
EXPECT_TRUE(pass);
// bf16
pass = ck::profiler::profile_conv_bwd_data_impl<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// int8 using KernelTypes = ::testing::Types<std::tuple<float>,
pass = ck::profiler::profile_conv_bwd_data_impl<1, std::tuple<ck::half_t>,
ck::tensor_layout::convolution::NWC, std::tuple<ck::bhalf_t>,
ck::tensor_layout::convolution::KXC, std::tuple<std::int8_t>>;
ck::tensor_layout::convolution::NWK, TYPED_TEST_SUITE(TestConvndBwdData, KernelTypes);
int8_t,
int8_t,
int8_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass); // 1d
} TYPED_TEST(TestConvndBwdData, Conv1dBwdData)
{
this->conv_params.clear();
this->conv_params.push_back({1, 1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
this->conv_params.push_back({1, 1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
this->template Run<1>();
} }
// 2d // 2d
TEST_F(TestConvndBwdData, Conv2dBwdData) TYPED_TEST(TestConvndBwdData, Conv2dBwdData)
{ {
conv_params.clear(); this->conv_params.clear();
conv_params.push_back({2, 1, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back(
conv_params.push_back({2, 1, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); {2, 1, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
conv_params.push_back({2, 1, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back(
{2, 1, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
for(auto& param : conv_params) this->conv_params.push_back(
{ {2, 1, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
bool pass; this->template Run<2>();
// fp32
pass = ck::profiler::profile_conv_bwd_data_impl<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
float,
float,
float>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// fp16
pass = ck::profiler::profile_conv_bwd_data_impl<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
ck::half_t,
ck::half_t,
ck::half_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// bf16
pass = ck::profiler::profile_conv_bwd_data_impl<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// int8
pass = ck::profiler::profile_conv_bwd_data_impl<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
int8_t,
int8_t,
int8_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
}
} }
// 3d // 3d
TEST_F(TestConvndBwdData, Conv3dBwdData) TYPED_TEST(TestConvndBwdData, Conv3dBwdData)
{ {
conv_params.clear(); this->conv_params.clear();
conv_params.push_back( this->conv_params.push_back(
{3, 1, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 1, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
conv_params.push_back( this->conv_params.push_back(
{3, 1, 128, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 1, 128, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
conv_params.push_back( this->conv_params.push_back(
{3, 1, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 1, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->template Run<3>();
for(auto& param : conv_params)
{
bool pass;
// fp32
pass = ck::profiler::profile_conv_bwd_data_impl<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
float,
float,
float>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// fp16
pass = ck::profiler::profile_conv_bwd_data_impl<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
ck::half_t,
ck::half_t,
ck::half_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// bf16
pass = ck::profiler::profile_conv_bwd_data_impl<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// int8
pass = ck::profiler::profile_conv_bwd_data_impl<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
int8_t,
int8_t,
int8_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
}
} }
...@@ -5,201 +5,86 @@ ...@@ -5,201 +5,86 @@
#include <iostream> #include <iostream>
#include <initializer_list> #include <initializer_list>
#include <vector> #include <vector>
#include <tuple>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "profiler/include/profile_conv_bwd_weight_impl.hpp" #include "profiler/include/profile_conv_bwd_weight_impl.hpp"
template <typename Tuple>
class TestConvndBwdWeight : public ::testing::Test class TestConvndBwdWeight : public ::testing::Test
{ {
protected: protected:
using DataType = std::tuple_element_t<0, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params; std::vector<ck::utils::conv::ConvParam> conv_params;
}; ck::index_t split_k{2};
// 1d
TEST_F(TestConvndBwdWeight, Conv1dBwdWeight)
{
conv_params.clear();
conv_params.push_back({1, 1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
conv_params.push_back({1, 1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
conv_params.push_back({1, 1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
for(auto& param : conv_params) template <ck::index_t NDimSpatial>
void Run()
{ {
bool pass; for(auto& param : conv_params)
{
// fp32 bool pass;
pass = ck::profiler::profile_conv_bwd_weight_impl<1, EXPECT_FALSE(conv_params.empty());
ck::tensor_layout::convolution::NWC, pass = ck::profiler::profile_conv_bwd_weight_impl<
ck::tensor_layout::convolution::KXC, NDimSpatial,
ck::tensor_layout::convolution::NWK, ck::tuple_element_t<NDimSpatial - 1,
float, ck::Tuple<ck::tensor_layout::convolution::NWC,
float, ck::tensor_layout::convolution::NHWC,
float>(true, // do_verification ck::tensor_layout::convolution::NDHWC>>,
1, // init_method ck::tuple_element_t<NDimSpatial - 1,
false, // do_log ck::Tuple<ck::tensor_layout::convolution::KXC,
false, // time_kernel ck::tensor_layout::convolution::KYXC,
param, ck::tensor_layout::convolution::KZYXC>>,
2); ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
EXPECT_TRUE(pass); ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
// fp16 DataType,
pass = ck::profiler::profile_conv_bwd_weight_impl<1, DataType,
ck::tensor_layout::convolution::NWC, DataType>(true, // do_verification
ck::tensor_layout::convolution::KXC, 1, // init_method integer value
ck::tensor_layout::convolution::NWK, false, // do_log
ck::half_t, false, // time_kernel
ck::half_t, param,
ck::half_t>(true, // do_verification split_k);
1, // init_method EXPECT_TRUE(pass);
false, // do_log }
false, // time_kernel }
param, };
2);
EXPECT_TRUE(pass);
// bf16 using KernelTypes =
pass = ck::profiler::profile_conv_bwd_weight_impl<1, ::testing::Types<std::tuple<float>, std::tuple<ck::half_t>, std::tuple<ck::bhalf_t>>;
ck::tensor_layout::convolution::NWC, TYPED_TEST_SUITE(TestConvndBwdWeight, KernelTypes);
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param,
2);
EXPECT_TRUE(pass); TYPED_TEST(TestConvndBwdWeight, Test1D)
} {
this->conv_params.clear();
this->conv_params.push_back({1, 1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
this->conv_params.push_back({1, 1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
this->template Run<1>();
} }
// 2d TYPED_TEST(TestConvndBwdWeight, Test2D)
TEST_F(TestConvndBwdWeight, Conv2dBwdWeight)
{ {
conv_params.clear(); this->conv_params.clear();
conv_params.push_back({2, 1, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back(
conv_params.push_back({2, 1, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); {2, 1, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
conv_params.push_back({2, 1, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back(
{2, 1, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
for(auto& param : conv_params) this->conv_params.push_back(
{ {2, 1, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
bool pass; this->template Run<2>();
// fp32
pass = ck::profiler::profile_conv_bwd_weight_impl<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
float,
float,
float>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param,
2);
EXPECT_TRUE(pass);
// fp16
pass = ck::profiler::profile_conv_bwd_weight_impl<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
ck::half_t,
ck::half_t,
ck::half_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param,
2);
EXPECT_TRUE(pass);
// bf16
pass = ck::profiler::profile_conv_bwd_weight_impl<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param,
2);
EXPECT_TRUE(pass);
}
} }
// 3d TYPED_TEST(TestConvndBwdWeight, Test3D)
TEST_F(TestConvndBwdWeight, Conv3dBwdWeight)
{ {
conv_params.clear(); this->conv_params.clear();
conv_params.push_back( this->conv_params.push_back(
{3, 1, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 1, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
conv_params.push_back( this->conv_params.push_back(
{3, 1, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 1, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
conv_params.push_back( this->conv_params.push_back(
{3, 1, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 1, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->template Run<3>();
for(auto& param : conv_params)
{
bool pass;
// fp32
pass = ck::profiler::profile_conv_bwd_weight_impl<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
float,
float,
float>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param,
2);
EXPECT_TRUE(pass);
// fp16
pass = ck::profiler::profile_conv_bwd_weight_impl<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
ck::half_t,
ck::half_t,
ck::half_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param,
2);
EXPECT_TRUE(pass);
// bf16
pass = ck::profiler::profile_conv_bwd_weight_impl<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param,
2);
EXPECT_TRUE(pass);
}
} }
...@@ -5,237 +5,88 @@ ...@@ -5,237 +5,88 @@
#include <iostream> #include <iostream>
#include <initializer_list> #include <initializer_list>
#include <vector> #include <vector>
#include <tuple>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "profiler/include/profile_conv_fwd_impl.hpp" #include "profiler/include/profile_conv_fwd_impl.hpp"
template <typename Tuple>
class TestConvndFwd : public ::testing::Test class TestConvndFwd : public ::testing::Test
{ {
protected: protected:
using DataType = std::tuple_element_t<0, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params; std::vector<ck::utils::conv::ConvParam> conv_params;
};
// 1d template <ck::index_t NDimSpatial>
TEST_F(TestConvndFwd, Conv1dFwd) void Run()
{
conv_params.clear();
conv_params.push_back({1, 1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
conv_params.push_back({1, 1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
conv_params.push_back({1, 1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
for(auto& param : conv_params)
{ {
bool pass; for(auto& param : conv_params)
{
// fp32 bool pass;
pass = ck::profiler::profile_conv_fwd_impl<1, EXPECT_FALSE(conv_params.empty());
ck::tensor_layout::convolution::NWC, pass = ck::profiler::profile_conv_fwd_impl<
ck::tensor_layout::convolution::KXC, NDimSpatial,
ck::tensor_layout::convolution::NWK, ck::tuple_element_t<NDimSpatial - 1,
float, ck::Tuple<ck::tensor_layout::convolution::NWC,
float, ck::tensor_layout::convolution::NHWC,
float>(true, // do_verification ck::tensor_layout::convolution::NDHWC>>,
1, // init_method ck::tuple_element_t<NDimSpatial - 1,
false, // do_log ck::Tuple<ck::tensor_layout::convolution::KXC,
false, // time_kernel ck::tensor_layout::convolution::KYXC,
param); ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
EXPECT_TRUE(pass); ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
// fp16 ck::tensor_layout::convolution::NDHWK>>,
pass = ck::profiler::profile_conv_fwd_impl<1, DataType,
ck::tensor_layout::convolution::NWC, DataType,
ck::tensor_layout::convolution::KXC, DataType>(true, // do_verification
ck::tensor_layout::convolution::NWK, 1, // init_method integer value
ck::half_t, false, // do_log
ck::half_t, false, // time_kernel
ck::half_t>(true, // do_verification param);
1, // init_method EXPECT_TRUE(pass);
false, // do_log }
false, // time_kernel }
param); };
EXPECT_TRUE(pass);
// bf16
pass = ck::profiler::profile_conv_fwd_impl<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// int8 using KernelTypes = ::testing::Types<std::tuple<float>,
pass = ck::profiler::profile_conv_fwd_impl<1, std::tuple<ck::half_t>,
ck::tensor_layout::convolution::NWC, std::tuple<ck::bhalf_t>,
ck::tensor_layout::convolution::KXC, std::tuple<std::int8_t>>;
ck::tensor_layout::convolution::NWK, TYPED_TEST_SUITE(TestConvndFwd, KernelTypes);
int8_t,
int8_t,
int8_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass); // 1d
} TYPED_TEST(TestConvndFwd, Conv1dFwd)
{
this->conv_params.clear();
this->conv_params.push_back({1, 1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
this->conv_params.push_back({1, 1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
this->template Run<1>();
} }
// 2d // 2d
TEST_F(TestConvndFwd, Conv2dFwd) TYPED_TEST(TestConvndFwd, Conv2dFwd)
{ {
conv_params.clear(); this->conv_params.clear();
conv_params.push_back({2, 1, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back(
conv_params.push_back({2, 1, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); {2, 1, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
conv_params.push_back({2, 1, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back(
{2, 1, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
for(auto& param : conv_params) this->conv_params.push_back(
{ {2, 1, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
bool pass; this->template Run<2>();
// fp32
pass = ck::profiler::profile_conv_fwd_impl<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
float,
float,
float>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// fp16
pass = ck::profiler::profile_conv_fwd_impl<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
ck::half_t,
ck::half_t,
ck::half_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// bf16
pass = ck::profiler::profile_conv_fwd_impl<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// int8
pass = ck::profiler::profile_conv_fwd_impl<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
int8_t,
int8_t,
int8_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
}
} }
// 3d // 3d
TEST_F(TestConvndFwd, Conv3dFwd) TYPED_TEST(TestConvndFwd, Conv3dFwd)
{ {
conv_params.clear(); this->conv_params.clear();
conv_params.push_back( this->conv_params.push_back(
{3, 1, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 1, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
conv_params.push_back( this->conv_params.push_back(
{3, 1, 128, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 1, 128, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
conv_params.push_back( this->conv_params.push_back(
{3, 1, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); {3, 1, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->template Run<3>();
for(auto& param : conv_params)
{
bool pass;
// fp32
pass = ck::profiler::profile_conv_fwd_impl<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
float,
float,
float>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// fp16
pass = ck::profiler::profile_conv_fwd_impl<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
ck::half_t,
ck::half_t,
ck::half_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// bf16
pass = ck::profiler::profile_conv_fwd_impl<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
// int8
pass = ck::profiler::profile_conv_fwd_impl<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
int8_t,
int8_t,
int8_t>(true, // do_verification
1, // init_method
false, // do_log
false, // time_kernel
param);
EXPECT_TRUE(pass);
}
} }
...@@ -13,3 +13,13 @@ target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) ...@@ -13,3 +13,13 @@ target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_int8 gemm_int8.cpp) add_test_executable(test_gemm_int8 gemm_int8.cpp)
target_link_libraries(test_gemm_int8 PRIVATE utility) target_link_libraries(test_gemm_int8 PRIVATE utility)
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
add_library(gemm_standalone_xdl_fp16_instances STATIC
instance/gemm_f16_nn_instance.cpp
instance/gemm_f16_nt_instance.cpp
instance/gemm_f16_tn_instance.cpp
instance/gemm_f16_tt_instance.cpp
)
add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp)
target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility)
target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/)
...@@ -24,56 +24,11 @@ ...@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp" #include "test/gemm/gemm_util.hpp"
int main() using ADataType = ck::bhalf_t;
{ using BDataType = ck::bhalf_t;
using ADataType = ck::bhalf_t; using CDataType = ck::bhalf_t;
using BDataType = ck::bhalf_t; using AccDataType = float;
using CDataType = ck::bhalf_t;
using AccDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor; #include "run_gemm_test.inc"
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; int main() { return run_gemm_test(); }
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
...@@ -24,56 +24,11 @@ ...@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp" #include "test/gemm/gemm_util.hpp"
int main() using ADataType = ck::half_t;
{ using BDataType = ck::half_t;
using ADataType = ck::half_t; using CDataType = ck::half_t;
using BDataType = ck::half_t; using AccDataType = float;
using CDataType = ck::half_t;
using AccDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor; #include "run_gemm_test.inc"
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; int main() { return run_gemm_test(); }
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
...@@ -24,56 +24,11 @@ ...@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp" #include "test/gemm/gemm_util.hpp"
int main() using ADataType = float;
{ using BDataType = float;
using ADataType = float; using CDataType = float;
using BDataType = float; using AccDataType = float;
using CDataType = float;
using AccDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor; #include "run_gemm_test.inc"
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; int main() { return run_gemm_test(); }
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
...@@ -24,56 +24,11 @@ ...@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp" #include "test/gemm/gemm_util.hpp"
int main() using ADataType = double;
{ using BDataType = double;
using ADataType = double; using CDataType = double;
using BDataType = double; using AccDataType = double;
using CDataType = double;
using AccDataType = double;
using Row = ck::tensor_layout::gemm::RowMajor; #include "run_gemm_test.inc"
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; int main() { return run_gemm_test(); }
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
...@@ -24,56 +24,11 @@ ...@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp" #include "test/gemm/gemm_util.hpp"
int main() using ADataType = int8_t;
{ using BDataType = int8_t;
using ADataType = int8_t; using CDataType = int8_t;
using BDataType = int8_t; using AccDataType = int32_t;
using CDataType = int8_t;
using AccDataType = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor; #include "run_gemm_test.inc"
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; int main() { return run_gemm_test(); }
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_util.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "gemm_f16_nn_instance.hpp"
#include "gemm_f16_nt_instance.hpp"
#include "gemm_f16_tn_instance.hpp"
#include "gemm_f16_tt_instance.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using F16 = ck::half_t;
using ADataType = F16;
using BDataType = F16;
using AccDataType = float;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
using ck::gemm_util::GemmParams;
using ck::tensor_operation::device::BaseOperator;
using ck::tensor_operation::device::DeviceGemm;
using namespace ck::tensor_operation::device::instance;
using DeviceGemmNN =
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>;
using DeviceGemmNT =
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>;
using DeviceGemmTN =
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>;
using DeviceGemmTT =
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>;
struct LayoutConfig
{
bool ARowMajor;
bool BRowMajor;
bool CRowMajor;
};
int main(int argc, char* argv[])
{
// Class DeviceGemm is templated by layout and precision types so it is not an option to contain
// them in a single vector. Instead we use abstract BaseOperator class and dynamic_cast() it
// upon invocation.
// And since DeviceGemm does not expose template arg information, an extra book keeping class
// LayoutConfig is used for determining which type a BaseOperator instance should be cast to.
using OpFactoryFn = void (*)(std::vector<std::unique_ptr<BaseOperator>>&);
std::vector<std::tuple<GemmParams, LayoutConfig, OpFactoryFn>> problems = {
// clang-format off
// 104 tiles
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// 110 tiles
{GemmParams{2560, 2816, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
{GemmParams{2560, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
{GemmParams{1280, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
{GemmParams{1280, 704, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
{GemmParams{2560, 2816, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
{GemmParams{2560, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
{GemmParams{1280, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
{GemmParams{1280, 704, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
{GemmParams{2560, 2816, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{2560, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{1280, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
{GemmParams{1280, 704, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64},
{GemmParams{2560, 2816, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
{GemmParams{2560, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
{GemmParams{1280, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
{GemmParams{1280, 704, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// clang-format on
};
bool do_verification = true;
bool time_kernel = true;
if(argc == 1)
{
// use default
}
else if(argc == 3)
{
do_verification = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: time kernel (0=no, 1=yes)" << std::endl;
return 0;
}
bool pass = true;
for(auto& p : problems)
{
GemmParams& problem_size = std::get<0>(p);
const LayoutConfig& layout_config = std::get<1>(p);
const auto& factory = std::get<2>(p);
std::vector<std::unique_ptr<BaseOperator>> ops;
factory(ops);
// overwrite strides
problem_size.StrideA = layout_config.ARowMajor ? problem_size.K : problem_size.M;
problem_size.StrideB = layout_config.BRowMajor ? problem_size.N : problem_size.K;
problem_size.StrideC = layout_config.CRowMajor ? problem_size.N : problem_size.M;
if(!layout_config.ARowMajor && !layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmNN*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel);
}
else if(!layout_config.ARowMajor && layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmNT*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel);
}
else if(layout_config.ARowMajor && !layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmTN*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel);
}
else if(layout_config.ARowMajor && layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmTT*>(ops[0].get());
pass &= ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel);
}
}
std::cout << (pass ? "ALL TESTS PASSED" : "SOME TESTS FAILED") << std::endl;
return pass ? 0 : 1;
}
...@@ -16,21 +16,13 @@ namespace gemm_util { ...@@ -16,21 +16,13 @@ namespace gemm_util {
struct GemmParams struct GemmParams
{ {
GemmParams() ck::index_t M = 1024;
: M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0) ck::index_t N = 1024;
{ ck::index_t K = 1024;
}
ck::index_t M;
ck::index_t N;
ck::index_t K;
ck::index_t StrideA; ck::index_t StrideA = 1024;
ck::index_t StrideB; ck::index_t StrideB = 1024;
ck::index_t StrideC; ck::index_t StrideC = 1024;
float alpha;
float beta;
}; };
template <typename GemmInstance, template <typename GemmInstance,
...@@ -69,7 +61,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, ...@@ -69,7 +61,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
Tensor<CDataType>& C, Tensor<CDataType>& C,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
bool time_kernel)
{ {
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
...@@ -94,7 +87,20 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, ...@@ -94,7 +87,20 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
{ {
a_m_k_device_buf.ToDevice(A.mData.data()); a_m_k_device_buf.ToDevice(A.mData.data());
b_k_n_device_buf.ToDevice(B.mData.data()); b_k_n_device_buf.ToDevice(B.mData.data());
invoker_ptr->Run(argument_ptr.get()); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * params.M * params.N * params.K;
std::size_t num_btype = sizeof(ADataType) * params.M * params.K +
sizeof(BDataType) * params.K * params.N +
sizeof(CDataType) * params.M * params.N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << std::endl;
c_m_n_device_buf.FromDevice(C.mData.data()); c_m_n_device_buf.FromDevice(C.mData.data());
return true; return true;
...@@ -109,19 +115,15 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, ...@@ -109,19 +115,15 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
} }
} }
template <typename DeviceGemmPtr_, template <typename AccDataType>
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct TestGemm struct TestGemm
{ {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params)
{ {
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
...@@ -156,25 +158,42 @@ struct TestGemm ...@@ -156,25 +158,42 @@ struct TestGemm
f_generate_tensor_value(a_m_k, ADataType{}); f_generate_tensor_value(a_m_k, ADataType{});
f_generate_tensor_value(b_k_n, BDataType{}); f_generate_tensor_value(b_k_n, BDataType{});
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result);
} }
auto operator()(const DeviceGemmPtr_& gemmPtr) template <template <class...> class DeviceGemmPtr_,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
auto operator()(DeviceGemmPtr_<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>* gemmPtr,
const GemmParams& params = GemmParams{},
bool do_verification = true,
bool time_kernel = false)
{ {
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl; << ", CLayout = " << CLayout{}.name << std::endl;
std::cout << gemmPtr->GetTypeString() << std::endl; std::cout << gemmPtr->GetTypeString() << std::endl;
// Arrange auto host_tensors =
ck::gemm_util::GemmParams params; PrepareGemmTensor<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(params);
params.M = 1024;
params.N = 1024;
params.K = 1024;
params.StrideA = 1024;
params.StrideB = 1024;
params.StrideC = 1024;
auto host_tensors = PrepareGemmTensor(params);
const Tensor<ADataType>& a = std::get<0>(host_tensors); const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors); const Tensor<BDataType>& b = std::get<1>(host_tensors);
...@@ -193,14 +212,18 @@ struct TestGemm ...@@ -193,14 +212,18 @@ struct TestGemm
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>; CElementwiseOperation>;
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a, b, c_host, a_element_op, b_element_op, c_element_op); if(do_verification)
{
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a, b, c_host, a_element_op, b_element_op, c_element_op);
}
// Act // Act
bool is_supported = ck::gemm_util::RunDeviceGEMM( bool is_supported = ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op, time_kernel);
if(is_supported) if(is_supported && do_verification)
{ {
// Assert // Assert
bool res = false; bool res = false;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "gemm_f16_nn_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using gemm_f16_nn_256x256 = std::tuple<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 32, 2, 8, 32, 32, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
using gemm_f16_nn_256x128 = std::tuple<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
using gemm_f16_nn_128x128 = std::tuple<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
using gemm_f16_nn_128x64 = std::tuple<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_gemm_f16_nn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_nn_256x256{});
}
void add_gemm_f16_nn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_nn_256x128{});
}
void add_gemm_f16_nn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_nn_128x128{});
}
void add_gemm_f16_nn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_nn_128x64{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_gemm_f16_nn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_f16_nn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_f16_nn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_f16_nn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment