Commit 6e0a93d2 authored by wangshaojie6's avatar wangshaojie6
Browse files

add test

parent 8cdcad67
...@@ -38,7 +38,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -38,7 +38,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
int N, int N,
int K, int K,
int O, int O,
int BatchCount = 1, int G0,
int G1,
int StrideA = -1, int StrideA = -1,
int StrideB0 = -1, int StrideB0 = -1,
int StrideB1 = -1, int StrideB1 = -1,
...@@ -46,7 +47,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -46,7 +47,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
int BatchStrideA = -1, int BatchStrideA = -1,
int BatchStrideB0 = -1, int BatchStrideB0 = -1,
int BatchStrideB1 = -1, int BatchStrideB1 = -1,
int BatchStrideC = -1) int BatchStrideC = -1,
float alpha = 1.f)
{ {
...@@ -68,7 +70,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -68,7 +70,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
AccDataType, AccDataType,
AElementOp, AElementOp,
B0ElementOp, B0ElementOp,
CElementOp>; Acc0ElementOp>;
// Ref Softmax: fp32 in, various type out // Ref Softmax: fp32 in, various type out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
...@@ -85,6 +87,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -85,6 +87,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
bool pass = true; bool pass = true;
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M; const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K; const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N; const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
...@@ -105,6 +110,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -105,6 +110,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC; BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
const int BatchCount = G0 * G1;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
...@@ -130,18 +137,22 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -130,18 +137,22 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o( Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_g_m_o_host_result( Tensor<CDataType> c_gs_ms_os_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
Tensor<CDataType> c_g_m_o_device_result( std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); Tensor<CDataType> c_gs_ms_os_device_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
// Host verification: Output of Gemm0 is input A of Gemm1 // Host verification: Output of Gemm0 is input A of Gemm1
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{BatchCount, M, O},
std::vector<int>{M * O, O, 1});
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
std::srand(1); // work around test flakiness std::srand(1); // work around test flakiness
switch(init_method) switch(init_method)
...@@ -178,7 +189,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -178,7 +189,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize()); DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSize()); DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
...@@ -220,7 +232,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -220,7 +232,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// mask out upper triangle // mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if (idx[1] < idx[2]) self(idx) = -ck::NumericLimits<float>::Infinity();
});
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker(); auto ref_softmax_invoker = ref_softmax.MakeInvoker();
...@@ -234,6 +248,16 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -234,6 +248,16 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument); ref_gemm1_invoker.Run(ref_gemm1_argument);
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
} }
std::string best_op_name; std::string best_op_name;
...@@ -302,7 +326,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -302,7 +326,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
pass = pass & pass = pass &
ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData); ck::utils::check_err(c_g_m_o_device_result.mData, c_gs_ms_os_host_result.mData);
if(do_log) if(do_log)
{ {
...@@ -313,7 +337,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -313,7 +337,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
LogRangeAsType<float>(std::cout << "b1_g_n_o : ", b1_g_n_o.mData, ",") LogRangeAsType<float>(std::cout << "b1_g_n_o : ", b1_g_n_o.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_g_m_o_host_result : ", c_g_m_o_host_result.mData, ",") std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_g_m_o_device_result : ", c_g_m_o_device_result.mData, ",") std::cout << "c_g_m_o_device_result : ", c_g_m_o_device_result.mData, ",")
......
add_custom_target(test_batched_gemm_softmax_gemm)
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance)
add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16)
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_softmax_gemm_util.hpp"
template <typename Tuple>
class TestBatchedGemmSoftmaxGemmFP16 : public TestBatchedGemmSoftmaxGemm<Tuple>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row>
>;
// clang-format on
TYPED_TEST_SUITE(TestBatchedGemmSoftmaxGemmFP16, KernelTypes);
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16) { this->Run(); }
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_PadM)
{
this->lengths_ = std::vector<std::vector<int>>{
{136, 128, 32, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_PadN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 136, 32, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_PadK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 40, 128, 1},
{128, 128, 136, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_PadO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 136, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_OddM)
{
this->lengths_ = std::vector<std::vector<int>>{
{129, 128, 32, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_OddN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 129, 32, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_OddK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 33, 128, 1},
{128, 128, 129, 128, 1},
};
this->Run();
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_OddO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 129, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
{
this->lengths_ = std::vector<std::vector<int>>{
{256, 256, 64, 64, 768},
{256, 256, 128, 128, 768},
{512, 512, 64, 64, 768},
{512, 512, 128, 128, 768},
{1024, 1024, 64, 64, 768},
{1024, 1024, 128, 128, 768},
{2048, 2048, 64, 64, 768},
{2048, 2048, 128, 128, 768},
{4096, 4096, 64, 64, 768},
{4096, 4096, 128, 128, 768},
};
this->bench_ = true;
this->verify_ = false;
this->Run();
}
using ck::tensor_operation::device::GemmSpecialization;
// TODO: enable KPadding tests when it is implemented
TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch)
{
int P = 120; // requires padding
int Q = 128; // do not require padding
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
// clang-format on
}
TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMismatch)
{
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
// clang-format on
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, AdhocTest)
{
this->lengths_ = std::vector<std::vector<int>>{
{49, 49, 64, 64, 24},
{64, 49, 64, 64, 24},
{1020, 1020, 64, 128, 24},
{576, 576, 64, 64, 24},
};
this->bench_ = true;
this->Run();
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp"
using ck::tensor_operation::device::GemmSpecialization;
template <ck::index_t N>
using I = ck::Number<N>;
using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
{
using ADataType = std::tuple_element_t<0, Tuple>;
using B0DataType = std::tuple_element_t<1, Tuple>;
using B1DataType = std::tuple_element_t<2, Tuple>;
using CDataType = std::tuple_element_t<3, Tuple>;
using ALayout = std::tuple_element_t<4, Tuple>;
using B0Layout = std::tuple_element_t<5, Tuple>;
using B1Layout = std::tuple_element_t<6, Tuple>;
using CLayout = std::tuple_element_t<7, Tuple>;
std::vector<std::vector<int>> lengths_ = {
{256, 256, 64, 64, 4},
{256, 256, 128, 128, 4},
{512, 512, 64, 64, 2},
{512, 512, 128, 128, 2},
{1024, 1024, 64, 64, 1},
{1024, 1024, 128, 128, 1},
};
bool bench_ = false;
bool verify_ = true;
void RunSingle(int M, int N, int K, int O, int BatchCount)
{
bool pass = ck::profiler::profile_batched_gemm_softmax_gemm_impl<ADataType,
B0DataType,
B1DataType,
CDataType,
ALayout,
B0Layout,
B1Layout,
CLayout>(
verify_, 1, false, bench_, M, N, K, O, BatchCount);
EXPECT_TRUE(pass);
}
void Run()
{
for(auto lengths : this->lengths_)
{
int M = lengths[0];
int N = lengths[1];
int K = lengths[2];
int O = lengths[3];
int BatchCount = lengths[4];
this->RunSingle(M, N, K, O, BatchCount);
}
}
};
template <GemmSpecialization GemmSpec>
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = F16;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using DeviceGemmGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
bool IsSupported(int M, int N, int K, int O)
{
auto gemm = DeviceGemmGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
static_cast<B0DataType*>(nullptr),
static_cast<B1DataType*>(nullptr),
static_cast<CDataType*>(nullptr),
M,
N,
K,
O,
0, // BatchCount
0, // StrideA
0, // StrideB0
0, // StrideB1
0, // StrideC
0, // BatchStrideA
0, // BatchStrideB0
0, // BatchStrideB1
0, // BatchStrideC
PassThrough{}, // a_element_op
PassThrough{}, // b0_element_op
PassThrough{}, // acc0_element_op
PassThrough{}, // b1_element_op
PassThrough{}); // c_element_op
return gemm.IsSupportedArgument(argument);
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment