"...composable_kernel.git" did not exist on "7494c1c611382069243709cd9eab0efb90843c7c"
Commit cbf14ee1 authored by aska-0096's avatar aska-0096
Browse files

tempsave, epilogue optimization for universal gemm done. TODO: mulitpleD epilogue optimization

parent 1ca98e75
...@@ -32,6 +32,7 @@ target_compile_options(example_gemm_xdl_fp8_v3 PRIVATE -mllvm -greedy-reverse-lo ...@@ -32,6 +32,7 @@ target_compile_options(example_gemm_xdl_fp8_v3 PRIVATE -mllvm -greedy-reverse-lo
add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3)
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
target_compile_options(example_gemm_xdl_bf16_v3 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3)
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
......
...@@ -19,7 +19,7 @@ using AElementOp = PassThrough; ...@@ -19,7 +19,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceGemmV2Instance = using DeviceGemmV2Instance =
...@@ -28,15 +28,15 @@ using DeviceGemmV2Instance = ...@@ -28,15 +28,15 @@ using DeviceGemmV2Instance =
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault, PassThrough, PassThrough, PassThrough, GemmDefault,
256, 256,
128, 128, 256, 256,
64, 8, 8, 32, 8, 8,
16, 16, 32, 32,
4, 4, 4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, 2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, 8, 1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
// clang-format on // clang-format on
......
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)
add_example_executable(example_gemm_bilinear_xdl_fp16_v3 gemm_bilinear_xdl_fp16_v3.cpp)
target_compile_options(example_gemm_bilinear_xdl_fp16_v3 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, ck::half_t>(
ck::half_t& e, const float& c, const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * c + beta_ * ck::type_convert<float>(d));
};
float alpha_;
float beta_;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AlphaBetaAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
ALayout, BLayout, DsLayout, ELayout,
ADataType, BDataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmDefault,
256,
128, 128,
64, 8, 8,
16, 16,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideD = 4096;
ck::index_t StrideE = 4096;
float alpha = 1.0f;
float beta = 1.0f;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
alpha = std::stof(argv[4]);
beta = std::stof(argv[5]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
alpha = std::stof(argv[11]);
beta = std::stof(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, "
"beta\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 1>{StrideD},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
...@@ -169,7 +169,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave, ...@@ -169,7 +169,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
typename BGridBuffer, typename BGridBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename CThreadBuffer> typename CThreadBuffer,
typename AThreadBuffer,
typename BThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc, __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_blockwise_copy,
...@@ -183,6 +185,8 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave, ...@@ -183,6 +185,8 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
AThreadBuffer& a_thread_buf_tail,
BThreadBuffer& b_thread_buf_tail,
index_t num_loop) const index_t num_loop) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
...@@ -298,37 +302,11 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave, ...@@ -298,37 +302,11 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
}); });
}); });
static_for<0, KRepeat, 1>{}([&](auto k0) { a_thread_buf_tail = a_thread_buf;
static_for<0, MRepeat, 1>{}([&](auto m0) { b_thread_buf_tail = b_thread_buf;
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
} }
} }
protected:
using Base::a_thread_copy_; using Base::a_thread_copy_;
using Base::a_thread_desc_; using Base::a_thread_desc_;
using Base::b_thread_copy_; using Base::b_thread_copy_;
...@@ -470,7 +448,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave, ...@@ -470,7 +448,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
typename BGridBuffer, typename BGridBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename CThreadBuffer> typename CThreadBuffer,
typename AThreadBuffer,
typename BThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc, __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_blockwise_copy,
...@@ -484,6 +464,8 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave, ...@@ -484,6 +464,8 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
AThreadBuffer& a_thread_buf_tail,
BThreadBuffer& b_thread_buf_tail,
index_t num_loop) const index_t num_loop) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
...@@ -630,64 +612,13 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave, ...@@ -630,64 +612,13 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
b_thread_buf); b_thread_buf);
}); });
}); });
__builtin_amdgcn_sched_barrier(0);
if constexpr(k0.value != 0 || KRepeat == 1)
{
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
}); });
a_thread_buf_tail = a_thread_buf;
b_thread_buf_tail = b_thread_buf;
} }
} }
protected:
// K->M loopover // K->M loopover
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}), make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
......
...@@ -208,7 +208,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -208,7 +208,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
typename BGridBuffer, typename BGridBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename CThreadBuffer> typename CThreadBuffer,
typename AThreadBuffer,
typename BThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc, __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_blockwise_copy,
...@@ -222,6 +224,8 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -222,6 +224,8 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
AThreadBuffer& a_thread_buf_tail,
BThreadBuffer& b_thread_buf_tail,
index_t num_loop) const index_t num_loop) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
...@@ -407,33 +411,8 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -407,33 +411,8 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
}); });
}); });
static_for<0, KRepeat, 1>{}([&](auto k0) { a_thread_buf_tail = a_thread_buf;
static_for<0, MRepeat, 1>{}([&](auto m0) { b_thread_buf_tail = b_thread_buf;
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}; };
if constexpr(TailNum == TailNumber::One) if constexpr(TailNum == TailNumber::One)
...@@ -458,33 +437,8 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -458,33 +437,8 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
}); });
}); });
static_for<0, KRepeat, 1>{}([&](auto k0) { a_thread_buf_tail = a_thread_buf;
static_for<0, MRepeat, 1>{}([&](auto m0) { b_thread_buf_tail = b_thread_buf;
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
} }
else if constexpr(TailNum == TailNumber::Two) else if constexpr(TailNum == TailNumber::Two)
{ {
...@@ -516,7 +470,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave, ...@@ -516,7 +470,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
} }
} }
protected:
using Base::a_thread_copy_; using Base::a_thread_copy_;
using Base::a_thread_desc_; using Base::a_thread_desc_;
using Base::b_thread_copy_; using Base::b_thread_copy_;
...@@ -699,7 +652,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -699,7 +652,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
typename BGridBuffer, typename BGridBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename CThreadBuffer> typename CThreadBuffer,
typename AThreadBuffer,
typename BThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc, __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_blockwise_copy,
...@@ -713,6 +668,8 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -713,6 +668,8 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
AThreadBuffer& a_thread_buf_tail,
BThreadBuffer& b_thread_buf_tail,
index_t num_loop) const index_t num_loop) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
...@@ -951,60 +908,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -951,60 +908,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
b_thread_buf); b_thread_buf);
}); });
}); });
__builtin_amdgcn_sched_barrier(0);
if constexpr(k0.value != 0 || KRepeat == 1)
{
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
}); });
a_thread_buf_tail = a_thread_buf;
b_thread_buf_tail = b_thread_buf;
}; };
if constexpr(TailNum == TailNumber::One) if constexpr(TailNum == TailNumber::One)
...@@ -1027,60 +934,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -1027,60 +934,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
b_thread_buf); b_thread_buf);
}); });
}); });
__builtin_amdgcn_sched_barrier(0);
if constexpr(k0.value != 0 || KRepeat == 1)
{
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
}); });
a_thread_buf_tail = a_thread_buf;
b_thread_buf_tail = b_thread_buf;
} }
else if constexpr(TailNum == TailNumber::Two) else if constexpr(TailNum == TailNumber::Two)
{ {
...@@ -1112,7 +969,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave, ...@@ -1112,7 +969,6 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
} }
} }
protected:
// K->M loopover // K->M loopover
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}), make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
......
...@@ -425,40 +425,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -425,40 +425,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
{ {
a_thread_buf_tail = a_thread_buf; a_thread_buf_tail = a_thread_buf;
b_thread_buf_tail = b_thread_buf; b_thread_buf_tail = b_thread_buf;
#if 0
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
#endif
// __builtin_amdgcn_sched_barrier(0);
} }
} }
// protected:
using Base::a_thread_copy_; using Base::a_thread_copy_;
using Base::a_thread_desc_; using Base::a_thread_desc_;
using Base::b_thread_copy_; using Base::b_thread_copy_;
......
...@@ -249,7 +249,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -249,7 +249,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
typename BGridBuffer, typename BGridBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename CThreadBuffer> typename CThreadBuffer,
typename AThreadBuffer,
typename BThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc, __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_blockwise_copy,
...@@ -263,6 +265,8 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -263,6 +265,8 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
AThreadBuffer& a_thread_buf_tail,
BThreadBuffer& b_thread_buf_tail,
index_t num_loop) const index_t num_loop) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
...@@ -539,33 +543,8 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -539,33 +543,8 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
}; };
auto CompFunc = [&](auto mfma_reg_buf) { auto CompFunc = [&](auto mfma_reg_buf) {
static_for<0, KRepeat, 1>{}([&](auto k0) { a_thread_buf_tail = a_thread_bufs[mfma_reg_buf];
static_for<0, MRepeat, 1>{}([&](auto m0) { b_thread_buf_tail = b_thread_bufs[mfma_reg_buf];
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}; };
// tail // tail
if constexpr(TailNum == TailNumber::Odd) if constexpr(TailNum == TailNumber::Odd)
...@@ -583,7 +562,6 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave, ...@@ -583,7 +562,6 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
} }
} }
protected:
using Base::a_thread_copy_; using Base::a_thread_copy_;
using Base::a_thread_desc_; using Base::a_thread_desc_;
using Base::b_thread_copy_; using Base::b_thread_copy_;
......
...@@ -330,7 +330,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -330,7 +330,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
typename BGridBuffer, typename BGridBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename CThreadBuffer> typename CThreadBuffer,
typename AThreadBuffer,
typename BThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc, __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_blockwise_copy,
...@@ -344,12 +346,14 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -344,12 +346,14 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
AThreadBuffer& a_thread_buf_tail,
BThreadBuffer& b_thread_buf_tail,
index_t num_loop) const index_t num_loop) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_loop.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_loop.GetElementSpaceSize());
// Global prefetch 1 // Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
...@@ -382,18 +386,18 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -382,18 +386,18 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1 // Local prefetch 1
block_sync_lds(); block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_loop.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0), make_tuple(m0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_loop,
make_tuple(m0, I0, I0, I0), make_tuple(m0, I0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_loop.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0), make_tuple(n0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_loop,
make_tuple(n0, I0, I0, I0), make_tuple(n0, I0, I0, I0),
b_thread_buf); b_thread_buf);
}); });
...@@ -428,12 +432,12 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -428,12 +432,12 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_loop.CalculateOffset(
make_tuple(m0, I0, I0, ik))>{}]; make_tuple(m0, I0, I0, ik))>{}];
}); });
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_loop.CalculateOffset(
make_tuple(n0, I0, I0, ik))>{}]; make_tuple(n0, I0, I0, ik))>{}];
}); });
...@@ -450,21 +454,21 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -450,21 +454,21 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run( a_thread_copy_loop.Run(
a_block_desc_m0_m1_m2_k, a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_loop,
make_tuple(m0, I0, I0, I0), make_tuple(m0, I0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run( b_thread_copy_loop.Run(
b_block_desc_n0_n1_n2_k, b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_loop,
make_tuple(n0, I0, I0, I0), make_tuple(n0, I0, I0, I0),
b_thread_buf); b_thread_buf);
}); });
...@@ -498,12 +502,12 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -498,12 +502,12 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_loop.CalculateOffset(
make_tuple(m0, I0, I0, ik))>{}]; make_tuple(m0, I0, I0, ik))>{}];
}); });
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_loop.CalculateOffset(
make_tuple(n0, I0, I0, ik))>{}]; make_tuple(n0, I0, I0, ik))>{}];
}); });
...@@ -517,21 +521,21 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -517,21 +521,21 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run( a_thread_copy_loop.Run(
a_block_desc_m0_m1_m2_k, a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_loop,
make_tuple(m0, I0, I0, I0), make_tuple(m0, I0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run( b_thread_copy_loop.Run(
b_block_desc_n0_n1_n2_k, b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_loop,
make_tuple(n0, I0, I0, I0), make_tuple(n0, I0, I0, I0),
b_thread_buf); b_thread_buf);
}); });
...@@ -540,74 +544,25 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -540,74 +544,25 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
HotLoopScheduler(); HotLoopScheduler();
}; };
auto ReadCompFunc = [&]() { auto ReadCompFunc = [&]() {
vector_type<ComputeDataType, KPack> a_thread_vec; static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KRepeat - 1, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, ik))>{}];
});
static_for<0, KPack, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, I0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k, a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, I0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf_tail);
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k, b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k0* BMmaKStride>{}),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, I0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf_tail);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, I0, ik))>{}];
});
static_for<0, KPack, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, I0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
...@@ -627,19 +582,18 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -627,19 +582,18 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
} }
} }
protected:
// A[MRepeat, I1, I1, KPack] // A[MRepeat, I1, I1, KPack]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_loop =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, I1, I1, Number<KPack>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, I1, I1, Number<KPack>{}));
// B[NRepeat, N1, N2, KPack] // B[NRepeat, N1, N2, KPack]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_loop =
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, I1, I1, Number<KPack>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, I1, I1, Number<KPack>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType, ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k), decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_), decltype(a_thread_desc_loop),
Sequence<1, 1, 1, KPack>, Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
...@@ -649,15 +603,19 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave, ...@@ -649,15 +603,19 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeDataType, ComputeDataType,
decltype(b_block_desc_n0_n1_n2_k), decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_), decltype(b_thread_desc_loop),
Sequence<1, 1, 1, KPack>, Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
B_K1, B_K1,
B_K1>; B_K1>;
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; AThreadCopy a_thread_copy_loop{Base::CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; BThreadCopy b_thread_copy_loop{Base::CalculateBThreadOriginDataIndex()};
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_copy_;
using Base::b_thread_desc_;
using Base::c_thread_desc_; using Base::c_thread_desc_;
}; };
......
...@@ -198,7 +198,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -198,7 +198,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}; };
constexpr index_t minimum_occupancy = constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; (MPerBlock * NPerBlock / BlockSize <= 128) ? 2
: (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave) ? 1
: 2;
if(has_main_k_block_loop) if(has_main_k_block_loop)
{ {
......
...@@ -1176,6 +1176,38 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1176,6 +1176,38 @@ struct GridwiseGemm_xdl_cshuffle_v3
return c_grid_desc_mblock_mperblock_nblock_nperblock; return c_grid_desc_mblock_mperblock_nblock_nperblock;
} }
__device__ static constexpr auto EpilogueScheduler()
{
constexpr auto epilogue_tile = MPerBlock * NPerBlock * CShuffleMXdlPerWavePerShuffle *
CShuffleNXdlPerWavePerShuffle / (MXdlPerWave * NXdlPerWave);
constexpr auto num_mfma_inst = BlockwiseGemmPipe::HotLoopInstList::C_MFMA_Inst_Num *
CShuffleMXdlPerWavePerShuffle *
CShuffleNXdlPerWavePerShuffle / (MXdlPerWave * NXdlPerWave);
constexpr auto num_ds_write_inst =
epilogue_tile / BlockSize; // DefaultMFMA, per-element write
constexpr auto num_ds_read_inst =
epilogue_tile / BlockSize / CShuffleBlockTransferScalarPerVector_NPerBlock;
constexpr auto num_buffer_store_inst = num_ds_read_inst;
// MFMA:ds_write=1:2
constexpr auto num_ds_write_issue = num_ds_write_inst / 2;
constexpr auto num_mfma_block_sync = (num_mfma_inst - num_ds_write_issue) / 2;
constexpr auto mfma_ds_write_rate = MXdlPerWave==16 ? 2 : 4;
// Hide ds_write issue latency
static_for<0, num_ds_write_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, mfma_ds_write_rate, 0); // DS write
});
// Hide block_sync + ds_read latency
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_block_sync, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, num_ds_read_inst, 0); // DS read
// Hide block_sync latency
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_block_sync, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x040, num_buffer_store_inst, 0); // VMEM write
}
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942 // if arch = gfx942
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
...@@ -1323,6 +1355,14 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1323,6 +1355,14 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>); static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
constexpr auto a_thread_desc = blockwise_gemm_pipeline.a_thread_desc_;
constexpr auto b_thread_desc = blockwise_gemm_pipeline.b_thread_desc_;
constexpr auto c_thread_desc = blockwise_gemm_pipeline.c_thread_desc_;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
b_thread_desc.GetElementSpaceSize());
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
...@@ -1341,14 +1381,21 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1341,14 +1381,21 @@ struct GridwiseGemm_xdl_cshuffle_v3
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
c_thread_buf, c_thread_buf,
a_thread_buf,
b_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out
{ {
// Last block MFMA
auto xdlops_gemm = blockwise_gemm_pipeline.xdlops_gemm;
constexpr auto KRepeat = blockwise_gemm_pipeline.KRepeat;
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!"); "wrong!");
// Shuffle
// 1. Copy data from VGPR to LDS
// 2. Copy data from LDS to VGPR
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
...@@ -1356,8 +1403,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1356,8 +1403,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -1379,19 +1424,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1379,19 +1424,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple( make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0), make_unmerge_transform(
make_unmerge_transform(make_tuple( make_tuple(Number<CShuffleMXdlPerWavePerShuffle>{}, M1, M2, M3, M4)),
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle make_freeze_transform(I0),
M1, // M1 = MWave make_unmerge_transform(
M2, // M2 * M3 * M4 = MPerXdl make_tuple(Number<CShuffleNXdlPerWavePerShuffle>{}, N1, N2))),
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple( make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
...@@ -1458,31 +1496,31 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1458,31 +1496,31 @@ struct GridwiseGemm_xdl_cshuffle_v3
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup ThisThreadBlock,
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation,
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>,
CShuffleDataType, // typename SrcData, CShuffleDataType,
CDataType, // typename DstData, CDataType,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>,
3, // index_t VectorDim, 3,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, CShuffleBlockTransferScalarPerVector_NPerBlock,
true, // bool ThreadTransferSrcResetCoordinateAfterRun, true,
false> // bool ThreadTransferDstResetCoordinateAfterRun> false>{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, make_multi_index(0, 0, 0, 0),
make_multi_index(0, 0, 0, 0), c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock, make_multi_index(block_m_id, 0, block_n_id, 0),
make_multi_index(block_m_id, 0, block_n_id, 0), c_element_op};
c_element_op};
// SpaceFillingCurve tocombine all components
// space filling curve for threadwise C in VGPR // C: VGPR to LDS
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
...@@ -1495,7 +1533,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1495,7 +1533,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
M4, M4,
1>>{}; 1>>{};
// space filling curve for shuffled blockwise C in global mem constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// Shuffled C: VGPR to Global
constexpr auto sfc_c_global = constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>, SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>, Sequence<0, 2, 1, 3>,
...@@ -1504,22 +1544,84 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1504,22 +1544,84 @@ struct GridwiseGemm_xdl_cshuffle_v3
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds(); block_sync_lds();
// each thread write its data from VGPR to LDS if constexpr(access_id < num_access - 1)
{
constexpr auto shuffle_m0 =
sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<0>{}];
constexpr auto shuffle_n0 =
sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<1>{}];
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf, c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf); c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds(); block_sync_lds();
// each block copy its data from LDS to global // each block copy its data from LDS to global
...@@ -1533,9 +1635,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1533,9 +1635,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
EpilogueScheduler();
} }
}); });
} }
...@@ -1729,6 +1832,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1729,6 +1832,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
constexpr auto a_thread_desc = blockwise_gemm_pipeline.a_thread_desc_;
constexpr auto b_thread_desc = blockwise_gemm_pipeline.b_thread_desc_;
constexpr auto c_thread_desc = blockwise_gemm_pipeline.c_thread_desc_;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
b_thread_desc.GetElementSpaceSize());
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
...@@ -1746,14 +1858,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1746,14 +1858,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
b_block_bufs, b_block_bufs,
b_block_slice_copy_step, b_block_slice_copy_step,
c_thread_buf, c_thread_buf,
a_thread_buf,
b_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out
{ {
// Last block MFMA
auto xdlops_gemm = blockwise_gemm_pipeline.xdlops_gemm;
constexpr auto KRepeat = blockwise_gemm_pipeline.KRepeat;
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!"); "wrong!");
// Shuffle
// 1. Copy data from VGPR to LDS
// 2. Copy data from LDS to VGPR
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
...@@ -1761,8 +1879,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1761,8 +1879,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -1784,19 +1900,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1784,19 +1900,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple( make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0), make_unmerge_transform(
make_unmerge_transform(make_tuple( make_tuple(Number<CShuffleMXdlPerWavePerShuffle>{}, M1, M2, M3, M4)),
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle make_freeze_transform(I0),
M1, // M1 = MWave make_unmerge_transform(
M2, // M2 * M3 * M4 = MPerXdl make_tuple(Number<CShuffleNXdlPerWavePerShuffle>{}, N1, N2))),
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple( make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
...@@ -1863,31 +1972,31 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1863,31 +1972,31 @@ struct GridwiseGemm_xdl_cshuffle_v3
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup ThisThreadBlock,
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation,
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>,
CShuffleDataType, // typename SrcData, CShuffleDataType,
CDataType, // typename DstData, CDataType,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>,
3, // index_t VectorDim, 3,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, CShuffleBlockTransferScalarPerVector_NPerBlock,
true, // bool ThreadTransferSrcResetCoordinateAfterRun, true,
false> // bool ThreadTransferDstResetCoordinateAfterRun> false>{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, make_multi_index(0, 0, 0, 0),
make_multi_index(0, 0, 0, 0), c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock, make_multi_index(block_m_id, 0, block_n_id, 0),
make_multi_index(block_m_id, 0, block_n_id, 0), c_element_op};
c_element_op};
// SpaceFillingCurve tocombine all components
// space filling curve for threadwise C in VGPR // C: VGPR to LDS
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
...@@ -1900,7 +2009,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1900,7 +2009,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
M4, M4,
1>>{}; 1>>{};
// space filling curve for shuffled blockwise C in global mem constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// Shuffled C: VGPR to Global
constexpr auto sfc_c_global = constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>, SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>, Sequence<0, 2, 1, 3>,
...@@ -1909,22 +2020,84 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1909,22 +2020,84 @@ struct GridwiseGemm_xdl_cshuffle_v3
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds(); block_sync_lds();
// each thread write its data from VGPR to LDS if constexpr(access_id < num_access - 1)
{
constexpr auto shuffle_m0 =
sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<0>{}];
constexpr auto shuffle_n0 =
sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<1>{}];
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf, c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf); c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds(); block_sync_lds();
// each block copy its data from LDS to global // each block copy its data from LDS to global
...@@ -1938,9 +2111,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1938,9 +2111,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
EpilogueScheduler();
} }
}); });
} }
......
...@@ -1421,119 +1421,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1421,119 +1421,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
a_thread_buf, a_thread_buf,
b_thread_buf, b_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out
{ {
#if 0 // Copy MultipleD data from HBM to VGPR
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
using EDataType = CDataType;
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
...@@ -1547,84 +1437,50 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1547,84 +1437,50 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
auto ds_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<DsDataType,
DsDataType,
decltype(ds_grid_desc_),
decltype(ds_thread_desc_),
Sequence<Number<KWmmaPerBlock>{},
Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
CDEShuffleBlockTransferScalarPerVectors,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_grid_desc,
make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
// tuple of reference to C/Ds tensor descriptors // Last block MFMA
const auto c_ds_desc_refs = concat_tuple_of_reference( auto xdlops_gemm = blockwise_gemm_pipeline.xdlops_gemm;
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), constexpr auto KRepeat = blockwise_gemm_pipeline.KRepeat;
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);;
using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
ThisThreadBlock, NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), "wrong!");
Tuple<EDataType>, // Shuffle
decltype(c_ds_desc_refs), // 1. Copy data from VGPR to LDS
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), // 2. Copy data from LDS to VGPR
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors,
CShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
c_element_op};
#endif
// copy multipled from global to vgpr
auto d_threadwise_copy;
// copy c from vgpr to lds
// TODO: Avoid bankconflict. 2 for mfma16x16, 0 for mfma32x32.
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// C SrcDesc in VGPR // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// C DstDesc in LDS
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -1656,6 +1512,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1656,6 +1512,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
make_tuple( make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
...@@ -1713,125 +1571,114 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1713,125 +1571,114 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
m_thread_data_on_block_idx[I4], m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// copy c from lds to vgpr
auto c_threadwise_copy_lds_to_vgpr;
// copy e from vgpr to global
constexpr auto n_vec = CDEShuffleBlockTransferScalarPerVectors.At(Number<0>{});
constexpr auto n_thread = NPerBlock / n_vec / NRepeat;
constexpr auto m_thread = BlockSize / n_thread;
constexpr auto m_thread_repeat = MPerBlock / MRepeat / m_thread;
auto c_thread_desc_coalescing = make_naive_tensor_descriptor_packed(
make_tuple(I1, MRepeat, m_thread_repeat, I1, I1, NRepeat, I1, n_vec));
auto c_block_desc_coalescing = make_naive_tensor_descriptor_packed(
make_tuple(I1, MRepeat, m_thread_repeat, m_thread, I1, NRepeat, n_thread, n_vec));
auto c_thread_copy_lds_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<CShuffleDataType,
CShuffleDataType,
decltype(c_block_desc_coalescing),
decltype(c_thread_desc_coalescing),
Sequence<I1,
CShuffleMXdlPerWavePerShuffle,
m_thread_repeat,
I1,
I1,
CShuffleMXdlPerWavePerShuffle,
I1,
n_vec>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
6,
ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_grid_desc,
make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
auto e_grid_desc_coalescing = // shuffle: blockwise copy C from LDS to global
make_naive_tensor_descriptor_packed(make_tuple(problem.MBlock, auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
MRepeat, ThisThreadBlock,
m_thread_repeat, CElementwiseOperation,
m_thread, CGlobalMemoryDataOperation,
problem.NBlock, Sequence<1,
NRepeat, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
n_thread, 1,
n_vec)); CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
auto e_threadwise_copy = Sequence<0, 1, 2, 3>,
ThreadwiseTensorSliceTransfer_v1r3<EDataType, CShuffleDataType,
EDataType, CDataType,
decltype(c_thread_desc_coalescing), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_coalescing), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
ck::tensor_operation::element_wise::PassThrough, Sequence<0, 1, 2, 3>,
Sequence<I1, 3,
CShuffleMXdlPerWavePerShuffle, CShuffleBlockTransferScalarPerVector_NPerBlock,
m_thread_repeat, true,
I1, false>{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
I1, make_multi_index(0, 0, 0, 0),
CShuffleMXdlPerWavePerShuffle, c_grid_desc_mblock_mperblock_nblock_nperblock,
I1, make_multi_index(block_m_id, 0, block_n_id, 0),
n_vec>, c_element_op};
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // SpaceFillingCurve tocombine all components
n_vec, // C: VGPR to LDS
CGlobalMemoryDataOperation, constexpr auto sfc_c_vgpr =
1, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
true>{ Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
e_grid_desc_coalescing, Sequence<CShuffleMXdlPerWavePerShuffle,
make_multi_index(block_m_id, CShuffleNXdlPerWavePerShuffle,
I0, 1,
I0, 1,
get_thread_local_1d_id() / n_thread, M2,
block_n_id, 1,
I0, M4,
get_thread_local_1d_id() % n_thread, 1>>{};
I0),
ck::tensor_operation::element_wise::PassThrough{}};
auto xdlops_gemm = blockwise_gemm_pipeline.xdlops_gemm; constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
constexpr auto MRepeat = MXdlPerWave;
constexpr auto NRepeat = NXdlPerWave;
constexpr auto KRepeat = blockwise_gemm_pipeline.KRepeat;
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && // Shuffled C: VGPR to Global
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, constexpr auto sfc_c_global =
"wrong!"); SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, MRepeat / CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto shuffle_m0) { static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, NRepeat / CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto shuffle_n0) { static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
// MutilpeD bufferload
d_threadwise_copy.Run(c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf));
// Tail MFMA
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { vector_type<ComputeTypeA, KPack> a_thread_vec;
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, num_access, 1>{}([&](auto access_id) {
ds_thread_copy_global_to_vgpr.Run();
block_sync_lds();
if constexpr(access_id < num_access - 1)
{
constexpr auto shuffle_m0 =
sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<0>{}];
constexpr auto shuffle_n0 =
sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<1>{}];
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec; vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeA, KPack> b_thread_vec; vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) = a_thread_buf a_thread_vec.template AsType<ComputeTypeA>()(ik) =
[Number<a_thread_desc.CalculateOffset(make_tuple( a_thread_buf[Number<a_thread_desc.CalculateOffset(
shuffle_m0 * CShuffleMXdlPerWavePerShuffle + m0, make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
I0, b_thread_vec.template AsType<ComputeTypeA>()(ik) =
k0, b_thread_buf[Number<b_thread_desc.CalculateOffset(
ik))>{}]; make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) = b_thread_buf
[Number<b_thread_desc.CalculateOffset(make_tuple(
shuffle_n0 * CShuffleNXdlPerWavePerShuffle + n0,
I0,
k0,
ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
...@@ -1839,9 +1686,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1839,9 +1686,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset( constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 * CShuffleMXdlPerWavePerShuffle + m0, make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
shuffle_n0 * CShuffleNXdlPerWavePerShuffle + n0,
0));
xdlops_gemm.Run( xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
...@@ -1850,38 +1695,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1850,38 +1695,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}); });
}); });
}); });
// Shuffle: DS_WRITE }
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(Number<shuffle_m0 * MRepeat>{},
Number<shuffle_n0 * NRepeat>{},
Number<0>{},
Number<0>{},
Number<0>{},
Number<0>{},
Number<0>{},
Number<0>{}),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
block_sync_lds();
// Shuffle: DS_READ
c_thread_copy_lds_to_vgpr.Run();
cde_element();
e_threadwise_copy.Run(c_thread_desc_coalescing,
make_tuple(Number<0>{},
Number<0>{},
Number<0>{},
Number<0>{},
Number<0>{},
Number<0>{},
Number<0>{},
Number<0>{}),
e_thread_buf,
e_grid_desc_coalescing,
c_grid_buf);
// move e_grid desc slice origin
}); c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
ds_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
EpilogueScheduler();
}
}); });
} }
} }
......
...@@ -389,6 +389,7 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -389,6 +389,7 @@ struct ThreadwiseTensorSliceTransfer_v2
SrcCoord src_coord_; SrcCoord src_coord_;
}; // namespace ck }; // namespace ck
#if 0
// Multiple DynamicBuffer to multiple StaticBuffer // Multiple DynamicBuffer to multiple StaticBuffer
// Assume: // Assume:
// 1. src: // 1. src:
...@@ -563,44 +564,29 @@ struct ThreadwiseTensorSliceTransfer_v2r1 ...@@ -563,44 +564,29 @@ struct ThreadwiseTensorSliceTransfer_v2r1
} }
} }
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason // src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack> template <index_t ISrc>
__device__ void __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
MoveSrcSliceWindow(const SrcDesc& src_desc, Number<ISrc> iSrc,
const Index& src_slice_origin_step_idx, const Index& src_slice_origin_step_idx)
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{ {
// if src coord was not reset by RunRead(), then need to adjust the step here // if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx SrcResetCoordinateAfterRunFlags::At(iSrc)
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step( const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx);
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step);
} }
private: private:
SrcCoord src_coord_; SrcCoords src_coords_;
}; // namespace ck }; // namespace ck
#endif
// Assume: // Assume:
// 1. src_desc and dst_desc are not known at compile-time // 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer // 2. SrcBuffer and DstBuffer are DynamicBuffer
......
...@@ -47,7 +47,7 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -47,7 +47,7 @@ if(GPU_TARGETS MATCHES "gfx9")
# endif() # endif()
# list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
# if(GPU_TARGETS MATCHES "gfx94") # if(GPU_TARGETS MATCHES "gfx94")
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
# endif() # endif()
# list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) # list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
...@@ -131,7 +131,7 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -131,7 +131,7 @@ if(GPU_TARGETS MATCHES "gfx9")
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
# if(GPU_TARGETS MATCHES "gfx94") # if(GPU_TARGETS MATCHES "gfx94")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
# endif() # endif()
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment