Commit 8cc53111 authored by aska-0096's avatar aska-0096
Browse files

tempsave, failed to generate ideal code based on iglp

parent 29087570
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp"
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
...@@ -18,7 +17,7 @@ using F16 = ck::half_t; ...@@ -18,7 +17,7 @@ using F16 = ck::half_t;
using F32 = float; using F32 = float;
using ALayout = Row; using ALayout = Row;
using BLayout = Row; using BLayout = Col;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -213,12 +212,11 @@ using DeviceGemmFactory = ...@@ -213,12 +212,11 @@ using DeviceGemmFactory =
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 4, 2, 0, 1, 4, 2, 0,
1, 1, S<1, 32, 1, 8>, 8>, 1, 1, S<1, 32, 1, 8>, 8>,
#endif
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
Row, Row, Row, Row, Row, Row,
F16, F16, F16, F32, F16, F16, F16, F16, F32, F16,
PassThrough, PassThrough, PassThrough, GemmDefault, PassThrough, PassThrough, PassThrough, GemmDefault,
2, 256, 1, 256,
256, 256, 256, 256,
32, 8, 4, 32, 8, 4,
32, 32, 32, 32,
...@@ -228,24 +226,25 @@ using DeviceGemmFactory = ...@@ -228,24 +226,25 @@ using DeviceGemmFactory =
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 4, 4, 1, 1, 4, 4, 1,
1, 1, S<1, 32, 1, 8>, 8, 1, 1, S<1, 32, 1, 8>, 8,
ck::LoopScheduler::Default, ck::PipelineVersion::v1>, ck::LoopScheduler::Default, ck::PipelineVersion::v1>
#endif
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
Row, Row, Row, Row, Col, Row,
F16, F16, F16, F32, F16, F16, F16, F16, F32, F16,
PassThrough, PassThrough, PassThrough, GemmDefault, PassThrough, PassThrough, PassThrough, GemmDefault,
1, 256, 2, 256,
256, 256, 256, 256,
32, 8, 4, 32, 8, 8,
32, 32, 32, 32,
4, 4, 4, 4,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 1, 2, 8, 8, 0,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
1, 4, 4, 1, 2, 8, 8, 0,
1, 1, S<1, 32, 1, 8>, 8, 1, 1, S<1, 32, 1, 8>, 8,
ck::LoopScheduler::Default, ck::PipelineVersion::v1> ck::LoopScheduler::Default, ck::PipelineVersion::v1>
>;
>;
// clang-format on // clang-format on
......
...@@ -7,12 +7,14 @@ ...@@ -7,12 +7,14 @@
using ADataType = ck::f8_t; using ADataType = ck::f8_t;
using BDataType = ck::f8_t; using BDataType = ck::f8_t;
using CDataType = ck::half_t; using CDataType = ck::f8_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = float;
using F8 = ck::f8_t;
using F32 = float;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Row;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -21,15 +23,138 @@ using CElementOp = PassThrough; ...@@ -21,15 +23,138 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off using DeviceGemmFactory = std::tuple<
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle #if 1
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<Row,
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Row,
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Row,
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | F8,
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; F8,
// clang-format on F8,
F32,
F8,
PassThrough,
PassThrough,
PassThrough,
GemmDefault,
1,
256,
256,
128,
64,
16,
4,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
1,
S<8, 32, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
4,
0,
1,
1,
S<1, 64, 1, 4>,
16,
ck::LoopScheduler::Interwave,
ck::PipelineVersion::v1>,
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<Row,
Row,
Row,
F8,
F8,
F8,
F32,
F8,
PassThrough,
PassThrough,
PassThrough,
GemmDefault,
1,
256,
256,
128,
64,
16,
16,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
1,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
2,
16,
1,
1,
1,
S<1, 64, 1, 4>,
16,
ck::LoopScheduler::Interwave,
ck::PipelineVersion::v1>,
#endif
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<Row,
Row,
Row,
F8,
F8,
F8,
F32,
F8,
PassThrough,
PassThrough,
PassThrough,
GemmDefault,
1,
256,
256,
128,
64,
16,
8,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
1,
S<8, 32, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
8,
1,
1,
1,
S<1, 64, 1, 4>,
16,
ck::LoopScheduler::Interwave,
ck::PipelineVersion::v1>>;
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
#pragma once #pragma once
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
template <typename ProblemType> template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{ {
...@@ -68,6 +66,26 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -68,6 +66,26 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break; break;
case 2:
ck::utils::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 3:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break;
case 4:
ck::utils::FillUniformDistribution<ADataType>{0.0f, 0.1f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-0.01f, 0.01f}(b_k_n);
break;
case 5:
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 6:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
break;
default: default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k); ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n); ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
...@@ -99,62 +117,24 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -99,62 +117,24 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif #endif
DeviceMem workspace;
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
using BaseStreamK = ck::tensor_operation::device::DeviceGemmStreamK<ALayout, float best_perf = .0;
BLayout, float best_time = .0;
CLayout, std::string best_kernel = "";
ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
if constexpr(std::is_same<ProblemType, ProblemSize>::value &&
!std::is_base_of<BaseStreamK, DeviceGemmInstance>::value)
{
auto argument = gemm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument)) ck::static_for<0, std::tuple_size_v<DeviceGemmFactory>, 1>{}([&](auto i) -> void {
{ const auto device_gemm_instance = std::get<i>(DeviceGemmFactory{});
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true; using DeviceGemmInstance = ck::remove_cvref_t<decltype(device_gemm_instance)>;
} // do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
}
else if constexpr(std::is_same<ProblemType, ProblemSizeStreamK>::value &&
std::is_base_of<BaseStreamK, DeviceGemmInstance>::value)
{
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()), static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
...@@ -173,51 +153,47 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -173,51 +153,47 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op);
problem_size.NumSKBlocks); #if 0
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true; return true;
} }
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
if(workspace_size != 0)
{
workspace.Realloc(workspace_size);
gemm.SetWorkSpacePointer(&argument, workspace.GetDeviceBuffer());
}
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
#if 0
// TODO!!!!!
if(workspace_size != 0){
float * ws_ptr = reinterpret_cast<float*>(malloc(workspace_size));
size_t ws_dwords = workspace_size / sizeof(float);
workspace.FromDevice(ws_ptr);
for(size_t i = 0; i < ws_dwords; i++) {
uint32_t rere = reinterpret_cast<uint32_t*>(ws_ptr)[i];
printf("%4lu : %f(0x%08x)\n", i, ws_ptr[i], rere);
}
free(ws_ptr);
}
#endif #endif
} ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float gb_per_sec = num_btype / 1.E6 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
if(tflops > best_perf)
{
best_perf = tflops;
best_time = ave_time;
best_kernel = gemm.GetTypeString();
}
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " float gb_per_sec = num_btype / 1.E6 / ave_time;
<< gemm.GetTypeString() << std::endl;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
});
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Problem Size: M: " << M << ", N: " << N << ", K: " << K << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
<< " ms" << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
if(config.do_verification) if(config.do_verification)
{ {
......
...@@ -44,9 +44,9 @@ int run_layernorm2d_fwd_example() ...@@ -44,9 +44,9 @@ int run_layernorm2d_fwd_example()
{0, 1}, {0, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(), std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()}, save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(), std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()}, save_mean.mDesc.GetStrides().end()},
{1}, {1},
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
......
...@@ -65,9 +65,9 @@ int run_groupnorm_fwd_example(int argc, char* argv[]) ...@@ -65,9 +65,9 @@ int run_groupnorm_fwd_example(int argc, char* argv[])
{0, 0, 0, C, 1}, {0, 0, 0, C, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(), std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()}, save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(), std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()}, save_mean.mDesc.GetStrides().end()},
{1, 2, 4}, // reduction dimension: [H, W, C] {1, 2, 4}, // reduction dimension: [H, W, C]
1e-6, 1e-6,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
......
...@@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example() ...@@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example()
{0, W * C, C, 1}, {0, W * C, C, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(), std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()}, save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(), std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()}, save_mean.mDesc.GetStrides().end()},
{1, 2, 3}, {1, 2, 3},
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#ifdef CK_USE_LAUNCH_BOUNDS #ifdef CK_USE_LAUNCH_BOUNDS
// for most kernels // for most kernels
#define CK_MAX_THREAD_PER_BLOCK 256 #define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU 2 #define CK_MIN_BLOCK_PER_CU 1
// for wavelet GEMM kernel // for wavelet GEMM kernel
#define CK_WAVELET_MAX_THREAD_PER_BLOCK 512 #define CK_WAVELET_MAX_THREAD_PER_BLOCK 512
......
...@@ -51,7 +51,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -51,7 +51,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipDeviceSynchronize()); hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < profile_repeat; ++i) for(int i = 0; i < nrepeat; ++i)
{ {
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError()); hip_check_error(hipGetLastError());
...@@ -64,7 +64,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -64,7 +64,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
return total_time / profile_repeat; return total_time / nrepeat;
} }
else else
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
#if 0
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
#endif
template <
index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false,
index_t AMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
index_t BMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
struct BlockwiseGemmXdlops_pipeline_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm =
XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{};
static constexpr auto xdlops_gemm_sp =
XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack / 2, FloatAB, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__
BlockwiseGemmXdlops_pipeline_v1(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1th
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// Write to LDS immediately to save VGPR
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
__builtin_amdgcn_sched_barrier(0);
// Global prefetch 2th, Hold in VGPR
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Wait all wave produce this K-loop data
block_sync_lds();
// preload data into VGPR, 1 kpack of K-loop
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
// main body
if constexpr(HasMainLoop)
{
#if 0
if(get_thread_local_1d_id() == 0)
{
printf("HasMainLoop=True\n");
}
#endif
index_t i = 0;
do
{
// Here only KRepeat-1 times read (1~KRepeat) & compute (0~KRepat-1) of this k-loop
static_for<1, KRepeat, 1>{}([&](auto k) { // k=1,2 instead of kpack*1, ...
/* Read N+1 */
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, Number<k % 2>{}, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf);
});
/* Compute N */
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (k - 1) % 2, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (k - 1) % 2, 0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
/*
__builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 4, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x100, 3, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 3, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
*/
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
// Wait all wave consume this k-loop data
block_sync_lds();
// Write 2th prefetch k-loop
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
/* We have entire loop to cover buffer_load latency */
// Next Global prefetch
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
++i;
// compute(idx=KRepeat) this K-loop can hide ds_write latency
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
/* Compute N */
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack / 2, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (KRepeat - 1) % 2, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (KRepeat - 1) % 2, 0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm_sp.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
// Wait all wave produce next k-loop data
block_sync_lds();
// Here 1 time prefetch read(idx=0) of next K-loop
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
/* Compute N */
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack / 2, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (KRepeat - 1) % 2, 0, ik + KPack / 2))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (KRepeat - 1) % 2, 0, ik + KPack / 2))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm_sp.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// Current best = 112T at float initialization 3840x4096x4096
// __builtin_amdgcn_sched_group_barrier(0x020, 4, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
// __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x200, 6, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
} while(i < (num_loop - 2));
}
// tail
{
// N-2 th k-loop
// Here only KRepeat-1 times read & compute
static_for<1, KRepeat, 1>{}([&](auto k) { // k=1,2 instead of kpack*1, ...
/* Read N+1 */
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, Number<k % 2>{}, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf);
});
/* Compute N */
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (k - 1) % 2, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (k - 1) % 2, 0, ik))>{}];
});
#if 0
if(get_thread_local_1d_id() == 0)
{
printf("rep of m.n.k (%01d, %01d, %01d)\n",
m0.value,
n0.value,
k.value - 1);
}
printf(
"Tid: %03d, A_compute_buf: %04x %04x %04x %04x %04x %04x %04x %04x\n",
get_thread_local_1d_id(),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<0>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<1>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<2>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<3>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<4>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<5>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<6>{})))),
*(reinterpret_cast<uint16_t*>(
&(a_thread_vec.template AsType<FloatAB>()(Number<7>{})))));
#endif
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
__builtin_amdgcn_sched_group_barrier(0x100, 8, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 32, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// compute(idx=KRepeat) this K-loop can hide ds_write latency
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
/* Compute N */
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack / 2, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (KRepeat - 1) % 2, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (KRepeat - 1) % 2, 0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm_sp.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
__builtin_amdgcn_sched_group_barrier(0x200, 8, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 32, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
// wait final data producer finished
block_sync_lds();
// N-1 th k-loop
// Here 1 time prefetch read(idx=0) of next K-loop
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
/* Compute N */
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack / 2, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (KRepeat - 1) % 2, 0, ik + KPack / 2))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (KRepeat - 1) % 2, 0, ik + KPack / 2))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm_sp.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
__builtin_amdgcn_sched_group_barrier(0x100, 8, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 32, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
// Here only KRepeat-1 times read & compute
static_for<1, KRepeat, 1>{}([&](auto k) { // k=1,2 instead of kpack*1, ...
/* Read N+1 */
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, Number<k % 2>{}, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf);
});
/* Compute N */
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (k - 1) % 2, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (k - 1) % 2, 0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
__builtin_amdgcn_sched_group_barrier(0x100, 8, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 32, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
/* Final Compute issue */
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (KRepeat - 1) % 2, 0, i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (KRepeat - 1) % 2, 0, i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
}
}
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, I2, I1, Number<KPack>{}));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, I2, I1, Number<KPack>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
#if 0
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
#endif
template <
index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false,
index_t AMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
index_t BMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
struct BlockwiseGemmXdlops_pipeline_v2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm =
XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__
BlockwiseGemmXdlops_pipeline_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1th
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// Write to LDS immediately to save VGPR
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
__builtin_amdgcn_sched_barrier(0);
// Global prefetch 2th, Hold in VGPR
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// main body
if constexpr(HasMainLoop)
{
#if 0
if(get_thread_local_1d_id() == 0)
{
printf("HasMainLoop=True\n");
}
#endif
index_t i = 0;
do
{
// Wait all wave produce this K-loop data
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { // k=1,2 instead of kpack*1, ...
/* Read N+1 */
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, Number<k % 2>{}, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf);
});
});
});
/* Compute N */
static_for<0, MRepeat / 2, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, 0, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, 0, 0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// schedule 16 ds_read_b128 and 16 v_mfma
__builtin_amdgcn_sched_group_barrier(0x100, 8, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
// Write 2th prefetch k-loop
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
/* We have entire loop to cover buffer_load latency */
// Next Global prefetch
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<MRepeat / 2, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, 0, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, 0, 0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I1, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I1, 0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// schedule 8 ds_write_b128, 8 buffer_load_dwordx4, 48 v_mfma
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
++i;
} while(i < (num_loop - 2));
}
// tail
{
// Wait all wave produce this K-loop data
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { // k=1,2 instead of kpack*1, ...
/* Read N+1 */
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, Number<k % 2>{}, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf);
});
});
});
/* Compute N */
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, 0, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, 0, 0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// schedule 16 ds_read_b128 and 32 v_mfma
__builtin_amdgcn_sched_group_barrier(0x100, 8, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
// Write 2th prefetch k-loop
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I1, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I1, 0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// schedule 8 ds_write_b128, 8 buffer_load_dwordx4, 32 v_mfma
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
// Wait all wave produce this K-loop data
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) { // k=1,2 instead of kpack*1, ...
/* Read N+1 */
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, Number<k % 2>{}, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf);
});
});
/* Compute N */
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, k % 2, 0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, k % 2, 0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// schedule 16 ds_read_b128 and 64 v_mfma
__builtin_amdgcn_sched_group_barrier(0x100, 8, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 5, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
}
}
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, I2, I1, Number<KPack>{}));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, I2, I1, Number<KPack>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
#if 0
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
#endif
template <
index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false,
index_t AMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
index_t BMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
struct BlockwiseGemmXdlops_pipeline_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm =
XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__
BlockwiseGemmXdlops_pipeline_v3(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
// Inst List:
// ds_read_b128: 16
// ds_write_b128: 8
// buffer_load_dwordx4: 16
// v_mfma: 0
// -------------------------------------------------------------------------------------------
// Global prefetch 1th
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Global prefetch 2th
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
});
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
// This hot loop has two legacy loopover, to implement the double local buffer strategy
do
{
// Inst List:
// ds_read_b128: 32
// ds_write_b128: 16
// buffer_load_dwordx4: 16
// v_mfma: 128
// -------------------------------------------------------------------------------------------
// I0 As Compute, I1 As Read
{
using ComputeLocalBuffer = Number<0>;
using MemoryLocalBuffer = Number<1>;
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, ComputeLocalBuffer{}, I0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, ComputeLocalBuffer{}, I0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
static_for<0, MRepeat / 2, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, ComputeLocalBuffer{}, I1, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, ComputeLocalBuffer{}, I1, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// schedule 8 ds_write_b128, 8 buffer_load_dwordx4, 48 v_mfma
// two single MFMA for hide ds_write & VMEM read issue latency
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, MemoryLocalBuffer{}, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, MemoryLocalBuffer{}, k, I0),
b_thread_buf);
});
});
});
static_for<MRepeat / 2, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, ComputeLocalBuffer{}, I1, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, ComputeLocalBuffer{}, I1, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// schedule 16 ds_read_b128 and 16 v_mfma
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_barrier(0);
}
// -------------------------------------------------------------------------------------------
// I1 As Compute, I0 As Read
{
using ComputeLocalBuffer = Number<1>;
using MemoryLocalBuffer = Number<0>;
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, ComputeLocalBuffer{}, I0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, ComputeLocalBuffer{}, I0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
static_for<0, MRepeat / 2, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, ComputeLocalBuffer{}, I1, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, ComputeLocalBuffer{}, I1, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// schedule 8 ds_write_b128, 8 buffer_load_dwordx4, 48 v_mfma
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, MemoryLocalBuffer{}, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, MemoryLocalBuffer{}, k, I0),
b_thread_buf);
});
});
});
static_for<MRepeat / 2, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, ComputeLocalBuffer{}, I1, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, ComputeLocalBuffer{}, I1, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// schedule 16 ds_read_b128 and 16 v_mfma
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_barrier(0);
}
i += 2;
} while(i < (num_loop - 2));
}
// tail
{
// Inst List:
// ds_read_b128: 16
// ds_write_b128: 8
// buffer_load_dwordx4: 0
// v_mfma: 128
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, I0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, I0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// schedule 8 ds_write_b128 and 32 v_mfma
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I1, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I1, k, I0),
b_thread_buf);
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, I1, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, I1, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// schedule 16 ds_read_b128 and 32 v_mfma
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I1, k, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I1, k, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// 64 v_mfma
__builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
}
}
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, I2, Number<KRepeat>{}, Number<KPack>{}));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, I2, Number<KRepeat>{}, Number<KPack>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
// Double LDS buffer
// Prefetech 2 stage
// Local prefetch 1 stage
namespace ck {
#if 0
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
#endif
template <
index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false,
index_t AMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
index_t BMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
struct BlockwiseGemmXdlops_pipeline_v4
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm =
XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__
BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
__device__ static constexpr auto HotLoopScheduler()
{
// schedule
constexpr auto num_ds_read_inst = 16;
constexpr auto num_ds_write_inst = 8;
constexpr auto num_buffer_load_inst = 8;
constexpr auto num_mfma_inst = 64;
constexpr auto num_issue = num_buffer_load_inst;
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, num_ds_read_inst/num_buffer_load_inst, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, num_ds_write_inst/num_buffer_load_inst, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_inst/num_buffer_load_inst-3, 0); // MFMA
});
}
template <index_t stage>
__device__ static constexpr auto TailScheduler()
{
}
template <>
__device__ static constexpr auto TailScheduler<1>()
{
// schedule
constexpr auto num_ds_read_inst = 16;
constexpr auto num_ds_write_inst = 8;
constexpr auto num_mfma_inst = 64;
constexpr auto num_issue = num_ds_write_inst;
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, num_ds_read_inst/num_ds_write_inst-1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_inst/num_ds_write_inst-3, 0); // MFMA
});
}
template <>
__device__ static constexpr auto TailScheduler<2>()
{
// schedule
constexpr auto num_ds_read_inst = 16;
constexpr auto num_mfma_inst = 64;
constexpr auto num_issue = num_ds_read_inst;
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_inst/num_ds_read_inst, 0); // MFMA
});
}
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
// Inst List:
// ds_read_b128: 16
// ds_write_b128: 8
// buffer_load_dwordx4: 16
// v_mfma: 0
// -------------------------------------------------------------------------------------------
// Global prefetch 1th, Fill Ping LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
// Local prefetch 1th, Fill Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(I0));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(I0),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(I0));
});
});
});
// Global prefetch 2th, Fill Pong LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
// Global prefetch 3rd
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
// This hot loop has two legacy loopover, to implement the double local buffer strategy
do
{
// -------------------------------------------------------------------------------------------
using PingP1 = Number<0>;
using PongP1 = Number<1>;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
// -------------------------------------------------------------------------------------------
using PingP2 = Number<1>;
using PongP2 = Number<0>;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP2{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP2{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP2{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP2{}));
});
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{}));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 2;
} while(i < (num_loop - 3));
}
// tail
{
using PingP1 = Number<0>;
using PongP1 = Number<1>;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
TailScheduler<1>();
__builtin_amdgcn_sched_barrier(0);
// -------------------------------------------------------------------------------------------
using PingP2 = Number<1>;
using PongP2 = Number<0>;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP2{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP2{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP2{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP2{}));
});
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
TailScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PongP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PongP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// 64 v_mfma
__builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
}
}
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KPack * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KPack * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
// Double LDS buffer
// Prefetech 2 stage
// Local prefetch 1 stage
namespace ck {
#if 0
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
#endif
template <
index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false,
index_t AMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
index_t BMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
struct BlockwiseGemmXdlops_pipeline_v4
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm =
XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__
BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
__device__ static constexpr auto HotLoopScheduler()
{
// schedule
constexpr auto num_ds_read_inst = 16;
constexpr auto num_ds_write_inst = 8;
constexpr auto num_buffer_load_inst = 8;
constexpr auto num_mfma_inst = 64;
constexpr auto num_issue = num_buffer_load_inst;
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, num_ds_read_inst/num_buffer_load_inst, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, num_ds_write_inst/num_buffer_load_inst, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_inst/num_buffer_load_inst-3, 0); // MFMA
});
}
template <index_t stage>
__device__ static constexpr auto TailScheduler()
{
}
template <>
__device__ static constexpr auto TailScheduler<1>()
{
// schedule
constexpr auto num_ds_read_inst = 16;
constexpr auto num_ds_write_inst = 8;
constexpr auto num_mfma_inst = 64;
constexpr auto num_issue = num_ds_write_inst;
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, num_ds_read_inst/num_ds_write_inst-1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_inst/num_ds_write_inst-3, 0); // MFMA
});
}
template <>
__device__ static constexpr auto TailScheduler<2>()
{
// schedule
constexpr auto num_ds_read_inst = 16;
constexpr auto num_mfma_inst = 64;
constexpr auto num_issue = num_ds_read_inst;
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_inst/num_ds_read_inst, 0); // MFMA
});
}
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
// -------------------------------------------------------------------------------------------
// Global prefetch 1th, Fill Ping LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
// Local prefetch 1th, Fill Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(I0),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
});
// Global prefetch 2th, Fill Pong LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
// Global prefetch 3rd
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
// This hot loop has two legacy loopover, to implement the double local buffer strategy
do
{
// -------------------------------------------------------------------------------------------
using PingP1 = Number<0>;
using PongP1 = Number<1>;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, PongP1{}, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, PongP1{}, k, I0),
b_thread_buf);
});
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, PingP1{}, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, PingP1{}, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
// -------------------------------------------------------------------------------------------
using PingP2 = Number<1>;
using PongP2 = Number<0>;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP2{}),
a_thread_desc_,
make_tuple(m0, PongP2{}, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP2{}),
b_thread_desc_,
make_tuple(n0, PongP2{}, k, I0),
b_thread_buf);
});
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{}));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, PingP2{}, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, PingP2{}, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 2;
} while(i < (num_loop - 3));
}
// tail
{
using PingP1 = Number<0>;
using PongP1 = Number<1>;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, PongP1{}, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, PongP1{}, k, I0),
b_thread_buf);
});
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, PingP1{}, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, PingP1{}, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
TailScheduler<1>();
__builtin_amdgcn_sched_barrier(0);
// -------------------------------------------------------------------------------------------
using PingP2 = Number<1>;
using PongP2 = Number<0>;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP2{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP2{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, PingP2{}, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, PingP2{}, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
TailScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, PongP2{}, k, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, PongP2{}, k, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// 64 v_mfma
__builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
}
}
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<MRepeat>{}, I2, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KPack * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<NRepeat>{}, I2, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KPack * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
...@@ -37,7 +37,8 @@ template <index_t BlockSize, ...@@ -37,7 +37,8 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack> index_t KPack,
bool TransposeC = false>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -59,7 +60,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -59,7 +60,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatA, MPerXDL, NPerXDL, KPack, FloatB>{}; static constexpr auto xdlops_gemm =
XdlopsGemm<FloatA, MPerXDL, NPerXDL, KPack, FloatB, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
...@@ -185,6 +187,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -185,6 +187,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N)); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
} }
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
...@@ -211,6 +227,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -211,6 +227,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
} }
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
...@@ -300,33 +330,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -300,33 +330,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KPerThread / KPack, 1>{}([&](auto k) {
// read A static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, // read A
make_tuple(m0, I0, I0, I0), a_thread_copy_.Run(
a_block_buf, a_block_desc_m0_m1_m2_k,
a_thread_desc_, make_tuple(m0, I0, I0, Number<(k.value * KPack)>{}),
make_tuple(I0, I0, I0, I0), a_block_buf,
a_thread_buf); a_thread_desc_,
make_tuple(Number<(k.value % 2)>{}, Number<(m0.value % 2)>{}, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B // read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0), make_tuple(n0, I0, I0, Number<(k.value * KPack)>{}),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0, I0), make_tuple(Number<(k.value % 2)>{}, n0, I0, I0),
b_thread_buf); b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<FloatA, KPack> a_thread_vec; vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec; vector_type<FloatB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) = a_thread_buf a_thread_vec.template AsType<FloatA>()(i) =
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(k % 2, m0 % 2, 0, i))>{}];
b_thread_vec.template AsType<FloatB>()(i) = b_thread_buf b_thread_vec.template AsType<FloatB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; [Number<b_thread_desc_.CalculateOffset(make_tuple(k % 2, n0, 0, i))>{}];
}); });
using mfma_input_type_a = using mfma_input_type_a =
...@@ -349,11 +381,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -349,11 +381,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
protected: protected:
// A[M0, M1, M2, KPerThread] // A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(I2, I2, I1, Number<KPack>{}));
// B[N0, N1, N2, KPerThread] // B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(I2, Number<NRepeat>{}, I1, Number<KPack>{}));
// C[M, N, NumRegXdlops] // C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
...@@ -363,7 +395,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -363,7 +395,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
FloatA, FloatA,
decltype(a_block_desc_m0_m1_m2_k), decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>, Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
A_K1, A_K1,
...@@ -373,7 +405,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -373,7 +405,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
FloatB, FloatB,
decltype(b_block_desc_n0_n1_n2_k), decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>, Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
B_K1, B_K1,
...@@ -398,6 +430,7 @@ template <index_t BlockSize, ...@@ -398,6 +430,7 @@ template <index_t BlockSize,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack, index_t KPack,
bool TransposeC = false,
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS> index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
...@@ -410,7 +443,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -410,7 +443,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
TransposeC>
{ {
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatA, FloatA,
...@@ -422,7 +456,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -422,7 +456,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>; KPack,
TransposeC>;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using Base::a_block_desc_m0_m1_m2_k; using Base::a_block_desc_m0_m1_m2_k;
...@@ -586,7 +621,8 @@ template <index_t BlockSize, ...@@ -586,7 +621,8 @@ template <index_t BlockSize,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack, index_t KPack,
LoopScheduler LoopSched> LoopScheduler LoopSched,
bool TransposeC = false>
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
{ {
if constexpr(LoopSched == LoopScheduler::Default) if constexpr(LoopSched == LoopScheduler::Default)
...@@ -601,7 +637,8 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() ...@@ -601,7 +637,8 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; KPack,
TransposeC>{};
} }
else if constexpr(LoopSched == LoopScheduler::Interwave) else if constexpr(LoopSched == LoopScheduler::Interwave)
{ {
...@@ -615,7 +652,8 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() ...@@ -615,7 +652,8 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; KPack,
TransposeC>{};
} }
}; };
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1_doublelds.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -155,10 +155,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -155,10 +155,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
float ave_time = 0; float ave_time = 0;
#if 1
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, true>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
#endif
#if 0
const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, true>; const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, true>;
...@@ -173,7 +178,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -173,7 +178,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
} }
#endif
return ave_time; return ave_time;
} }
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
...@@ -22,7 +22,8 @@ __global__ void ...@@ -22,7 +22,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -357,6 +358,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -357,6 +358,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
} }
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
__host__ __device__ static auto __host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{ {
...@@ -810,25 +831,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -810,25 +831,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr index_t KPack = math::max( constexpr index_t KPack =
math::lcm(AK1Number, BK1Number), math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk); MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
// BlockSize,
// ComputeType,
// FloatGemmAcc,
// decltype(a_block_desc_ak0_m_ak1),
// decltype(b_block_desc_bk0_n_bk1),
// MPerXdl,
// NPerXdl,
// MXdlPerWave,
// NXdlPerWave,
// KPack,
// LoopSched>();
auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v3<
BlockSize, BlockSize,
ComputeTypeA, ComputeTypeA,
ComputeTypeB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl, MPerXdl,
NPerXdl, NPerXdl,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
KPack, KPack>{}; // TransposeC
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
...@@ -846,27 +882,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -846,27 +882,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline // gridwise GEMM pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>); static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; // const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1, blockwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm, c_thread_buf,
c_thread_buf, num_k_block_main_loop);
num_k_block_main_loop);
// shuffle C and write out // shuffle C and write out
{ {
...@@ -879,12 +914,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -879,12 +914,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// TODO: hacky, fix it! // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it! // TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
...@@ -924,7 +959,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -924,7 +959,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseGemm, bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
typename GridwiseGemm::Problem problem)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = problem;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ALayout,
typename BLayout,
typename CLayout,
typename FloatA,
typename FloatB,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeTypeA = FloatC,
typename ComputeTypeB = ComputeTypeA>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
}
__host__ static auto CalculateMPadded(index_t M)
{
return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
}
__host__ static auto CalculateNPadded(index_t N)
{
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
}
__host__ static auto CalculateKPadded(index_t K)
{
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
}
__host__ static auto CalculateAK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
return CalculateKPadded(K) / AK1Value;
}
else
{
return K / AK1Value;
}
}
__host__ static auto CalculateBK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
return CalculateKPadded(K) / BK1Value;
}
else
{
return K / BK1Value;
}
}
__host__ static auto CalculateMBlock(index_t M)
{
return math::integer_divide_floor(M, MPerBlock);
}
__host__ static auto CalculateNBlock(index_t N)
{
return math::integer_divide_floor(N, NPerBlock);
}
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
struct Problem
{
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KPadded{CalculateKPadded(K_)},
AK0{CalculateAK0(K_)},
BK0{CalculateBK0(K_)},
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
{
}
__host__ void Print() const
{
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t AK0;
index_t BK0;
index_t MBlock;
index_t NBlock;
};
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const FloatA* p_a_grid_,
const FloatB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
}
const FloatA* p_a_grid;
const FloatB* p_b_grid;
FloatC* p_c_grid;
};
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
}
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
}
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max(2 * (a_block_space_size_aligned * sizeof(ComputeTypeA) +
b_block_space_size_aligned * sizeof(ComputeTypeB)),
c_block_size * sizeof(FloatCShuffle));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Problem& problem)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(problem.M % MPerBlock == 0))
{
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(problem.N % NPerBlock == 0))
{
return false;
}
}
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding)
{
if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
!(CalculateKPadded(problem.K) % BK1Value == 0))
{
return false;
}
}
else
{
if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(problem.K % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(problem.M % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(problem.N % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(problem.K % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
// check gridwise gemm pipeline
const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
if(num_k_loop < 3 || (num_k_loop - 3) % 2 != 0)
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return num_loop > 3;
}
template <typename CGridDesc>
__device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const Problem& problem)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatA,
ComputeTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
ComputeTypeB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
// BlockSize,
// ComputeType,
// FloatGemmAcc,
// decltype(a_block_desc_ak0_m_ak1),
// decltype(b_block_desc_bk0_n_bk1),
// MPerXdl,
// NPerXdl,
// MXdlPerWave,
// NXdlPerWave,
// KPack,
// LoopSched>();
auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v4<
BlockSize,
ComputeTypeA,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>{}; // TransposeC
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeTypeB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeTypeA*>(p_shared) + a_block_space_size_aligned +
b_block_space_size_aligned,
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeTypeB*>(p_shared) + a_block_space_size_aligned +
a_block_space_size_aligned + b_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
// gridwise GEMM pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
// const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseGemm, bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
typename GridwiseGemm::Problem problem)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = problem;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ALayout,
typename BLayout,
typename CLayout,
typename FloatA,
typename FloatB,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeTypeA = FloatC,
typename ComputeTypeB = ComputeTypeA>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
}
__host__ static auto CalculateMPadded(index_t M)
{
return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
}
__host__ static auto CalculateNPadded(index_t N)
{
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
}
__host__ static auto CalculateKPadded(index_t K)
{
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
}
__host__ static auto CalculateAK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
return CalculateKPadded(K) / AK1Value;
}
else
{
return K / AK1Value;
}
}
__host__ static auto CalculateBK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
return CalculateKPadded(K) / BK1Value;
}
else
{
return K / BK1Value;
}
}
__host__ static auto CalculateMBlock(index_t M)
{
return math::integer_divide_floor(M, MPerBlock);
}
__host__ static auto CalculateNBlock(index_t N)
{
return math::integer_divide_floor(N, NPerBlock);
}
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
struct Problem
{
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KPadded{CalculateKPadded(K_)},
AK0{CalculateAK0(K_)},
BK0{CalculateBK0(K_)},
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
{
}
__host__ void Print() const
{
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t AK0;
index_t BK0;
index_t MBlock;
index_t NBlock;
};
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const FloatA* p_a_grid_,
const FloatB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
}
const FloatA* p_a_grid;
const FloatB* p_b_grid;
FloatC* p_c_grid;
};
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
}
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
}
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
b_block_space_size_aligned * sizeof(ComputeTypeB)),
c_block_size * sizeof(FloatCShuffle));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Problem& problem)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(problem.M % MPerBlock == 0))
{
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(problem.N % NPerBlock == 0))
{
return false;
}
}
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding)
{
if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
!(CalculateKPadded(problem.K) % BK1Value == 0))
{
return false;
}
}
else
{
if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(problem.K % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(problem.M % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(problem.N % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(problem.K % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
// check gridwise gemm pipeline
const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
template <typename CGridDesc>
__device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const Problem& problem)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatA,
ComputeTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
ComputeTypeB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeTypeA,
ComputeTypeB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched,
true>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeTypeB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
// gridwise GEMM pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2 // M2 = MPerXdl
)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2,
N3,
N4))), // N2 * N3 * N4 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
N4,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, 1, N2, 1, N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
1,
N2,
1,
N4>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
}
};
} // namespace ck
#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}'
git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' # git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}'
...@@ -14,7 +14,6 @@ TYPED_TEST(TestTranspose, Test1) ...@@ -14,7 +14,6 @@ TYPED_TEST(TestTranspose, Test1)
this->Run(); this->Run();
} }
TYPED_TEST(TestTranpose, Test2) TYPED_TEST(TestTranpose, Test2)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
...@@ -27,4 +26,3 @@ TYPED_TEST(TestTranpose, Test2) ...@@ -27,4 +26,3 @@ TYPED_TEST(TestTranpose, Test2)
this->Run(); this->Run();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment