"docs/source/vscode:/vscode.git/clone" did not exist on "37c1e3c218ed9987cb6e1a52a2efdeed2e3c304a"
Unverified Commit 9684677a authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into hip_tensor_permute

parents 36f6966a 98fd41f5
...@@ -53,7 +53,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -53,7 +53,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
template <typename ConvTensorRearrangeOp> template <typename ConvTensorRearrangeOp>
bool Run() bool Run()
{ {
const auto G = conv_param.G_;
const auto N = conv_param.N_; const auto N = conv_param.N_;
const auto C = conv_param.C_; const auto C = conv_param.C_;
const auto FakeC = const auto FakeC =
...@@ -71,13 +71,13 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -71,13 +71,13 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
const auto image_desc = const auto image_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>( ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(
conv_param); conv_param);
const auto gemm_desc = HostTensorDescriptor({NDoHoWo, CZYX}); const auto gemm_desc = HostTensorDescriptor({G, NDoHoWo, CZYX});
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_g_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 3> input_g_n_c_wis_strides{};
std::array<ck::index_t, 2> output_m_k_strides{}; std::array<ck::index_t, 3> output_g_m_k_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
...@@ -89,7 +89,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -89,7 +89,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
copy(conv_param.filter_spatial_lengths_, filter_spatial_lengths); copy(conv_param.filter_spatial_lengths_, filter_spatial_lengths);
copy(conv_param.output_spatial_lengths_, output_spatial_lengths); copy(conv_param.output_spatial_lengths_, output_spatial_lengths);
copy(image_desc.GetStrides(), input_g_n_c_wis_strides); copy(image_desc.GetStrides(), input_g_n_c_wis_strides);
copy(gemm_desc.GetStrides(), output_m_k_strides); copy(gemm_desc.GetStrides(), output_g_m_k_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides); copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations); copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_left_pads_, input_left_pads);
...@@ -100,13 +100,14 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -100,13 +100,14 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
auto img2col = DeviceImgToColInstance{}; auto img2col = DeviceImgToColInstance{};
auto argument = img2col.MakeArgument(nullptr, auto argument = img2col.MakeArgument(nullptr,
nullptr, nullptr,
G,
N, N,
IsCPacked ? C : FakeC, IsCPacked ? C : FakeC,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_g_n_c_wis_strides, input_g_n_c_wis_strides,
output_m_k_strides, output_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -119,13 +120,14 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -119,13 +120,14 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
auto col2img = DeviceColToimgInstance{}; auto col2img = DeviceColToimgInstance{};
auto argument = col2img.MakeArgument(nullptr, auto argument = col2img.MakeArgument(nullptr,
nullptr, nullptr,
G,
N, N,
IsCPacked ? C : FakeC, IsCPacked ? C : FakeC,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_g_n_c_wis_strides, input_g_n_c_wis_strides,
output_m_k_strides, output_g_m_k_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -108,6 +108,10 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops) ...@@ -108,6 +108,10 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
// kloops % 2 // kloops % 2
Ks = std::vector<int>{256, 512, 320, 768}; Ks = std::vector<int>{256, 512, 320, 768};
EXPECT_FALSE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
Ks = std::vector<int>{256, 512, 384, 768};
EXPECT_TRUE( EXPECT_TRUE(
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch)); DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment