"docs/vscode:/vscode.git/clone" did not exist on "cf6ce7c77514189fd021f191603611e4d47a72b9"
Unverified Commit 29dcb956 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #33 from ROCm/lwpck-1292

Merge from the public repo.
parents 29deceb6 cbcc844e
...@@ -147,3 +147,9 @@ export onnx_log="perf_onnx_gemm.log" ...@@ -147,3 +147,9 @@ export onnx_log="perf_onnx_gemm.log"
print_log_header $onnx_log $env_type $branch $host_name print_log_header $onnx_log $env_type $branch $host_name
./profile_onnx_gemm.sh gemm 0 0 $verify 1 0 1 2>&1 | tee -a $onnx_log ./profile_onnx_gemm.sh gemm 0 0 $verify 1 0 1 2>&1 | tee -a $onnx_log
./profile_onnx_gemm.sh gemm 1 0 $verify 1 0 1 2>&1 | tee -a $onnx_log ./profile_onnx_gemm.sh gemm 1 0 $verify 1 0 1 2>&1 | tee -a $onnx_log
#run mixed fp16/fp8 and fp8/fp16 gemm tests
export mixed_gemm_log="perf_mixed_gemm.log"
print_log_header $mixed_gemm_log $env_type $branch $host_name
./profile_mixed_gemm.sh gemm_splitk 4 0 $verify 2 0 1 16 2>&1 | tee -a $mixed_gemm_log
./profile_mixed_gemm.sh gemm_splitk 5 0 $verify 2 0 1 16 2>&1 | tee -a $mixed_gemm_log
\ No newline at end of file
...@@ -3,7 +3,7 @@ include_directories(BEFORE ...@@ -3,7 +3,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/profiler/include ${PROJECT_SOURCE_DIR}/profiler/include
) )
include(googletest) include(gtest)
add_custom_target(tests) add_custom_target(tests)
...@@ -50,6 +50,7 @@ function(add_test_executable TEST_NAME) ...@@ -50,6 +50,7 @@ function(add_test_executable TEST_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(ARGN) if(ARGN)
add_executable(${TEST_NAME} ${ARGN}) add_executable(${TEST_NAME} ${ARGN})
target_link_libraries(${TEST_NAME} PRIVATE getopt::getopt)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>) add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME}) add_dependencies(check ${TEST_NAME})
...@@ -58,9 +59,7 @@ function(add_test_executable TEST_NAME) ...@@ -58,9 +59,7 @@ function(add_test_executable TEST_NAME)
endif() endif()
#message("add_test returns ${result}") #message("add_test returns ${result}")
set(result ${result} PARENT_SCOPE) set(result ${result} PARENT_SCOPE)
endfunction(add_test_executable TEST_NAME) endfunction()
include(GoogleTest)
function(add_gtest_executable TEST_NAME) function(add_gtest_executable TEST_NAME)
message("adding gtest ${TEST_NAME}") message("adding gtest ${TEST_NAME}")
...@@ -109,20 +108,21 @@ function(add_gtest_executable TEST_NAME) ...@@ -109,20 +108,21 @@ function(add_gtest_executable TEST_NAME)
# suppress gtest warnings # suppress gtest warnings
target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef)
target_link_libraries(${TEST_NAME} PRIVATE gtest_main) target_link_libraries(${TEST_NAME} PRIVATE gtest_main getopt::getopt)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>) add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
set(result 0) set(result 0)
endif() endif()
#message("add_gtest returns ${result}") #message("add_gtest returns ${result}")
set(result ${result} PARENT_SCOPE) set(result ${result} PARENT_SCOPE)
endfunction(add_gtest_executable TEST_NAME) endfunction()
add_subdirectory(magic_number_division) add_subdirectory(magic_number_division)
add_subdirectory(space_filling_curve) add_subdirectory(space_filling_curve)
add_subdirectory(conv_util) add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd) add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm) add_subdirectory(gemm)
add_subdirectory(gemm_add)
add_subdirectory(gemm_layernorm) add_subdirectory(gemm_layernorm)
add_subdirectory(gemm_split_k) add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce) add_subdirectory(gemm_reduce)
...@@ -140,6 +140,8 @@ add_subdirectory(grouped_convnd_bwd_weight) ...@@ -140,6 +140,8 @@ add_subdirectory(grouped_convnd_bwd_weight)
add_subdirectory(block_to_ctile_map) add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax) add_subdirectory(softmax)
add_subdirectory(normalization_fwd) add_subdirectory(normalization_fwd)
add_subdirectory(normalization_bwd_data)
add_subdirectory(normalization_bwd_gamma_beta)
add_subdirectory(data_type) add_subdirectory(data_type)
add_subdirectory(elementwise_normalization) add_subdirectory(elementwise_normalization)
add_subdirectory(batchnorm) add_subdirectory(batchnorm)
...@@ -149,6 +151,8 @@ add_subdirectory(batched_gemm_multi_d) ...@@ -149,6 +151,8 @@ add_subdirectory(batched_gemm_multi_d)
add_subdirectory(grouped_convnd_bwd_data) add_subdirectory(grouped_convnd_bwd_data)
add_subdirectory(conv_tensor_rearrange) add_subdirectory(conv_tensor_rearrange)
add_subdirectory(transpose) add_subdirectory(transpose)
add_subdirectory(permute_scale)
add_subdirectory(wrapper)
if(GPU_TARGETS MATCHES "gfx11") if(GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op) add_subdirectory(wmma_op)
endif() endif()
...@@ -135,6 +135,8 @@ class TestConvTensorRearrangeInterface : public ::testing::Test ...@@ -135,6 +135,8 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
return col2img.IsSupportedArgument(argument); return col2img.IsSupportedArgument(argument);
} }
throw std::runtime_error("Conv_tensor_rearrange: problem with tensor rearrange operator. ");
return 1;
} }
}; };
......
add_gtest_executable(test_gemm_add test_gemm_add.hpp)
target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance)
add_gtest_executable(test_gemm_add_relu test_gemm_add_relu.cpp)
target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance)
add_gtest_executable(test_gemm_add_silu test_gemm_add_silu.cpp)
target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu.cpp)
target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance)
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_impl.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using I8 = int8_t;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
template <typename Tuple>
class TestGemmAdd : public ::testing::Test
{
protected:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using D0DataType = std::tuple_element_t<3, Tuple>;
using EDataType = std::tuple_element_t<4, Tuple>;
using ALayout = std::tuple_element_t<5, Tuple>;
using BLayout = std::tuple_element_t<6, Tuple>;
using D0Layout = std::tuple_element_t<7, Tuple>;
using ELayout = std::tuple_element_t<8, Tuple>;
constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
ELayout>;
virtual decltype(ProfileGemmAddImpl) GetImpl() { return ProfileGemmAddImpl; }
void Run()
{
std::vector<std::vector<ck::index_t>> lengths = {
{16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}};
bool all_success = true;
for(auto length : lengths)
{
int M = length[0];
int N = length[1];
int K = length[2];
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
int StrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
all_success =
all_success &
GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE);
}
EXPECT_TRUE(all_success);
}
};
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
TYPED_TEST_SUITE(TestGemmAdd, KernelTypes);
TYPED_TEST(TestGemmAdd, Test_BF16FP16_INT8) { this->Run(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_fastgelu_impl.hpp"
#include "test_gemm_add.hpp"
template <typename Tuple>
class TestGemmAddFastgelu : public TestGemmAdd<Tuple>
{
private:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using D0DataType = std::tuple_element_t<3, Tuple>;
using EDataType = std::tuple_element_t<4, Tuple>;
using ALayout = std::tuple_element_t<5, Tuple>;
using BLayout = std::tuple_element_t<6, Tuple>;
using D0Layout = std::tuple_element_t<7, Tuple>;
using ELayout = std::tuple_element_t<8, Tuple>;
constexpr static auto ProfileGemmAddFastgeluImpl =
ck::profiler::profile_gemm_add_fastgelu_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
ELayout>;
decltype(ProfileGemmAddFastgeluImpl) GetImpl() override { return ProfileGemmAddFastgeluImpl; }
};
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes);
TYPED_TEST(TestGemmAddFastgelu, Test_BF16FP16) { this->Run(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_relu_impl.hpp"
#include "test_gemm_add.hpp"
template <typename Tuple>
class TestGemmAddRelu : public TestGemmAdd<Tuple>
{
private:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using D0DataType = std::tuple_element_t<3, Tuple>;
using EDataType = std::tuple_element_t<4, Tuple>;
using ALayout = std::tuple_element_t<5, Tuple>;
using BLayout = std::tuple_element_t<6, Tuple>;
using D0Layout = std::tuple_element_t<7, Tuple>;
using ELayout = std::tuple_element_t<8, Tuple>;
constexpr static auto ProfileGemmAddReluImpl =
ck::profiler::profile_gemm_add_relu_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
ELayout>;
decltype(ProfileGemmAddReluImpl) GetImpl() override { return ProfileGemmAddReluImpl; }
};
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes);
TYPED_TEST(TestGemmAddRelu, Test_BF16FP16_INT8) { this->Run(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_silu_impl.hpp"
#include "test_gemm_add.hpp"
template <typename Tuple>
class TestGemmAddSilu : public TestGemmAdd<Tuple>
{
private:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using D0DataType = std::tuple_element_t<3, Tuple>;
using EDataType = std::tuple_element_t<4, Tuple>;
using ALayout = std::tuple_element_t<5, Tuple>;
using BLayout = std::tuple_element_t<6, Tuple>;
using D0Layout = std::tuple_element_t<7, Tuple>;
using ELayout = std::tuple_element_t<8, Tuple>;
constexpr static auto ProfileGemmAddSiluImpl =
ck::profiler::profile_gemm_add_silu_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
ELayout>;
decltype(ProfileGemmAddSiluImpl) GetImpl() override { return ProfileGemmAddSiluImpl; }
};
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
TYPED_TEST_SUITE(TestGemmAddSilu, KernelTypes);
TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_INT8) { this->Run(); }
...@@ -60,7 +60,9 @@ class TestGemmSplitK : public testing::Test ...@@ -60,7 +60,9 @@ class TestGemmSplitK : public testing::Test
const int StrideA, const int StrideA,
const int StrideB, const int StrideB,
const int StrideC, const int StrideC,
int kbatch = 1) int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{ {
bool pass = ck::profiler::profile_gemm_splitk_impl<ADataType, bool pass = ck::profiler::profile_gemm_splitk_impl<ADataType,
BDataType, BDataType,
...@@ -68,8 +70,19 @@ class TestGemmSplitK : public testing::Test ...@@ -68,8 +70,19 @@ class TestGemmSplitK : public testing::Test
CDataType, CDataType,
ALayout, ALayout,
BLayout, BLayout,
CLayout>( CLayout>(verify_,
verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, kbatch); init_method_,
log_,
bench_,
M,
N,
K,
StrideA,
StrideB,
StrideC,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
}; };
......
...@@ -55,10 +55,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -55,10 +55,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
} }
} }
const bool is_navi3x = ck::get_device_name() == "gfx1100" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102";
if(is_navi3x)
{ {
// on navi3x only support for 3d is implemented // on navi3x only support for 3d is implemented
if constexpr(NDimSpatial{} != 3) if constexpr(NDimSpatial{} != 3)
......
...@@ -63,7 +63,9 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -63,7 +63,9 @@ class TestGroupedGemm : public testing::TestWithParam<int>
const std::vector<int>& StrideAs, const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs, const std::vector<int>& StrideCs,
int kbatch = 1) int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{ {
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType, bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType, BDataType,
...@@ -71,8 +73,19 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -71,8 +73,19 @@ class TestGroupedGemm : public testing::TestWithParam<int>
float, float,
ALayout, ALayout,
BLayout, BLayout,
ELayout>( ELayout>(verify_,
verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch); init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
}; };
......
add_custom_target(test_normalization_bwd_data)
add_gtest_executable(test_layernorm2d_bwd_data_fp32 test_layernorm2d_bwd_data_fp32.cpp)
if(result EQUAL 0)
target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance)
add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32)
endif()
add_gtest_executable(test_groupnorm_bwd_data_fp32 test_groupnorm_bwd_data_fp32.cpp)
if(result EQUAL 0)
target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance)
add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_groupnorm_bwd_data_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using ck::index_t;
template <typename Tuple>
class TestgroupnormBwdData : public ::testing::Test
{
protected:
using DYDataType = std::tuple_element_t<0, Tuple>;
using XDataType = std::tuple_element_t<1, Tuple>;
using GammaDataType = std::tuple_element_t<2, Tuple>;
using MeanInvStdDataType = std::tuple_element_t<3, Tuple>;
using ComputeDataType = std::tuple_element_t<4, Tuple>;
using DXDataType = std::tuple_element_t<5, Tuple>;
void Run()
{
// Bwd data: [N, H, W, G, C], reduce H, W, C
std::vector<std::vector<ck::index_t>> lengths = {{1, 1, 1, 1, 1},
{1, 2, 3, 4, 5},
{256, 9, 9, 9, 9},
{1, 64, 64, 32, 10},
{1, 32, 32, 32, 20},
{1, 16, 16, 32, 40}};
for(auto length : lengths)
{
bool success = ck::profiler::profile_groupnorm_bwd_data_impl<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType>(
true, 2, false, false, length);
EXPECT_TRUE(success);
}
}
};
using KernelTypes = ::testing::Types<
// DYDataType XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType>
std::tuple<F32, F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestgroupnormBwdData, KernelTypes);
TYPED_TEST(TestgroupnormBwdData, Test_FP32) { this->Run(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_layernorm_bwd_data_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using ck::index_t;
template <typename Tuple>
class TestLayernorm2dBwdData : public ::testing::Test
{
protected:
using DYDataType = std::tuple_element_t<0, Tuple>;
using XDataType = std::tuple_element_t<1, Tuple>;
using GammaDataType = std::tuple_element_t<2, Tuple>;
using MeanInvStdDataType = std::tuple_element_t<3, Tuple>;
using ComputeDataType = std::tuple_element_t<4, Tuple>;
using DXDataType = std::tuple_element_t<5, Tuple>;
void Run()
{
// Bwd data: [N, D], reduce D
std::vector<std::vector<ck::index_t>> lengths = {
{4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}};
for(auto length : lengths)
{
bool success =
ck::profiler::profile_layernorm_bwd_data_impl<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
2>(true, 2, false, false, length);
EXPECT_TRUE(success);
}
}
};
using KernelTypes = ::testing::Types<
// DYDataType XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType>
std::tuple<F32, F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestLayernorm2dBwdData, KernelTypes);
TYPED_TEST(TestLayernorm2dBwdData, Test_FP32) { this->Run(); }
add_custom_target(test_normalization_bwd_gamma_beta)
add_gtest_executable(test_layernorm2d_bwd_gamma_beta_fp32 test_layernorm2d_bwd_gamma_beta_fp32.cpp)
if(result EQUAL 0)
target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance)
add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32)
endif()
add_gtest_executable(test_groupnorm_bwd_gamma_beta_fp32 test_groupnorm_bwd_gamma_beta_fp32.cpp)
if(result EQUAL 0)
target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance)
add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using ck::index_t;
template <typename Tuple>
class TestgroupnormBwdGammaBeta : public ::testing::Test
{
protected:
using DYDataType = std::tuple_element_t<0, Tuple>;
using XDataType = std::tuple_element_t<1, Tuple>;
using MeanInvStdDataType = std::tuple_element_t<2, Tuple>;
using ComputeDataType = std::tuple_element_t<3, Tuple>;
using DGammaDataType = std::tuple_element_t<4, Tuple>;
using DBetaDataType = std::tuple_element_t<5, Tuple>;
void Run()
{
// Bwd data: [N, H, W, G, C], reduce H, W, C
std::vector<std::vector<ck::index_t>> lengths = {{1, 1, 1, 1, 1},
{1, 2, 3, 4, 5},
{256, 9, 9, 9, 9},
{1, 64, 64, 32, 10},
{1, 32, 32, 32, 20},
{1, 16, 16, 32, 40}};
for(auto length : lengths)
{
bool success = ck::profiler::profile_groupnorm_bwd_gamma_beta_impl<DYDataType,
XDataType,
MeanInvStdDataType,
ComputeDataType,
DGammaDataType,
DBetaDataType>(
true, 2, false, false, length);
EXPECT_TRUE(success);
}
}
};
using KernelTypes = ::testing::Types<
// DYDataType XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType>
std::tuple<F32, F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestgroupnormBwdGammaBeta, KernelTypes);
TYPED_TEST(TestgroupnormBwdGammaBeta, Test_FP32) { this->Run(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_layernorm_bwd_gamma_beta_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using ck::index_t;
template <typename Tuple>
class TestLayernorm2dBwdGammaBeta : public ::testing::Test
{
protected:
using DYDataType = std::tuple_element_t<0, Tuple>;
using XDataType = std::tuple_element_t<1, Tuple>;
using MeanInvStdDataType = std::tuple_element_t<2, Tuple>;
using ComputeDataType = std::tuple_element_t<3, Tuple>;
using DGammaDataType = std::tuple_element_t<4, Tuple>;
using DBetaDataType = std::tuple_element_t<5, Tuple>;
void Run()
{
// Bwd data: [N, D], reduce D
std::vector<std::vector<ck::index_t>> lengths = {
{4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}};
for(auto length : lengths)
{
bool success = ck::profiler::profile_layernorm_bwd_gamma_beta_impl<DYDataType,
XDataType,
MeanInvStdDataType,
ComputeDataType,
DGammaDataType,
DBetaDataType,
2>(
true, 2, false, false, length);
EXPECT_TRUE(success);
}
}
};
using KernelTypes = ::testing::Types<
// DYDataType XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType>
std::tuple<F32, F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestLayernorm2dBwdGammaBeta, KernelTypes);
TYPED_TEST(TestLayernorm2dBwdGammaBeta, Test_FP32) { this->Run(); }
...@@ -47,8 +47,8 @@ class TestGroupnorm : public ::testing::Test ...@@ -47,8 +47,8 @@ class TestGroupnorm : public ::testing::Test
}; };
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType>
std::tuple<F16, F16, F16, F32, F16, F32>>; std::tuple<F16, F16, F16, F32, F16, F16>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
TYPED_TEST(TestGroupnorm, Test_FP16) { this->Run(); } TYPED_TEST(TestGroupnorm, Test_FP16) { this->Run(); }
...@@ -45,7 +45,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -45,7 +45,7 @@ class TestGroupnorm : public ::testing::Test
}; };
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType>
std::tuple<F32, F32, F32, F32, F32, F32>>; std::tuple<F32, F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
......
...@@ -41,8 +41,8 @@ class TestLayernorm2d : public ::testing::Test ...@@ -41,8 +41,8 @@ class TestLayernorm2d : public ::testing::Test
}; };
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType>
std::tuple<F16, F16, F16, F32, F16, F32>>; std::tuple<F16, F16, F16, F32, F16, F16>>;
TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes);
TYPED_TEST(TestLayernorm2d, Test_FP16) { this->Run(); } TYPED_TEST(TestLayernorm2d, Test_FP16) { this->Run(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment