Commit 97dcc7b2 authored by wangshaojie6's avatar wangshaojie6
Browse files

add gtest for bmm masking scale softmax bmm permute

parent 200dd06b
......@@ -50,7 +50,7 @@ struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
B0Layout,
B1Layout,
CLayout,
CPermuteNumDims_G_M_Gemm1N,
ADataType,
B0DataType,
B1DataType,
......@@ -83,7 +83,7 @@ struct DeviceOperationInstanceFactory<
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
is_same_v<B1Layout, Row> && is_same_v<CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_O>)
{
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
......
add_instance_library(device_batched_gemm_masking_softmax_gemm_permute_instance
device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
add_instance_library(device_batched_gemm_masking_scale_softmax_gemm_permute_instance
device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
)
......@@ -43,11 +43,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
int StrideA = -1,
int StrideB0 = -1,
int StrideB1 = -1,
int StrideC = -1,
int BatchStrideA = -1,
int BatchStrideB0 = -1,
int BatchStrideB1 = -1,
int BatchStrideC = -1,
float alpha = 1.f)
{
......@@ -93,22 +91,18 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? O : M;
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Col> ? O : M) * StrideC;
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
const int BatchCount = G0 * G1;
......@@ -198,7 +192,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
......@@ -227,7 +221,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{});
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, Scale{alpha});
ref_gemm0_invoker.Run(ref_gemm0_argument);
......@@ -272,20 +266,20 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_gs_ms_os_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc0_element_op,
......@@ -323,10 +317,10 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
if(do_verification)
{
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
pass = pass &
ck::utils::check_err(c_g_m_o_device_result.mData, c_gs_ms_os_host_result.mData);
ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData);
if(do_log)
{
......@@ -340,7 +334,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_g_m_o_device_result : ", c_g_m_o_device_result.mData, ",")
std::cout << "c_gs_ms_os_device_result : ", c_gs_ms_os_device_result.mData, ",")
<< std::endl;
}
}
......
......@@ -42,6 +42,7 @@ add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_masking_scale_softmax_gemm_permute)
add_subdirectory(grouped_gemm)
add_subdirectory(reduce)
add_subdirectory(convnd_fwd)
......
add_custom_target(test_batched_gemm_masking_scale_softmax_gemm_permute)
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_masking_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_masking_scale_softmax_gemm_permute test_batched_gemm_softmax_gemm_fp16)
\ No newline at end of file
add_gtest_executable(test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp)
target_link_libraries(test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_masking_scale_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_masking_scale_softmax_gemm_permute test_batched_gemm_masking_scale_softmax_gemm_permute_fp16)
\ No newline at end of file
......@@ -2,103 +2,107 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_softmax_gemm_util.hpp"
#include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp"
template <typename Tuple>
class TestBatchedGemmSoftmaxGemmFP16 : public TestBatchedGemmSoftmaxGemm<Tuple>
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16 : public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple>
{
};
// clang-format off
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using CPermuteNumDims_G_M_O =
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using KernelTypes = ::testing::Types<
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row>
std::tuple<F16, F16, F16, F16, Row, Col, Row, CPermuteNumDims_G_M_O>
>;
// clang-format on
TYPED_TEST_SUITE(TestBatchedGemmSoftmaxGemmFP16, KernelTypes);
TYPED_TEST_SUITE(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, KernelTypes);
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16) { this->Run(); }
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16) { this->Run(); }
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_PadM)
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadM)
{
this->lengths_ = std::vector<std::vector<int>>{
{136, 128, 32, 128, 1},
{136, 128, 32, 128, 2, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_PadN)
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 136, 32, 128, 1},
{128, 136, 32, 128, 3, 2},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_PadK)
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 40, 128, 1},
{128, 128, 136, 128, 1},
{128, 128, 40, 128, 2, 4},
{128, 128, 136, 128, 4, 2},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_PadO)
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 136, 1},
{128, 128, 32, 136, 1, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_OddM)
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddM)
{
this->lengths_ = std::vector<std::vector<int>>{
{129, 128, 32, 128, 1},
{129, 128, 32, 128, 2, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_OddN)
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 129, 32, 128, 1},
{128, 129, 32, 128, 4, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_OddK)
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 33, 128, 1},
{128, 128, 129, 128, 1},
{128, 128, 33, 128, 2, 3},
{128, 128, 129, 128, 2, 3},
};
this->Run();
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_OddO)
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 129, 1},
{128, 128, 32, 129, 2, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16)
{
this->lengths_ = std::vector<std::vector<int>>{
{256, 256, 64, 64, 768},
{256, 256, 128, 128, 768},
{512, 512, 64, 64, 768},
{512, 512, 128, 128, 768},
{1024, 1024, 64, 64, 768},
{1024, 1024, 128, 128, 768},
{2048, 2048, 64, 64, 768},
{2048, 2048, 128, 128, 768},
{4096, 4096, 64, 64, 768},
{4096, 4096, 128, 128, 768},
{256, 256, 64, 64, 48, 16},
{256, 256, 128, 128, 48, 16},
{512, 512, 64, 64, 48, 16},
{512, 512, 128, 128, 48, 16},
{1024, 1024, 64, 64, 48, 16},
{1024, 1024, 128, 128, 48, 16},
{2048, 2048, 64, 64, 48, 16},
{2048, 2048, 128, 128, 48, 16},
{4096, 4096, 64, 64, 48, 16},
{4096, 4096, 128, 128, 48, 16},
};
this->bench_ = true;
this->verify_ = false;
......@@ -108,7 +112,7 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
using ck::tensor_operation::device::GemmSpecialization;
// TODO: enable KPadding tests when it is implemented
TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch)
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch)
{
int P = 120; // requires padding
int Q = 128; // do not require padding
......@@ -134,7 +138,7 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch)
// clang-format on
}
TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMismatch)
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMismatch)
{
// IsSupported(M, N, K, O)
// clang-format off
......@@ -148,13 +152,13 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMismatch)
// clang-format on
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, AdhocTest)
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
{
this->lengths_ = std::vector<std::vector<int>>{
{49, 49, 64, 64, 24},
{64, 49, 64, 64, 24},
{1020, 1020, 64, 128, 24},
{576, 576, 64, 64, 24},
{49, 49, 64, 64, 4, 6},
{64, 49, 64, 64, 4, 6},
{1020, 1020, 64, 128, 4, 6},
{576, 576, 64, 64, 4,6},
};
this->bench_ = true;
this->Run();
......
......@@ -18,7 +18,7 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
{
using ADataType = std::tuple_element_t<0, Tuple>;
using B0DataType = std::tuple_element_t<1, Tuple>;
......@@ -179,14 +179,12 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
0, // StrideA
0, // StrideB0
0, // StrideB1
0, // StrideC
0, // BatchStrideA
0, // BatchStrideB0
0, // BatchStrideB1
0, // BatchStrideC
PassThrough{}, // a_element_op
PassThrough{}, // b0_element_op
Scale{}, // acc0_element_op
Scale{1.f}, // acc0_element_op
PassThrough{}, // b1_element_op
PassThrough{}); // c_element_op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment