Commit 200dd06b authored by wangshaojie6's avatar wangshaojie6
Browse files

add test file

parent 6e0a93d2
add_custom_target(test_batched_gemm_softmax_gemm) add_custom_target(test_batched_gemm_masking_scale_softmax_gemm_permute)
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_masking_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) add_dependencies(test_batched_gemm_masking_scale_softmax_gemm_permute test_batched_gemm_softmax_gemm_fp16)
\ No newline at end of file \ No newline at end of file
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
#include <vector> #include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp" #include "profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp"
using ck::tensor_operation::device::GemmSpecialization; using ck::tensor_operation::device::GemmSpecialization;
template <ck::index_t N> template <ck::index_t N>
...@@ -20,37 +20,37 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -20,37 +20,37 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple> template <typename Tuple>
struct TestBatchedGemmSoftmaxGemm : public ::testing::Test struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
{ {
using ADataType = std::tuple_element_t<0, Tuple>; using ADataType = std::tuple_element_t<0, Tuple>;
using B0DataType = std::tuple_element_t<1, Tuple>; using B0DataType = std::tuple_element_t<1, Tuple>;
using B1DataType = std::tuple_element_t<2, Tuple>; using B1DataType = std::tuple_element_t<2, Tuple>;
using CDataType = std::tuple_element_t<3, Tuple>; using CDataType = std::tuple_element_t<3, Tuple>;
using ALayout = std::tuple_element_t<4, Tuple>; using ALayout = std::tuple_element_t<4, Tuple>;
using B0Layout = std::tuple_element_t<5, Tuple>; using B0Layout = std::tuple_element_t<5, Tuple>;
using B1Layout = std::tuple_element_t<6, Tuple>; using B1Layout = std::tuple_element_t<6, Tuple>;
using CLayout = std::tuple_element_t<7, Tuple>; using CPermuteNumDims_G_M_O = std::tuple_element_t<7, Tuple>;
std::vector<std::vector<int>> lengths_ = { std::vector<std::vector<int>> lengths_ = {
{256, 256, 64, 64, 4}, {256, 256, 64, 64, 6, 4},
{256, 256, 128, 128, 4}, {256, 256, 128, 128, 4, 6},
{512, 512, 64, 64, 2}, {512, 512, 64, 64, 3, 2},
{512, 512, 128, 128, 2}, {512, 512, 128, 128, 2, 3},
{1024, 1024, 64, 64, 1}, {1024, 1024, 64, 64, 3, 1},
{1024, 1024, 128, 128, 1}, {1024, 1024, 128, 128, 1, 1},
}; };
bool bench_ = false; bool bench_ = false;
bool verify_ = true; bool verify_ = true;
void RunSingle(int M, int N, int K, int O, int BatchCount) void RunSingle(int M, int N, int K, int O, int G0, int G1)
{ {
bool pass = ck::profiler::profile_batched_gemm_softmax_gemm_impl<ADataType, bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<ADataType,
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
ALayout, ALayout,
B0Layout, B0Layout,
B1Layout, B1Layout,
CLayout>( CPermuteNumDims_G_M_O>(
verify_, 1, false, bench_, M, N, K, O, BatchCount); verify_, 1, false, bench_, M, N, K, O, G0, G1);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
...@@ -63,9 +63,10 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test ...@@ -63,9 +63,10 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
int N = lengths[1]; int N = lengths[1];
int K = lengths[2]; int K = lengths[2];
int O = lengths[3]; int O = lengths[3];
int BatchCount = lengths[4]; int G0 = lengths[4];
int G1 = lengths[5];
this->RunSingle(M, N, K, O, BatchCount); this->RunSingle(M, N, K, O, G0, G1);
} }
} }
}; };
...@@ -74,36 +75,38 @@ template <GemmSpecialization GemmSpec> ...@@ -74,36 +75,38 @@ template <GemmSpecialization GemmSpec>
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{ {
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using ALayout = Row; using ALayout = Row;
using B0Layout = Col; using B0Layout = Col;
using B1Layout = Row; using B1Layout = Row;
using CLayout = Row;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using CPermuteNumDims_G_M_O =
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using ADataType = F16; using ADataType = F16;
using B0DataType = F16; using B0DataType = F16;
using B1DataType = F16; using B1DataType = F16;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = F16;
using CDataType = F16; using CDataType = F16;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using B0ElementOp = PassThrough; using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough; using Acc0ElementOp = Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value; // static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using DeviceGemmGemmInstance = using DeviceGemmGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
ALayout, ALayout,
B0Layout, B0Layout,
B1Layout, B1Layout,
CLayout, CPermuteNumDims_G_M_O,
ADataType, ADataType,
B0DataType, B0DataType,
B1DataType, B1DataType,
...@@ -155,7 +158,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 ...@@ -155,7 +158,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
true>; // Masking
bool IsSupported(int M, int N, int K, int O) bool IsSupported(int M, int N, int K, int O)
{ {
...@@ -170,6 +174,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 ...@@ -170,6 +174,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
K, K,
O, O,
0, // BatchCount 0, // BatchCount
{0, 0, M, O}, // gs ms ns lengths
{0, O, 0, 1}, // gs ms ns strides
0, // StrideA 0, // StrideA
0, // StrideB0 0, // StrideB0
0, // StrideB1 0, // StrideB1
...@@ -180,7 +186,7 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 ...@@ -180,7 +186,7 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
0, // BatchStrideC 0, // BatchStrideC
PassThrough{}, // a_element_op PassThrough{}, // a_element_op
PassThrough{}, // b0_element_op PassThrough{}, // b0_element_op
PassThrough{}, // acc0_element_op Scale{}, // acc0_element_op
PassThrough{}, // b1_element_op PassThrough{}, // b1_element_op
PassThrough{}); // c_element_op PassThrough{}); // c_element_op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment