Unverified Commit 5aa3c344 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Merge branch 'develop' into gemm_layernorm_welford

parents 7fefc966 9d8f834a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp"
using ck::tensor_operation::device::GemmSpecialization;
template <ck::index_t N>
using I = ck::Number<N>;
using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
{
using ADataType = std::tuple_element_t<0, Tuple>;
using B0DataType = std::tuple_element_t<1, Tuple>;
using B1DataType = std::tuple_element_t<2, Tuple>;
using CDataType = std::tuple_element_t<3, Tuple>;
using ALayout = std::tuple_element_t<4, Tuple>;
using B0Layout = std::tuple_element_t<5, Tuple>;
using B1Layout = std::tuple_element_t<6, Tuple>;
using CPermuteNumDims_G_M_O = std::tuple_element_t<7, Tuple>;
std::vector<std::vector<int>> lengths_ = {
{256, 256, 64, 64, 6, 4},
{256, 256, 128, 128, 4, 6},
{512, 512, 64, 64, 3, 2},
{512, 512, 128, 128, 2, 3},
{1024, 1024, 64, 64, 3, 1},
{1024, 1024, 128, 128, 1, 1},
};
bool bench_ = false;
bool verify_ = true;
void RunSingle(int M, int N, int K, int O, int G0, int G1)
{
bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<
ADataType,
B0DataType,
B1DataType,
CDataType,
ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_O>(verify_, 1, false, bench_, M, N, K, O, G0, G1);
EXPECT_TRUE(pass);
}
void Run()
{
for(auto lengths : this->lengths_)
{
int M = lengths[0];
int N = lengths[1];
int K = lengths[2];
int O = lengths[3];
int G0 = lengths[4];
int G1 = lengths[5];
this->RunSingle(M, N, K, O, G0, G1);
}
}
};
template <GemmSpecialization GemmSpec>
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using CPermuteNumDims_G_M_O =
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = float;
using CShuffleDataType = F16;
using CDataType = F16;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using DeviceGemmGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_O,
ADataType,
B0DataType,
B1DataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
true>; // Masking
bool IsSupported(int M, int N, int K, int O)
{
auto gemm = DeviceGemmGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
static_cast<B0DataType*>(nullptr),
static_cast<B1DataType*>(nullptr),
static_cast<CDataType*>(nullptr),
M,
N,
K,
O,
0, // BatchCount
{0, 0, M, O}, // gs ms ns lengths
{0, O, 0, 1}, // gs ms ns strides
0, // StrideA
0, // StrideB0
0, // StrideB1
0, // BatchStrideA
0, // BatchStrideB0
0, // BatchStrideB1
PassThrough{}, // a_element_op
PassThrough{}, // b0_element_op
Scale{1.f}, // acc0_element_op
PassThrough{}, // b1_element_op
PassThrough{}); // c_element_op
return gemm.IsSupportedArgument(argument);
}
};
......@@ -105,6 +105,19 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16_IrregularK)
{
this->lengths_ = std::vector<std::vector<int>>{{256, 256, 160, 160, 16},
{256, 64, 160, 64, 16},
{1024, 1024, 80, 80, 16},
{1024, 64, 80, 64, 16},
{4096, 4096, 40, 40, 16},
{4096, 64, 40, 64, 16}};
this->bench_ = true;
this->verify_ = false;
this->Run();
}
using ck::tensor_operation::device::GemmSpecialization;
// TODO: enable KPadding tests when it is implemented
......@@ -118,19 +131,19 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch)
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
// clang-format on
}
......
......@@ -29,14 +29,19 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
using B1Layout = std::tuple_element_t<6, Tuple>;
using CLayout = std::tuple_element_t<7, Tuple>;
std::vector<std::vector<int>> lengths_ = {
{256, 256, 64, 64, 4},
{256, 256, 128, 128, 4},
{512, 512, 64, 64, 2},
{512, 512, 128, 128, 2},
{1024, 1024, 64, 64, 1},
{1024, 1024, 128, 128, 1},
};
std::vector<std::vector<int>> lengths_ = {{256, 256, 64, 64, 4},
{256, 256, 128, 128, 4},
{512, 512, 64, 64, 2},
{512, 512, 128, 128, 2},
{1024, 1024, 64, 64, 1},
{1024, 1024, 128, 128, 1},
{256, 256, 160, 160, 4},
{256, 64, 160, 64, 4},
{1024, 1024, 80, 80, 2},
{1024, 64, 80, 64, 2},
{4096, 4096, 40, 40, 1},
{4096, 64, 40, 64, 1}};
bool bench_ = false;
bool verify_ = true;
......@@ -155,7 +160,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
false>;
bool IsSupported(int M, int N, int K, int O)
{
......
add_custom_target(test_layernorm)
add_gtest_executable(test_layernorm_fp32 test_layernorm_fp32.cpp)
add_gtest_executable(test_layernorm_fp16 test_layernorm_fp16.cpp)
add_gtest_executable(test_layernorm2d_fp32 test_layernorm2d_fp32.cpp)
add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp)
add_gtest_executable(test_groupnorm_fp16 test_groupnorm_fp16.cpp)
add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp)
target_link_libraries(test_layernorm_fp32 PRIVATE utility)
target_link_libraries(test_layernorm_fp16 PRIVATE utility)
target_link_libraries(test_layernorm2d_fp32 PRIVATE utility)
target_link_libraries(test_layernorm2d_fp16 PRIVATE utility)
target_link_libraries(test_groupnorm_fp16 PRIVATE utility device_normalization_instance)
target_link_libraries(test_groupnorm_fp32 PRIVATE utility device_normalization_instance)
add_dependencies(test_layernorm test_layernorm2d_fp32)
add_dependencies(test_layernorm test_layernorm2d_fp16)
add_dependencies(test_layernorm test_groupnorm_fp16)
add_dependencies(test_layernorm test_groupnorm_fp32)
add_dependencies(test_layernorm test_layernorm_fp32)
add_dependencies(test_layernorm test_layernorm_fp16)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/include/profile_groupnorm_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using ck::index_t;
template <typename Tuple>
class TestGroupnorm : public ::testing::Test
{
protected:
using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>;
using AccDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>;
void Run()
{
// N, H, W, G, C
std::vector<std::vector<ck::index_t>> lengths = {{1, 1, 1, 1, 1},
{1, 2, 3, 4, 5},
{256, 9, 9, 9, 9},
{1, 64, 64, 32, 10},
{1, 32, 32, 32, 20},
{1, 16, 16, 32, 40}};
for(auto length : lengths)
{
bool success =
ck::profiler::profile_groupnorm_impl<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
YDataType>(true, 2, false, false, length);
EXPECT_TRUE(success);
}
}
};
using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType>
std::tuple<F16, F16, F16, F32, F16>,
std::tuple<F16, F16, F16, F32, F16>,
std::tuple<F16, F16, F16, F32, F16>,
std::tuple<F16, F16, F16, F32, F16>,
std::tuple<F16, F16, F16, F32, F16>,
std::tuple<F16, F16, F16, F32, F16>,
std::tuple<F16, F16, F16, F32, F16>,
std::tuple<F16, F16, F16, F32, F16>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
TYPED_TEST(TestGroupnorm, Test_FP16) { this->Run(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/include/profile_groupnorm_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using ck::index_t;
template <typename Tuple>
class TestGroupnorm : public ::testing::Test
{
protected:
using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>;
using AccDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>;
void Run()
{
// N, H, W, G, C
std::vector<std::vector<ck::index_t>> lengths = {{1, 1, 1, 1, 1},
{1, 2, 3, 4, 5},
{256, 9, 9, 9, 9},
{1, 64, 64, 32, 10},
{1, 32, 32, 32, 20},
{1, 16, 16, 32, 40}};
for(auto length : lengths)
{
bool success =
ck::profiler::profile_groupnorm_impl<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
YDataType>(true, 2, false, false, length);
EXPECT_TRUE(success);
}
}
};
using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType>
std::tuple<F32, F32, F32, F32, F32>,
std::tuple<F32, F32, F32, F32, F32>,
std::tuple<F32, F32, F32, F32, F32>,
std::tuple<F32, F32, F32, F32, F32>,
std::tuple<F32, F32, F32, F32, F32>,
std::tuple<F32, F32, F32, F32, F32>,
std::tuple<F32, F32, F32, F32, F32>,
std::tuple<F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
TYPED_TEST(TestGroupnorm, Test_FP32) { this->Run(); }
......@@ -2,28 +2,28 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_layernorm_util.hpp"
#include "test_layernorm2d_util.hpp"
template <ck::index_t N>
using I = ck::Number<N>;
template <typename Tuple>
class TestLayernormFP16 : public ck::TestLayernorm<Tuple>
class TestLayernorm2dFP16 : public ck::TestLayernorm2d<Tuple>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, , GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<8>, I<32>, I<2>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<4>, I<64>, I<2>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<2>, I<128>, I<2>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<1>, I<256>, I<2>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim , GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<8>, I<32>, I<2>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<4>, I<64>, I<2>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<2>, I<128>, I<2>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<1>, I<256>, I<2>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>
>;
// clang-format on
TYPED_TEST_SUITE(TestLayernormFP16, KernelTypes);
TYPED_TEST(TestLayernormFP16, Test_FP16) { this->Run(); }
TYPED_TEST_SUITE(TestLayernorm2dFP16, KernelTypes);
TYPED_TEST(TestLayernorm2dFP16, Test_FP16) { this->Run(); }
......@@ -2,28 +2,28 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_layernorm_util.hpp"
#include "test_layernorm2d_util.hpp"
template <ck::index_t N>
using I = ck::Number<N>;
template <typename Tuple>
class TestLayernormFP32 : public ck::TestLayernorm<Tuple>
class TestLayernorm2dFP32 : public ck::TestLayernorm2d<Tuple>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, , GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<8>, I<32>, I<2>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<4>, I<64>, I<2>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<2>, I<128>, I<2>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<1>, I<256>, I<2>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<8>, I<32>, I<2>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<4>, I<64>, I<2>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<2>, I<128>, I<2>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<1>, I<256>, I<2>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>
>;
// clang-format on
TYPED_TEST_SUITE(TestLayernormFP32, KernelTypes);
TYPED_TEST(TestLayernormFP32, Test_FP32) { this->Run(); }
TYPED_TEST_SUITE(TestLayernorm2dFP32, KernelTypes);
TYPED_TEST(TestLayernorm2dFP32, Test_FP32) { this->Run(); }
......@@ -31,7 +31,7 @@ std::string serialize_range(const Range& range)
}
template <typename Tuple>
class TestLayernorm : public ::testing::Test
class TestLayernorm2d : public ::testing::Test
{
protected:
using XDataType = std::tuple_element_t<0, Tuple>;
......@@ -48,9 +48,11 @@ class TestLayernorm : public ::testing::Test
static constexpr index_t KThreadSliceSize = std::tuple_element_t<11, Tuple>{}.value;
static constexpr index_t XYSrcVectorDim = std::tuple_element_t<12, Tuple>{}.value;
static constexpr index_t XSrcVectorSize = std::tuple_element_t<13, Tuple>{}.value;
static constexpr index_t GammaSrcVectorSize = std::tuple_element_t<14, Tuple>{}.value;
static constexpr index_t BetaSrcVectorSize = std::tuple_element_t<15, Tuple>{}.value;
static constexpr index_t YDstVectorSize = std::tuple_element_t<16, Tuple>{}.value;
static constexpr index_t GammaSrcVectorDim = std::tuple_element_t<14, Tuple>{}.value;
static constexpr index_t GammaSrcVectorSize = std::tuple_element_t<15, Tuple>{}.value;
static constexpr index_t BetaSrcVectorDim = std::tuple_element_t<16, Tuple>{}.value;
static constexpr index_t BetaSrcVectorSize = std::tuple_element_t<17, Tuple>{}.value;
static constexpr index_t YDstVectorSize = std::tuple_element_t<18, Tuple>{}.value;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......@@ -78,23 +80,24 @@ class TestLayernorm : public ::testing::Test
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
YDstVectorSize>;
TestLayernorm() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {}
TestLayernorm2d() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {}
void RunSingle(std::vector<index_t> lengths, std::vector<index_t> reduceDims)
void RunSingle(const std::vector<index_t>& lengths,
const std::vector<index_t>& reduceDims,
const std::vector<index_t>& GammaLength,
const std::vector<index_t>& GammaStride,
const std::vector<index_t>& BetaLength,
const std::vector<index_t>& BetaStride)
{
std::vector<index_t> reduceLength(reduceDims.size());
for(int i = 0; i < NumReduceDim; ++i)
{
reduceLength[i] = lengths[reduceDims[i]];
}
Tensor<XDataType> x(lengths);
Tensor<GammaDataType> gamma(reduceLength);
Tensor<BetaDataType> beta(reduceLength);
Tensor<GammaDataType> gamma(GammaLength);
Tensor<BetaDataType> beta(BetaLength);
Tensor<YDataType> y(lengths);
Tensor<YDataType> y_ref(lengths);
......@@ -115,10 +118,8 @@ class TestLayernorm : public ::testing::Test
auto argument_ptr = device_instance.MakeArgumentPointer(
lengths,
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
std::vector<ck::index_t>{gamma.mDesc.GetStrides().begin(),
gamma.mDesc.GetStrides().end()},
std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(),
beta.mDesc.GetStrides().end()},
GammaStride,
BetaStride,
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
reduceDims,
1e-4,
......@@ -163,17 +164,16 @@ class TestLayernorm : public ::testing::Test
void Run()
{
for(auto length : this->lengths_)
std::vector<std::vector<index_t>> lengths = {
{4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}};
for(auto length : lengths)
{
this->RunSingle(length, reduceDims_[0]);
this->RunSingle(length, {1}, {length[1]}, {0, 1}, {length[1]}, {0, 1});
}
}
std::vector<std::vector<index_t>> lengths_ = {
{4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}};
std::vector<std::vector<index_t>> reduceDims_ = {{1}};
typename ReferenceInstance::Invoker ref_instance_invoker_;
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment