Commit 401e643e authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Merge branch 'develop' into feature/use-larger-tile-size-for-chunk-prefill

parents d783a8cf fdfe2102
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp" #include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
template <typename Tuple> template <typename Tuple>
class TestGroupedConvndBwdData : public ::testing::Test class TestGroupedConvndBwdDataXdl : public ::testing::Test
{ {
protected: protected:
using DataType = std::tuple_element_t<0, Tuple>; using DataType = std::tuple_element_t<0, Tuple>;
...@@ -51,35 +51,31 @@ using namespace ck::tensor_layout::convolution; ...@@ -51,35 +51,31 @@ using namespace ck::tensor_layout::convolution;
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>, using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>, std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>, std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>,
std::tuple<int8_t, GNHWK, GKYXC, GNHWC>,
std::tuple<float, NHWGK, GKYXC, NHWGC>, std::tuple<float, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>, std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>, std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>;
std::tuple<int8_t, NHWGK, GKYXC, NHWGC>>;
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>, using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>, std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>, std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<int8_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<float, NDHWGK, GKZYXC, NDHWGC>, std::tuple<float, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>, std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>, std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>>;
std::tuple<int8_t, NDHWGK, GKZYXC, NDHWGC>>;
template <typename Tuple> template <typename Tuple>
class TestGroupedConvndBwdData2d : public TestGroupedConvndBwdData<Tuple> class TestGroupedConvndBwdDataXdl2d : public TestGroupedConvndBwdDataXdl<Tuple>
{ {
}; };
template <typename Tuple> template <typename Tuple>
class TestGroupedConvndBwdData3d : public TestGroupedConvndBwdData<Tuple> class TestGroupedConvndBwdDataXdl3d : public TestGroupedConvndBwdDataXdl<Tuple>
{ {
}; };
TYPED_TEST_SUITE(TestGroupedConvndBwdData2d, KernelTypes2d); TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d); TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdData2d, Test2D) TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D)
{ {
this->conv_params.clear(); this->conv_params.clear();
...@@ -94,10 +90,13 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D) ...@@ -94,10 +90,13 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
// SplitN case
this->conv_params.push_back(
{2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}});
this->template Run<2>(); this->template Run<2>();
} }
TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D)
{ {
this->conv_params.clear(); this->conv_params.clear();
this->conv_params.push_back( this->conv_params.push_back(
...@@ -112,5 +111,17 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) ...@@ -112,5 +111,17 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
// SplitN case
this->conv_params.push_back({3,
1,
128,
4,
192,
{2, 2, 2},
{2, 224, 224},
{1, 224, 224},
{1, 1, 1},
{0, 0, 0},
{0, 0, 0}});
this->template Run<3>(); this->template Run<3>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment