Commit f9c478e2 authored by ltqin's avatar ltqin
Browse files

Merge branch 'develop' into bmatrix_skip_lds

parents 7d85d04a 91d8b7d6
add_gtest_executable(test_conv_util conv_util.cpp) add_gtest_executable(test_conv_util conv_util.cpp)
target_link_libraries(test_conv_util PRIVATE host_tensor conv_fwd_util) target_link_libraries(test_conv_util PRIVATE host_tensor conv_util)
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "gtest/gtest.h" #include <gtest/gtest.h>
#include "config.hpp" #include "config.hpp"
#include "conv_fwd_util.hpp" #include "conv_util.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "check_err.hpp" #include "check_err.hpp"
...@@ -15,13 +15,13 @@ class TestConvUtil : public ::testing::Test ...@@ -15,13 +15,13 @@ class TestConvUtil : public ::testing::Test
public: public:
void SetNDParams(std::size_t ndims) void SetNDParams(std::size_t ndims)
{ {
conv_params.num_dim_spatial = ndims; conv_params.num_dim_spatial_ = ndims;
conv_params.filter_spatial_lengths = std::vector<ck::index_t>(ndims, 3); conv_params.filter_spatial_lengths_ = std::vector<ck::index_t>(ndims, 3);
conv_params.input_spatial_lengths = std::vector<ck::index_t>(ndims, 71); conv_params.input_spatial_lengths_ = std::vector<ck::index_t>(ndims, 71);
conv_params.conv_filter_strides = std::vector<ck::index_t>(ndims, 2); conv_params.conv_filter_strides_ = std::vector<ck::index_t>(ndims, 2);
conv_params.conv_filter_dilations = std::vector<ck::index_t>(ndims, 1); conv_params.conv_filter_dilations_ = std::vector<ck::index_t>(ndims, 1);
conv_params.input_left_pads = std::vector<ck::index_t>(ndims, 1); conv_params.input_left_pads_ = std::vector<ck::index_t>(ndims, 1);
conv_params.input_right_pads = std::vector<ck::index_t>(ndims, 1); conv_params.input_right_pads_ = std::vector<ck::index_t>(ndims, 1);
} }
protected: protected:
...@@ -44,29 +44,29 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D) ...@@ -44,29 +44,29 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D)
std::vector<ck::index_t>{36, 36}, std::vector<ck::index_t>{36, 36},
"Error: ConvParams 2D default constructor.")); "Error: ConvParams 2D default constructor."));
conv_params.conv_filter_strides = std::vector<ck::index_t>{1, 1}; conv_params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1};
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_spatial_len, std::vector<ck::index_t>{71, 71}, "Error: ConvParams 2D stride {1,1}.")); out_spatial_len, std::vector<ck::index_t>{71, 71}, "Error: ConvParams 2D stride {1,1}."));
conv_params.conv_filter_strides = std::vector<ck::index_t>{2, 2}; conv_params.conv_filter_strides_ = std::vector<ck::index_t>{2, 2};
conv_params.input_left_pads = std::vector<ck::index_t>{2, 2}; conv_params.input_left_pads_ = std::vector<ck::index_t>{2, 2};
conv_params.input_right_pads = std::vector<ck::index_t>{2, 2}; conv_params.input_right_pads_ = std::vector<ck::index_t>{2, 2};
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{37, 37}, std::vector<ck::index_t>{37, 37},
"Error: ConvParams 2D padding left/right {2,2}.")); "Error: ConvParams 2D padding left/right {2,2}."));
conv_params.conv_filter_dilations = std::vector<ck::index_t>{2, 2}; conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2, 2};
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_spatial_len, std::vector<ck::index_t>{36, 36}, "Error: ConvParams 2D dilation {2,2}.")); out_spatial_len, std::vector<ck::index_t>{36, 36}, "Error: ConvParams 2D dilation {2,2}."));
conv_params.conv_filter_strides = std::vector<ck::index_t>{3, 3}; conv_params.conv_filter_strides_ = std::vector<ck::index_t>{3, 3};
conv_params.input_left_pads = std::vector<ck::index_t>{1, 1}; conv_params.input_left_pads_ = std::vector<ck::index_t>{1, 1};
conv_params.input_right_pads = std::vector<ck::index_t>{1, 1}; conv_params.input_right_pads_ = std::vector<ck::index_t>{1, 1};
conv_params.conv_filter_dilations = std::vector<ck::index_t>{2, 2}; conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2, 2};
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE( EXPECT_TRUE(
ck::utils::check_err(out_spatial_len, ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{23, 23}, std::vector<ck::index_t>{23, 23},
...@@ -81,29 +81,29 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D) ...@@ -81,29 +81,29 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D)
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D.")); out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D."));
conv_params.conv_filter_strides = std::vector<ck::index_t>{1}; conv_params.conv_filter_strides_ = std::vector<ck::index_t>{1};
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_spatial_len, std::vector<ck::index_t>{71}, "Error: ConvParams 1D stride {1}.")); out_spatial_len, std::vector<ck::index_t>{71}, "Error: ConvParams 1D stride {1}."));
conv_params.conv_filter_strides = std::vector<ck::index_t>{2}; conv_params.conv_filter_strides_ = std::vector<ck::index_t>{2};
conv_params.input_left_pads = std::vector<ck::index_t>{2}; conv_params.input_left_pads_ = std::vector<ck::index_t>{2};
conv_params.input_right_pads = std::vector<ck::index_t>{2}; conv_params.input_right_pads_ = std::vector<ck::index_t>{2};
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{37}, std::vector<ck::index_t>{37},
"Error: ConvParams 1D padding left/right {2}.")); "Error: ConvParams 1D padding left/right {2}."));
conv_params.conv_filter_dilations = std::vector<ck::index_t>{2}; conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2};
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D dilation {2}.")); out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D dilation {2}."));
conv_params.conv_filter_strides = std::vector<ck::index_t>{3}; conv_params.conv_filter_strides_ = std::vector<ck::index_t>{3};
conv_params.input_left_pads = std::vector<ck::index_t>{1}; conv_params.input_left_pads_ = std::vector<ck::index_t>{1};
conv_params.input_right_pads = std::vector<ck::index_t>{1}; conv_params.input_right_pads_ = std::vector<ck::index_t>{1};
conv_params.conv_filter_dilations = std::vector<ck::index_t>{2}; conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2};
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE( EXPECT_TRUE(
ck::utils::check_err(out_spatial_len, ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{23}, std::vector<ck::index_t>{23},
...@@ -118,31 +118,31 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D) ...@@ -118,31 +118,31 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D)
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_spatial_len, std::vector<ck::index_t>{36, 36, 36}, "Error: ConvParams 3D.")); out_spatial_len, std::vector<ck::index_t>{36, 36, 36}, "Error: ConvParams 3D."));
conv_params.conv_filter_strides = std::vector<ck::index_t>{1, 1, 1}; conv_params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1, 1};
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{71, 71, 71}, std::vector<ck::index_t>{71, 71, 71},
"Error: ConvParams 3D stride {1, 1, 1}.")); "Error: ConvParams 3D stride {1, 1, 1}."));
conv_params.conv_filter_strides = std::vector<ck::index_t>{2, 2, 2}; conv_params.conv_filter_strides_ = std::vector<ck::index_t>{2, 2, 2};
conv_params.input_left_pads = std::vector<ck::index_t>{2, 2, 2}; conv_params.input_left_pads_ = std::vector<ck::index_t>{2, 2, 2};
conv_params.input_right_pads = std::vector<ck::index_t>{2, 2, 2}; conv_params.input_right_pads_ = std::vector<ck::index_t>{2, 2, 2};
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{37, 37, 37}, std::vector<ck::index_t>{37, 37, 37},
"Error: ConvParams 3D padding left/right {2, 2, 2}.")); "Error: ConvParams 3D padding left/right {2, 2, 2}."));
conv_params.conv_filter_dilations = std::vector<ck::index_t>{2, 2, 2}; conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2, 2, 2};
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{36, 36, 36}, std::vector<ck::index_t>{36, 36, 36},
"Error: ConvParams 3D dilation {2, 2, 2}.")); "Error: ConvParams 3D dilation {2, 2, 2}."));
conv_params.conv_filter_strides = std::vector<ck::index_t>{3, 3, 3}; conv_params.conv_filter_strides_ = std::vector<ck::index_t>{3, 3, 3};
conv_params.input_left_pads = std::vector<ck::index_t>{1, 1, 1}; conv_params.input_left_pads_ = std::vector<ck::index_t>{1, 1, 1};
conv_params.input_right_pads = std::vector<ck::index_t>{1, 1, 1}; conv_params.input_right_pads_ = std::vector<ck::index_t>{1, 1, 1};
conv_params.conv_filter_dilations = std::vector<ck::index_t>{2, 2, 2}; conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2, 2, 2};
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_spatial_len, out_spatial_len,
std::vector<ck::index_t>{23, 23, 23}, std::vector<ck::index_t>{23, 23, 23},
......
...@@ -4,4 +4,4 @@ include_directories(BEFORE ...@@ -4,4 +4,4 @@ include_directories(BEFORE
) )
add_test_executable(test_convnd_bwd_data convnd_bwd_data.cpp) add_test_executable(test_convnd_bwd_data convnd_bwd_data.cpp)
target_link_libraries(test_convnd_bwd_data PRIVATE host_tensor device_convnd_bwd_data_instance conv_fwd_util) target_link_libraries(test_convnd_bwd_data PRIVATE host_tensor device_convnd_bwd_data_instance conv_util)
...@@ -27,20 +27,20 @@ int main() ...@@ -27,20 +27,20 @@ int main()
ck::tensor_layout::convolution::NWC, ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC, ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK>( ck::tensor_layout::convolution::NWK>(
1, // do_verification, true, // do_verification
1, // init_method, 1, // init_method
0, // do_log, false, // do_log
1, // nrepeat, false, // time_kernel
param.N, param.N_,
param.K, param.K_,
param.C, param.C_,
param.input_spatial_lengths, param.input_spatial_lengths_,
param.filter_spatial_lengths, param.filter_spatial_lengths_,
param.GetOutputSpatialLengths(), param.GetOutputSpatialLengths(),
param.conv_filter_strides, param.conv_filter_strides_,
param.conv_filter_dilations, param.conv_filter_dilations_,
param.input_left_pads, param.input_left_pads_,
param.input_right_pads); param.input_right_pads_);
pass &= ck::profiler::profile_convnd_bwd_data_impl<1, pass &= ck::profiler::profile_convnd_bwd_data_impl<1,
ck::half_t, ck::half_t,
...@@ -50,20 +50,20 @@ int main() ...@@ -50,20 +50,20 @@ int main()
ck::tensor_layout::convolution::NWC, ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC, ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK>( ck::tensor_layout::convolution::NWK>(
1, // do_verification, true, // do_verification
1, // init_method, 1, // init_method
0, // do_log, false, // do_log
1, // nrepeat, false, // time_kernel
param.N, param.N_,
param.K, param.K_,
param.C, param.C_,
param.input_spatial_lengths, param.input_spatial_lengths_,
param.filter_spatial_lengths, param.filter_spatial_lengths_,
param.GetOutputSpatialLengths(), param.GetOutputSpatialLengths(),
param.conv_filter_strides, param.conv_filter_strides_,
param.conv_filter_dilations, param.conv_filter_dilations_,
param.input_left_pads, param.input_left_pads_,
param.input_right_pads); param.input_right_pads_);
pass &= ck::profiler::profile_convnd_bwd_data_impl<1, pass &= ck::profiler::profile_convnd_bwd_data_impl<1,
ck::bhalf_t, ck::bhalf_t,
...@@ -73,20 +73,20 @@ int main() ...@@ -73,20 +73,20 @@ int main()
ck::tensor_layout::convolution::NWC, ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC, ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK>( ck::tensor_layout::convolution::NWK>(
1, // do_verification, true, // do_verification
1, // init_method, 1, // init_method
0, // do_log, false, // do_log
1, // nrepeat, false, // time_kernel
param.N, param.N_,
param.K, param.K_,
param.C, param.C_,
param.input_spatial_lengths, param.input_spatial_lengths_,
param.filter_spatial_lengths, param.filter_spatial_lengths_,
param.GetOutputSpatialLengths(), param.GetOutputSpatialLengths(),
param.conv_filter_strides, param.conv_filter_strides_,
param.conv_filter_dilations, param.conv_filter_dilations_,
param.input_left_pads, param.input_left_pads_,
param.input_right_pads); param.input_right_pads_);
pass &= ck::profiler::profile_convnd_bwd_data_impl<1, pass &= ck::profiler::profile_convnd_bwd_data_impl<1,
int8_t, int8_t,
...@@ -96,20 +96,20 @@ int main() ...@@ -96,20 +96,20 @@ int main()
ck::tensor_layout::convolution::NWC, ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC, ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK>( ck::tensor_layout::convolution::NWK>(
1, // do_verification, true, // do_verification
1, // init_method, 1, // init_method
0, // do_log, false, // do_log
1, // nrepeat, false, // time_kernel
param.N, param.N_,
param.K, param.K_,
param.C, param.C_,
param.input_spatial_lengths, param.input_spatial_lengths_,
param.filter_spatial_lengths, param.filter_spatial_lengths_,
param.GetOutputSpatialLengths(), param.GetOutputSpatialLengths(),
param.conv_filter_strides, param.conv_filter_strides_,
param.conv_filter_dilations, param.conv_filter_dilations_,
param.input_left_pads, param.input_left_pads_,
param.input_right_pads); param.input_right_pads_);
} }
// check 2d // check 2d
...@@ -128,20 +128,20 @@ int main() ...@@ -128,20 +128,20 @@ int main()
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
1, // do_verification, true, // do_verification
1, // init_method, 1, // init_method
0, // do_log, false, // do_log
1, // nrepeat, false, // time_kernel
param.N, param.N_,
param.K, param.K_,
param.C, param.C_,
param.input_spatial_lengths, param.input_spatial_lengths_,
param.filter_spatial_lengths, param.filter_spatial_lengths_,
param.GetOutputSpatialLengths(), param.GetOutputSpatialLengths(),
param.conv_filter_strides, param.conv_filter_strides_,
param.conv_filter_dilations, param.conv_filter_dilations_,
param.input_left_pads, param.input_left_pads_,
param.input_right_pads); param.input_right_pads_);
pass &= ck::profiler::profile_convnd_bwd_data_impl<2, pass &= ck::profiler::profile_convnd_bwd_data_impl<2,
ck::half_t, ck::half_t,
...@@ -151,20 +151,20 @@ int main() ...@@ -151,20 +151,20 @@ int main()
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
1, // do_verification, true, // do_verification
1, // init_method, 1, // init_method
0, // do_log, false, // do_log
1, // nrepeat, false, // time_kernel
param.N, param.N_,
param.K, param.K_,
param.C, param.C_,
param.input_spatial_lengths, param.input_spatial_lengths_,
param.filter_spatial_lengths, param.filter_spatial_lengths_,
param.GetOutputSpatialLengths(), param.GetOutputSpatialLengths(),
param.conv_filter_strides, param.conv_filter_strides_,
param.conv_filter_dilations, param.conv_filter_dilations_,
param.input_left_pads, param.input_left_pads_,
param.input_right_pads); param.input_right_pads_);
pass &= ck::profiler::profile_convnd_bwd_data_impl<2, pass &= ck::profiler::profile_convnd_bwd_data_impl<2,
ck::bhalf_t, ck::bhalf_t,
...@@ -174,20 +174,20 @@ int main() ...@@ -174,20 +174,20 @@ int main()
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
1, // do_verification, true, // do_verification
1, // init_method, 1, // init_method
0, // do_log, false, // do_log
1, // nrepeat, false, // time_kernel
param.N, param.N_,
param.K, param.K_,
param.C, param.C_,
param.input_spatial_lengths, param.input_spatial_lengths_,
param.filter_spatial_lengths, param.filter_spatial_lengths_,
param.GetOutputSpatialLengths(), param.GetOutputSpatialLengths(),
param.conv_filter_strides, param.conv_filter_strides_,
param.conv_filter_dilations, param.conv_filter_dilations_,
param.input_left_pads, param.input_left_pads_,
param.input_right_pads); param.input_right_pads_);
pass &= ck::profiler::profile_convnd_bwd_data_impl<2, pass &= ck::profiler::profile_convnd_bwd_data_impl<2,
int8_t, int8_t,
...@@ -197,20 +197,20 @@ int main() ...@@ -197,20 +197,20 @@ int main()
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
1, // do_verification, true, // do_verification
1, // init_method, 1, // init_method
0, // do_log, false, // do_log
1, // nrepeat, false, // time_kernel
param.N, param.N_,
param.K, param.K_,
param.C, param.C_,
param.input_spatial_lengths, param.input_spatial_lengths_,
param.filter_spatial_lengths, param.filter_spatial_lengths_,
param.GetOutputSpatialLengths(), param.GetOutputSpatialLengths(),
param.conv_filter_strides, param.conv_filter_strides_,
param.conv_filter_dilations, param.conv_filter_dilations_,
param.input_left_pads, param.input_left_pads_,
param.input_right_pads); param.input_right_pads_);
} }
// check 3d // check 3d
...@@ -232,20 +232,20 @@ int main() ...@@ -232,20 +232,20 @@ int main()
ck::tensor_layout::convolution::NDHWC, ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC, ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK>( ck::tensor_layout::convolution::NDHWK>(
1, // do_verification, true, // do_verification
1, // init_method, 1, // init_method
0, // do_log, false, // do_log
1, // nrepeat, false, // time_kernel
param.N, param.N_,
param.K, param.K_,
param.C, param.C_,
param.input_spatial_lengths, param.input_spatial_lengths_,
param.filter_spatial_lengths, param.filter_spatial_lengths_,
param.GetOutputSpatialLengths(), param.GetOutputSpatialLengths(),
param.conv_filter_strides, param.conv_filter_strides_,
param.conv_filter_dilations, param.conv_filter_dilations_,
param.input_left_pads, param.input_left_pads_,
param.input_right_pads); param.input_right_pads_);
pass &= ck::profiler::profile_convnd_bwd_data_impl<3, pass &= ck::profiler::profile_convnd_bwd_data_impl<3,
ck::half_t, ck::half_t,
...@@ -255,20 +255,20 @@ int main() ...@@ -255,20 +255,20 @@ int main()
ck::tensor_layout::convolution::NDHWC, ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC, ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK>( ck::tensor_layout::convolution::NDHWK>(
1, // do_verification, true, // do_verification
1, // init_method, 1, // init_method
0, // do_log, false, // do_log
1, // nrepeat, false, // time_kernel
param.N, param.N_,
param.K, param.K_,
param.C, param.C_,
param.input_spatial_lengths, param.input_spatial_lengths_,
param.filter_spatial_lengths, param.filter_spatial_lengths_,
param.GetOutputSpatialLengths(), param.GetOutputSpatialLengths(),
param.conv_filter_strides, param.conv_filter_strides_,
param.conv_filter_dilations, param.conv_filter_dilations_,
param.input_left_pads, param.input_left_pads_,
param.input_right_pads); param.input_right_pads_);
pass &= ck::profiler::profile_convnd_bwd_data_impl<3, pass &= ck::profiler::profile_convnd_bwd_data_impl<3,
ck::bhalf_t, ck::bhalf_t,
...@@ -278,20 +278,20 @@ int main() ...@@ -278,20 +278,20 @@ int main()
ck::tensor_layout::convolution::NDHWC, ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC, ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK>( ck::tensor_layout::convolution::NDHWK>(
1, // do_verification, true, // do_verification
1, // init_method, 1, // init_method
0, // do_log, false, // do_log
1, // nrepeat, false, // time_kernel
param.N, param.N_,
param.K, param.K_,
param.C, param.C_,
param.input_spatial_lengths, param.input_spatial_lengths_,
param.filter_spatial_lengths, param.filter_spatial_lengths_,
param.GetOutputSpatialLengths(), param.GetOutputSpatialLengths(),
param.conv_filter_strides, param.conv_filter_strides_,
param.conv_filter_dilations, param.conv_filter_dilations_,
param.input_left_pads, param.input_left_pads_,
param.input_right_pads); param.input_right_pads_);
pass &= ck::profiler::profile_convnd_bwd_data_impl<3, pass &= ck::profiler::profile_convnd_bwd_data_impl<3,
int8_t, int8_t,
...@@ -301,20 +301,20 @@ int main() ...@@ -301,20 +301,20 @@ int main()
ck::tensor_layout::convolution::NDHWC, ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC, ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK>( ck::tensor_layout::convolution::NDHWK>(
1, // do_verification, true, // do_verification
1, // init_method, 1, // init_method
0, // do_log, false, // do_log
1, // nrepeat, false, // time_kernel
param.N, param.N_,
param.K, param.K_,
param.C, param.C_,
param.input_spatial_lengths, param.input_spatial_lengths_,
param.filter_spatial_lengths, param.filter_spatial_lengths_,
param.GetOutputSpatialLengths(), param.GetOutputSpatialLengths(),
param.conv_filter_strides, param.conv_filter_strides_,
param.conv_filter_dilations, param.conv_filter_dilations_,
param.input_left_pads, param.input_left_pads_,
param.input_right_pads); param.input_right_pads_);
} }
if(pass) if(pass)
......
add_custom_target(test_convnd_fwd) add_custom_target(test_convnd_fwd)
add_gtest_executable(test_conv1d_fwd conv1d_fwd.cpp) add_gtest_executable(test_conv1d_fwd conv1d_fwd.cpp)
target_link_libraries(test_conv1d_fwd PRIVATE host_tensor device_conv1d_fwd_instance conv_fwd_util) target_link_libraries(test_conv1d_fwd PRIVATE host_tensor device_conv1d_fwd_instance conv_util)
add_dependencies(test_convnd_fwd test_conv1d_fwd) add_dependencies(test_convnd_fwd test_conv1d_fwd)
add_gtest_executable(test_conv2d_fwd conv2d_fwd.cpp) add_gtest_executable(test_conv2d_fwd conv2d_fwd.cpp)
target_link_libraries(test_conv2d_fwd PRIVATE host_tensor device_conv2d_fwd_instance conv_fwd_util) target_link_libraries(test_conv2d_fwd PRIVATE host_tensor device_conv2d_fwd_instance conv_util)
add_dependencies(test_convnd_fwd test_conv2d_fwd) add_dependencies(test_convnd_fwd test_conv2d_fwd)
add_gtest_executable(test_conv3d_fwd conv3d_fwd.cpp) add_gtest_executable(test_conv3d_fwd conv3d_fwd.cpp)
target_link_libraries(test_conv3d_fwd PRIVATE host_tensor device_conv3d_fwd_instance conv_fwd_util) target_link_libraries(test_conv3d_fwd PRIVATE host_tensor device_conv3d_fwd_instance conv_util)
add_dependencies(test_convnd_fwd test_conv3d_fwd) add_dependencies(test_convnd_fwd test_conv3d_fwd)
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "data_type.hpp" #include "data_type.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "conv_fwd_util.hpp" #include "library/include/ck/library/utility/conv_util.hpp"
#include "conv_util.hpp" #include "conv_util.hpp"
namespace { namespace {
...@@ -19,13 +19,13 @@ bool test_conv1d_nwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPt ...@@ -19,13 +19,13 @@ bool test_conv1d_nwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPt
namespace ctl = ck::tensor_layout::convolution; namespace ctl = ck::tensor_layout::convolution;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
params.num_dim_spatial = 1; params.num_dim_spatial_ = 1;
params.filter_spatial_lengths = std::vector<ck::index_t>{3}; params.filter_spatial_lengths_ = std::vector<ck::index_t>{3};
params.input_spatial_lengths = std::vector<ck::index_t>{71}; params.input_spatial_lengths_ = std::vector<ck::index_t>{71};
params.conv_filter_strides = std::vector<ck::index_t>{2}; params.conv_filter_strides_ = std::vector<ck::index_t>{2};
params.conv_filter_dilations = std::vector<ck::index_t>{1}; params.conv_filter_dilations_ = std::vector<ck::index_t>{1};
params.input_left_pads = std::vector<ck::index_t>{1}; params.input_left_pads_ = std::vector<ck::index_t>{1};
params.input_right_pads = std::vector<ck::index_t>{1}; params.input_right_pads_ = std::vector<ck::index_t>{1};
conv::ConvFwdOpInstance<T, T, T, ctl::NWC, ctl::KCX, ctl::NWK> conv_instance(params); conv::ConvFwdOpInstance<T, T, T, ctl::NWC, ctl::KCX, ctl::NWK> conv_instance(params);
...@@ -44,16 +44,16 @@ TEST(Conv1DFwdNWC, TestConv1D) ...@@ -44,16 +44,16 @@ TEST(Conv1DFwdNWC, TestConv1D)
namespace ctl = ck::tensor_layout::convolution; namespace ctl = ck::tensor_layout::convolution;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
params.num_dim_spatial = 1; params.num_dim_spatial_ = 1;
params.N = 2; params.N_ = 2;
params.K = 16; params.K_ = 16;
params.C = 4; params.C_ = 4;
params.filter_spatial_lengths = std::vector<ck::index_t>{3}; params.filter_spatial_lengths_ = std::vector<ck::index_t>{3};
params.input_spatial_lengths = std::vector<ck::index_t>{16}; params.input_spatial_lengths_ = std::vector<ck::index_t>{16};
params.conv_filter_strides = std::vector<ck::index_t>{1}; params.conv_filter_strides_ = std::vector<ck::index_t>{1};
params.conv_filter_dilations = std::vector<ck::index_t>{1}; params.conv_filter_dilations_ = std::vector<ck::index_t>{1};
params.input_left_pads = std::vector<ck::index_t>{1}; params.input_left_pads_ = std::vector<ck::index_t>{1};
params.input_right_pads = std::vector<ck::index_t>{1}; params.input_right_pads_ = std::vector<ck::index_t>{1};
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
test::conv::get_test_convolution_fwd_instance<1>(conv_ptrs); test::conv::get_test_convolution_fwd_instance<1>(conv_ptrs);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "data_type.hpp" #include "data_type.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "conv_fwd_util.hpp" #include "ck/library/utility/conv_util.hpp"
#include "conv_util.hpp" #include "conv_util.hpp"
namespace { namespace {
...@@ -18,13 +18,13 @@ bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpP ...@@ -18,13 +18,13 @@ bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpP
using namespace ck::utils; using namespace ck::utils;
conv::ConvParams params; conv::ConvParams params;
params.num_dim_spatial = 2; params.num_dim_spatial_ = 2;
params.filter_spatial_lengths = std::vector<ck::index_t>{3, 3}; params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3};
params.input_spatial_lengths = std::vector<ck::index_t>{71, 71}; params.input_spatial_lengths_ = std::vector<ck::index_t>{71, 71};
params.conv_filter_strides = std::vector<ck::index_t>{2, 2}; params.conv_filter_strides_ = std::vector<ck::index_t>{2, 2};
params.conv_filter_dilations = std::vector<ck::index_t>{1, 1}; params.conv_filter_dilations_ = std::vector<ck::index_t>{1, 1};
params.input_left_pads = std::vector<ck::index_t>{1, 1}; params.input_left_pads_ = std::vector<ck::index_t>{1, 1};
params.input_right_pads = std::vector<ck::index_t>{1, 1}; params.input_right_pads_ = std::vector<ck::index_t>{1, 1};
conv::ConvFwdOpInstance<T, T, T> conv_instance(params); conv::ConvFwdOpInstance<T, T, T> conv_instance(params);
...@@ -42,11 +42,11 @@ TEST(Conv2DFwdNHWC, TestConv2D) ...@@ -42,11 +42,11 @@ TEST(Conv2DFwdNHWC, TestConv2D)
using namespace ck::utils; using namespace ck::utils;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
params.N = 2; params.N_ = 2;
params.K = 16; params.K_ = 16;
params.C = 4; params.C_ = 4;
params.input_spatial_lengths = std::vector<ck::index_t>{16, 16}; params.input_spatial_lengths_ = std::vector<ck::index_t>{16, 16};
params.conv_filter_strides = std::vector<ck::index_t>{1, 1}; params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1};
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
test::conv::get_test_convolution_fwd_instance<2>(conv_ptrs); test::conv::get_test_convolution_fwd_instance<2>(conv_ptrs);
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "data_type.hpp" #include "data_type.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "conv_fwd_util.hpp" #include "library/include/ck/library/utility/conv_util.hpp"
#include "conv_util.hpp" #include "conv_util.hpp"
namespace { namespace {
...@@ -20,14 +20,14 @@ bool test_conv3d_ndhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOp ...@@ -20,14 +20,14 @@ bool test_conv3d_ndhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOp
namespace ctl = ck::tensor_layout::convolution; namespace ctl = ck::tensor_layout::convolution;
conv::ConvParams params; conv::ConvParams params;
params.N = 64; params.N_ = 64;
params.num_dim_spatial = 3; params.num_dim_spatial_ = 3;
params.filter_spatial_lengths = std::vector<ck::index_t>{3, 3, 2}; params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3, 2};
params.input_spatial_lengths = std::vector<ck::index_t>{32, 32, 2}; params.input_spatial_lengths_ = std::vector<ck::index_t>{32, 32, 2};
params.conv_filter_strides = std::vector<ck::index_t>{2, 2, 2}; params.conv_filter_strides_ = std::vector<ck::index_t>{2, 2, 2};
params.conv_filter_dilations = std::vector<ck::index_t>{1, 1, 1}; params.conv_filter_dilations_ = std::vector<ck::index_t>{1, 1, 1};
params.input_left_pads = std::vector<ck::index_t>{1, 1, 1}; params.input_left_pads_ = std::vector<ck::index_t>{1, 1, 1};
params.input_right_pads = std::vector<ck::index_t>{1, 1, 1}; params.input_right_pads_ = std::vector<ck::index_t>{1, 1, 1};
conv::ConvFwdOpInstance<T, T, T, ctl::NDHWC, ctl::KZYXC, ctl::NDHWK> conv_instance(params); conv::ConvFwdOpInstance<T, T, T, ctl::NDHWC, ctl::KZYXC, ctl::NDHWK> conv_instance(params);
...@@ -46,16 +46,16 @@ TEST(Conv3DFwdNDHWC, TestConv3D) ...@@ -46,16 +46,16 @@ TEST(Conv3DFwdNDHWC, TestConv3D)
namespace ctl = ck::tensor_layout::convolution; namespace ctl = ck::tensor_layout::convolution;
conv::ConvParams params; conv::ConvParams params;
params.num_dim_spatial = 3; params.num_dim_spatial_ = 3;
params.N = 2; params.N_ = 2;
params.K = 16; params.K_ = 16;
params.C = 4; params.C_ = 4;
params.filter_spatial_lengths = std::vector<ck::index_t>{3, 3, 3}; params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3, 3};
params.input_spatial_lengths = std::vector<ck::index_t>{16, 16, 16}; params.input_spatial_lengths_ = std::vector<ck::index_t>{16, 16, 16};
params.conv_filter_strides = std::vector<ck::index_t>{1, 1, 1}; params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1, 1};
params.conv_filter_dilations = std::vector<ck::index_t>{1, 1, 1}; params.conv_filter_dilations_ = std::vector<ck::index_t>{1, 1, 1};
params.input_left_pads = std::vector<ck::index_t>{1, 1, 1}; params.input_left_pads_ = std::vector<ck::index_t>{1, 1, 1};
params.input_right_pads = std::vector<ck::index_t>{1, 1, 1}; params.input_right_pads_ = std::vector<ck::index_t>{1, 1, 1};
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs);
...@@ -77,16 +77,16 @@ TEST(Conv3DFwdNDHWC, InputOver2GB) ...@@ -77,16 +77,16 @@ TEST(Conv3DFwdNDHWC, InputOver2GB)
// >2GB Input // >2GB Input
conv::ConvParams params; conv::ConvParams params;
params.num_dim_spatial = 3; params.num_dim_spatial_ = 3;
params.N = 2; params.N_ = 2;
params.K = 16; params.K_ = 16;
params.C = 32; params.C_ = 32;
params.filter_spatial_lengths = std::vector<ck::index_t>{3, 3, 3}; params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3, 3};
params.input_spatial_lengths = std::vector<ck::index_t>{32, 1000, 1000}; params.input_spatial_lengths_ = std::vector<ck::index_t>{32, 1000, 1000};
params.conv_filter_strides = std::vector<ck::index_t>{1, 1, 1}; params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1, 1};
params.conv_filter_dilations = std::vector<ck::index_t>{1, 1, 1}; params.conv_filter_dilations_ = std::vector<ck::index_t>{1, 1, 1};
params.input_left_pads = std::vector<ck::index_t>{1, 1, 1}; params.input_left_pads_ = std::vector<ck::index_t>{1, 1, 1};
params.input_right_pads = std::vector<ck::index_t>{1, 1, 1}; params.input_right_pads_ = std::vector<ck::index_t>{1, 1, 1};
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs);
...@@ -94,16 +94,16 @@ TEST(Conv3DFwdNDHWC, InputOver2GB) ...@@ -94,16 +94,16 @@ TEST(Conv3DFwdNDHWC, InputOver2GB)
auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr,
nullptr, nullptr,
nullptr, nullptr,
params.N, params.N_,
params.K, params.K_,
params.C, params.C_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
params.GetOutputSpatialLengths(), params.GetOutputSpatialLengths(),
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
PassThrough{}, PassThrough{},
PassThrough{}, PassThrough{},
PassThrough{}); PassThrough{});
...@@ -117,16 +117,16 @@ TEST(Conv3DFwdNDHWC, FiltersOver2GB) ...@@ -117,16 +117,16 @@ TEST(Conv3DFwdNDHWC, FiltersOver2GB)
// >2GB Filters // >2GB Filters
conv::ConvParams params; conv::ConvParams params;
params.num_dim_spatial = 3; params.num_dim_spatial_ = 3;
params.N = 2; params.N_ = 2;
params.K = 16; params.K_ = 16;
params.C = 32; params.C_ = 32;
params.filter_spatial_lengths = std::vector<ck::index_t>{4, 1000, 1000}; params.filter_spatial_lengths_ = std::vector<ck::index_t>{4, 1000, 1000};
params.input_spatial_lengths = std::vector<ck::index_t>{16, 16, 16}; params.input_spatial_lengths_ = std::vector<ck::index_t>{16, 16, 16};
params.conv_filter_strides = std::vector<ck::index_t>{1, 1, 1}; params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1, 1};
params.conv_filter_dilations = std::vector<ck::index_t>{1, 1, 1}; params.conv_filter_dilations_ = std::vector<ck::index_t>{1, 1, 1};
params.input_left_pads = std::vector<ck::index_t>{1, 1, 1}; params.input_left_pads_ = std::vector<ck::index_t>{1, 1, 1};
params.input_right_pads = std::vector<ck::index_t>{1, 1, 1}; params.input_right_pads_ = std::vector<ck::index_t>{1, 1, 1};
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs);
...@@ -134,16 +134,16 @@ TEST(Conv3DFwdNDHWC, FiltersOver2GB) ...@@ -134,16 +134,16 @@ TEST(Conv3DFwdNDHWC, FiltersOver2GB)
auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr,
nullptr, nullptr,
nullptr, nullptr,
params.N, params.N_,
params.K, params.K_,
params.C, params.C_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
params.GetOutputSpatialLengths(), params.GetOutputSpatialLengths(),
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
PassThrough{}, PassThrough{},
PassThrough{}, PassThrough{},
PassThrough{}); PassThrough{});
...@@ -157,32 +157,32 @@ TEST(Conv3DFwdNDHWC, OutputOver2GB) ...@@ -157,32 +157,32 @@ TEST(Conv3DFwdNDHWC, OutputOver2GB)
// >2GB Output // >2GB Output
conv::ConvParams params; conv::ConvParams params;
params.num_dim_spatial = 3; params.num_dim_spatial_ = 3;
params.N = 2; params.N_ = 2;
params.K = 16; params.K_ = 16;
params.C = 2; params.C_ = 2;
params.filter_spatial_lengths = std::vector<ck::index_t>{1, 1, 1}; params.filter_spatial_lengths_ = std::vector<ck::index_t>{1, 1, 1};
params.input_spatial_lengths = std::vector<ck::index_t>{1000, 1000, 30}; params.input_spatial_lengths_ = std::vector<ck::index_t>{1000, 1000, 30};
params.conv_filter_strides = std::vector<ck::index_t>{1, 1, 1}; params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1, 1};
params.conv_filter_dilations = std::vector<ck::index_t>{1, 1, 1}; params.conv_filter_dilations_ = std::vector<ck::index_t>{1, 1, 1};
params.input_left_pads = std::vector<ck::index_t>{2, 2, 2}; params.input_left_pads_ = std::vector<ck::index_t>{2, 2, 2};
params.input_right_pads = std::vector<ck::index_t>{2, 2, 2}; params.input_right_pads_ = std::vector<ck::index_t>{2, 2, 2};
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs);
auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr,
nullptr, nullptr,
nullptr, nullptr,
params.N, params.N_,
params.K, params.K_,
params.C, params.C_,
params.input_spatial_lengths, params.input_spatial_lengths_,
params.filter_spatial_lengths, params.filter_spatial_lengths_,
params.GetOutputSpatialLengths(), params.GetOutputSpatialLengths(),
params.conv_filter_strides, params.conv_filter_strides_,
params.conv_filter_dilations, params.conv_filter_dilations_,
params.input_left_pads, params.input_left_pads_,
params.input_right_pads, params.input_right_pads_,
PassThrough{}, PassThrough{},
PassThrough{}, PassThrough{},
PassThrough{}); PassThrough{});
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#include <tuple> #include <tuple>
#include "config.hpp" #include "config.hpp"
#include "conv_fwd_util.hpp"
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
......
add_test_executable(test_gemm_fp32 gemm_fp32.cpp) # GEMM XDL
target_link_libraries(test_gemm_fp32 PRIVATE host_tensor) add_test_executable(test_gemm_xdl_fp32 gemm_xdl_fp32.cpp)
target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_xdl_fp32 PRIVATE host_tensor)
target_link_libraries(test_gemm_xdl_fp32 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_fp16 gemm_fp16.cpp) add_test_executable(test_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
target_link_libraries(test_gemm_fp16 PRIVATE host_tensor) target_link_libraries(test_gemm_xdl_fp16 PRIVATE host_tensor)
target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_xdl_fp16 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_bf16 gemm_bf16.cpp) add_test_executable(test_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
target_link_libraries(test_gemm_bf16 PRIVATE host_tensor) target_link_libraries(test_gemm_xdl_bf16 PRIVATE host_tensor)
target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_xdl_bf16 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_int8 gemm_int8.cpp) add_test_executable(test_gemm_xdl_int8 gemm_xdl_int8.cpp)
target_link_libraries(test_gemm_int8 PRIVATE host_tensor) target_link_libraries(test_gemm_xdl_int8 PRIVATE host_tensor)
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_xdl_int8 PRIVATE device_gemm_instance)
# GEMM DL
add_test_executable(test_gemm_dl_fp32 gemm_dl_fp32.cpp)
target_link_libraries(test_gemm_dl_fp32 PRIVATE host_tensor)
target_link_libraries(test_gemm_dl_fp32 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_dl_fp16 gemm_dl_fp16.cpp)
target_link_libraries(test_gemm_dl_fp16 PRIVATE host_tensor)
target_link_libraries(test_gemm_dl_fp16 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_dl_int8 gemm_dl_int8.cpp)
target_link_libraries(test_gemm_dl_int8 PRIVATE host_tensor)
TArget_link_libraries(test_gemm_dl_int8 PRIVATE device_gemm_instance)
#include <algorithm> #include <algorithm>
#include <cstdlib> #include <cstdlib>
#include <half.hpp> #include <half.hpp>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "gemm_util.hpp" #include "../gemm/gemm_util.hpp"
#include "config.hpp" #include "config.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_dl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp" #include "element_wise_operation.hpp"
#include "element_wise_operation.hpp" #include "reference_gemm.hpp"
#include "reference_gemm.hpp" #include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
using DeviceGemmNoOpPtr = ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough>;
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace ck { namespace tensor_operation {
namespace tensor_operation { namespace device {
namespace device { namespace device_gemm_instance {
namespace device_gemm_instance {
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); } // namespace device_gemm_instance
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( } // namespace device
std::vector<DeviceGemmNoOpPtr>&); } // namespace tensor_operation
} // namespace device_gemm_instance } // namespace ck
} // namespace device
} // namespace tensor_operation int main()
} // namespace ck {
using ADataType = ck::half_t;
int main() using BDataType = ck::half_t;
{ using CDataType = ck::half_t;
using ADataType = int8_t; using AccDataType = float;
using BDataType = int8_t;
using CDataType = int8_t; using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs; std::vector<DeviceGemmNoOpPtr> gemmPtrs;
bool res = true;
ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
for(auto& gemmPtr : gemmPtrs) {
{ res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr, ADataType,
ADataType, BDataType,
BDataType, CDataType,
CDataType, AccDataType,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>{}(gemmPtr); PassThrough>{}(gemmPtr);
} }
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemmPtrs); add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs) for(auto& gemmPtr : gemmPtrs)
{ {
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr, res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
ColumnMajor, AccDataType,
ColumnMajor, ColumnMajor,
RowMajor, ColumnMajor,
PassThrough, RowMajor,
PassThrough, PassThrough,
PassThrough>{}(gemmPtr); PassThrough,
} PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: gemmPtrs.clear();
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemmPtrs); ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{ for(auto& gemmPtr : gemmPtrs)
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr, {
ADataType, res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
BDataType, ADataType,
CDataType, BDataType,
RowMajor, CDataType,
RowMajor, AccDataType,
RowMajor, RowMajor,
PassThrough, RowMajor,
PassThrough, RowMajor,
PassThrough>{}(gemmPtr); PassThrough,
} PassThrough,
PassThrough>{}(gemmPtr);
gemmPtrs.clear(); }
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemmPtrs); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
for(auto& gemmPtr : gemmPtrs) add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr, for(auto& gemmPtr : gemmPtrs)
ADataType, {
BDataType, res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
CDataType, ADataType,
RowMajor, BDataType,
ColumnMajor, CDataType,
RowMajor, AccDataType,
PassThrough, RowMajor,
PassThrough, ColumnMajor,
PassThrough>{}(gemmPtr); RowMajor,
} PassThrough,
PassThrough,
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; PassThrough>{}(gemmPtr);
return res ? 0 : 1; }
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "../gemm/gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_dl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "../gemm/gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_dl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
...@@ -60,7 +60,7 @@ template <typename DeviceGemmPtr_, ...@@ -60,7 +60,7 @@ template <typename DeviceGemmPtr_,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
const ck::gemm_util::GemmParams& params, const ck::gemm_util::GemmParams& params,
const Tensor<ADataType>& A, const Tensor<ADataType>& A,
const Tensor<BDataType>& B, const Tensor<BDataType>& B,
...@@ -73,9 +73,6 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, ...@@ -73,9 +73,6 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(A.mData.data());
b_k_n_device_buf.ToDevice(B.mData.data());
auto invoker_ptr = gemmPtr->MakeInvokerPointer(); auto invoker_ptr = gemmPtr->MakeInvokerPointer();
auto argument_ptr = auto argument_ptr =
gemmPtr->MakeArgumentPointer(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), gemmPtr->MakeArgumentPointer(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
...@@ -91,21 +88,30 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, ...@@ -91,21 +88,30 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
b_element_op, b_element_op,
c_element_op); c_element_op);
if(!gemmPtr->IsSupportedArgument(argument_ptr.get())) if(gemmPtr->IsSupportedArgument(argument_ptr.get()))
{ {
throw std::runtime_error( a_m_k_device_buf.ToDevice(A.mData.data());
"wrong! device_gemm with the specified compilation parameters does " b_k_n_device_buf.ToDevice(B.mData.data());
"not support this GEMM problem"); invoker_ptr->Run(argument_ptr.get());
c_m_n_device_buf.FromDevice(C.mData.data());
return true;
} }
else
{
std::cout << "device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
<< std::endl;
invoker_ptr->Run(argument_ptr.get()); return false;
c_m_n_device_buf.FromDevice(C.mData.data()); }
} }
template <typename DeviceGemmPtr_, template <typename DeviceGemmPtr_,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -181,6 +187,7 @@ struct TestGemm ...@@ -181,6 +187,7 @@ struct TestGemm
ck::tensor_operation::host::ReferenceGemm<ADataType, ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>; CElementwiseOperation>;
...@@ -188,28 +195,40 @@ struct TestGemm ...@@ -188,28 +195,40 @@ struct TestGemm
a, b, c_host, a_element_op, b_element_op, c_element_op); a, b, c_host, a_element_op, b_element_op, c_element_op);
// Act // Act
ck::gemm_util::RunDeviceGEMM( bool is_supported = ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op);
// Assert if(is_supported)
bool res = false;
if(std::is_same<CDataType, float>::value)
{ {
res = ck::utils::check_err(c_device.mData, c_host.mData); // Assert
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; bool res = false;
if(std::is_same<CDataType, float>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, ck::half_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, int8_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, double>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
return res;
} }
else if(std::is_same<CDataType, ck::half_t>::value) else
{ {
res = ck::utils::check_err(c_device.mData, c_host.mData); return true;
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
else if(std::is_same<CDataType, int8_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
return res;
} }
}; };
...@@ -299,6 +318,7 @@ struct TestGemmBF16 ...@@ -299,6 +318,7 @@ struct TestGemmBF16
// use fp32 host kernel to verify bf16 device kernel // use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance = using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<float, ck::tensor_operation::host::ReferenceGemm<float,
float,
float, float,
float, float,
AElementwiseOperation, AElementwiseOperation,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -52,9 +52,10 @@ void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( ...@@ -52,9 +52,10 @@ void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
int main() int main()
{ {
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor; using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
...@@ -74,6 +75,7 @@ int main() ...@@ -74,6 +75,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -96,6 +98,7 @@ int main() ...@@ -96,6 +98,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
...@@ -118,6 +121,7 @@ int main() ...@@ -118,6 +121,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -142,6 +146,7 @@ int main() ...@@ -142,6 +146,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -53,9 +53,10 @@ void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<De ...@@ -53,9 +53,10 @@ void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<De
int main() int main()
{ {
using ADataType = float; using ADataType = float;
using BDataType = float; using BDataType = float;
using CDataType = float; using CDataType = float;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor; using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
...@@ -75,6 +76,7 @@ int main() ...@@ -75,6 +76,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -97,6 +99,7 @@ int main() ...@@ -97,6 +99,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
ColumnMajor, ColumnMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
...@@ -119,6 +122,7 @@ int main() ...@@ -119,6 +122,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
RowMajor, RowMajor,
RowMajor, RowMajor,
...@@ -141,6 +145,7 @@ int main() ...@@ -141,6 +145,7 @@ int main()
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
RowMajor, RowMajor,
ColumnMajor, ColumnMajor,
RowMajor, RowMajor,
......
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
inline std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return std::string();
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return std::string();
}
const std::string name(props.gcnArchName);
return name;
}
int main()
{
if(get_device_name().find("gfx90a") == std::string::npos)
{
std::cout << "TestGemm ..... SUCCESS" << std::endl;
return 0;
}
using ADataType = double;
using BDataType = double;
using CDataType = double;
using AccDataType = double;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int32_t;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
bool res = true;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
...@@ -16,22 +16,22 @@ int main() ...@@ -16,22 +16,22 @@ int main()
pass = pass && pass = pass &&
ck::profiler:: ck::profiler::
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Row, Row, Row>( profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Row, Row, Row>(
true, 1, false, 1, M, N, K, K, N, N); true, 1, false, false, M, N, K, K, N, N);
pass = pass && pass = pass &&
ck::profiler:: ck::profiler::
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Row, Col, Row>( profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Row, Col, Row>(
true, 1, false, 1, M, N, K, K, K, N); true, 1, false, false, M, N, K, K, K, N);
pass = pass && pass = pass &&
ck::profiler:: ck::profiler::
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Col, Row, Row>( profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Col, Row, Row>(
true, 1, false, 1, M, N, K, M, N, N); true, 1, false, false, M, N, K, M, N, N);
pass = pass && pass = pass &&
ck::profiler:: ck::profiler::
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Col, Col, Row>( profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Col, Col, Row>(
true, 1, false, 1, M, N, K, M, K, N); true, 1, false, false, M, N, K, M, K, N);
if(pass) if(pass)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment