Commit c7c7f754 authored by aska-0096's avatar aska-0096
Browse files

RCR GEMM Instantces for AIT

parent c1d9c0ae
......@@ -22,12 +22,86 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>;
using DeviceGemmInstances = std::tuple<
// RCR Gemm AIT
ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, 256,
128, 256, 8, 8, 16, 16, 4, 4,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
1, 1, S<1, 32, 1, 8>, 8, 1>,
ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, 256,
256, 128, 8, 8, 16, 16, 8, 2,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
1, 1, S<1, 32, 1, 8>, 8, 1>,
ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, 256,
128, 256, 4, 8, 16, 16, 4, 4,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
1, 1, S<1, 32, 1, 8>, 8, 1>,
ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, 256,
256, 128, 4, 8, 16, 16, 8, 2,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
1, 1, S<1, 32, 1, 8>, 8, 1>,
ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, 256,
128, 128, 8, 8, 16, 16, 4, 2,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
1, 1, S<1, 32, 1, 8>, 8, 1>,
ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, 256,
128, 128, 4, 8, 16, 16, 4, 2,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
1, 1, S<1, 32, 1, 8>, 8, 1>,
ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, 256,
256, 64, 8, 8, 16, 16, 8, 1,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
1, 1, S<1, 32, 1, 8>, 8, 1>,
ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, 256,
64, 256, 8, 8, 16, 16, 2, 4,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
1, 1, S<1, 32, 1, 8>, 8, 1>,
ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, 128,
128, 128, 8, 8, 16, 16, 8, 2,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
1, 1, S<1, 16, 1, 8>, 8, 1>,
ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, 128,
128, 64, 8, 8, 16, 16, 4, 2,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
1, 1, S<1, 32, 1, 4>, 8, 1>,
ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, 128,
64, 128, 8, 8, 16, 16, 4, 2,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
1, 1, S<1, 16, 1, 8>, 8, 1>
>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......
......@@ -11,7 +11,15 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
// auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -71,6 +79,9 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
ck::static_for<0, std::tuple_size_v<DeviceGemmInstances>, 1>{}([&](auto i) {
const auto device_gemm_instance = std::get<i>(DeviceGemmInstances{});
using DeviceGemmInstance = ck::remove_cvref_t<decltype(device_gemm_instance)>;
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
......@@ -96,9 +107,9 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this Gemm problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
......@@ -116,6 +127,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
if(config.do_verification)
{
#if 0
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......@@ -137,8 +149,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
#endif
}
});
return true;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment