Commit 66e61076 authored by aska-0096's avatar aska-0096
Browse files

Sanity pass.

parent 0c51a35e
......@@ -9,7 +9,7 @@ using ADataType = ck::half_t;
using BDataType = int8_t;
using ScaleDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = float;
using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
using ALayout = Row;
......
......@@ -28,7 +28,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
// assume scale tensor is [1, n]
Tensor<ScaleDataType> scale_k_n(f_host_tensor_descriptor(K, N, 0, BLayout{}));
Tensor<ScaleDataType> scale_k_n(f_host_tensor_descriptor(K, N, 0, Row{}));
switch(config.init_method)
{
......@@ -51,7 +51,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{1.f, 1.f}(scale_k_n);
ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{2.f, 2.f}(scale_k_n);
break;
case 5:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k);
......@@ -64,6 +64,50 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistribution<ScaleDataType>{-1.f, 1.f}(scale_k_n);
}
#if 0
printf("Matrix A:\n");
for (int im = 0; im < M; im++)
{
for (int ik = 0; ik < K; ik++)
{
if(ik % 16 == 0){
printf("|");
}
printf(" %04x", *(reinterpret_cast<uint16_t*>(&a_m_k(im,ik))));
}
printf("\n");
}
printf("Matrix B:\n");
for (int in = 0; in < N; in++)
{
for (int ik = 0; ik < K; ik++)
{
if(ik % 16 == 0){
printf("|");
}
printf(" %02x", b_k_n(ik,in));
}
printf("\n");
}
printf("Matrix Scale:\n");
for (int in = 0; in < N; in++)
{
for (int ik = 0; ik < K; ik++)
{
if(ik % 16 == 0){
printf("|");
}
printf(" %04x", *(reinterpret_cast<uint16_t*>(&scale_k_n(ik,in))));
}
printf("\n");
}
#endif
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
......
......@@ -309,7 +309,8 @@ struct Blockwise_fpAintB_GemmWMMA
b_thread_desc_.GetElementSpaceSize());
auto scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ScaleDataType>(
scale_thread_desc_.GetElementSpaceSize());
auto converted_b_thread_buf = b_thread_buf;
auto converted_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
b_thread_desc_.GetElementSpaceSize());
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
......@@ -345,7 +346,7 @@ struct Blockwise_fpAintB_GemmWMMA
scale_thread_buf);
// convert B from int8 to fp16, multiply scale
static_for<0, b_thread_buf.size(), 1>{}([&](auto i) {
static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) {
converted_b_thread_buf(i) =
scale_thread_buf[i / WmmaK] *
type_convert<ADataType>(b_thread_buf[i]);
......@@ -390,6 +391,20 @@ struct Blockwise_fpAintB_GemmWMMA
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read weight scale
scale_thread_copy_.Run(
scale_block_desc_1_n0_n1_n2_1,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_block_buf,
scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_thread_buf);
#if 0
printf("Tid: %03d, n: %02d, scale_thread_buf: %04x\n",
get_thread_local_1d_id(), n0.value,
*(reinterpret_cast<const uint16_t*>(&scale_thread_buf[n0]))
);
#endif
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, ..
......@@ -400,16 +415,7 @@ struct Blockwise_fpAintB_GemmWMMA
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read weight scale
scale_thread_copy_.Run(
scale_block_desc_1_n0_n1_n2_1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
scale_block_buf,
scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_thread_buf);
b_thread_buf);
// convert B from int8 to fp16, multiply scale
static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) {
converted_b_thread_buf(i) = scale_thread_buf[i / WmmaK] *
......@@ -423,7 +429,71 @@ struct Blockwise_fpAintB_GemmWMMA
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
if (true){
#if 0
printf("Tid: %03d, m, n, k: %02d, %02d, %02d, a_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n",
get_thread_local_1d_id(), m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<0>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<1>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<2>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<3>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<4>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<5>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<6>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<7>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<8>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<9>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<10>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<11>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<12>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<13>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<14>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<15>{}]))
);
#endif
#if 0
printf("Tid: %03d, m, n, k: %02d, %02d, %02d, b_thread_buf: %02x %02x %02x %02x| %02x %02x %02x %02x| %02x %02x %02x %02x| %02x %02x %02x %02x|\n",
get_thread_local_1d_id(), m0.value, n0.value, k.value,
b_thread_buf[Number<0>{}],
b_thread_buf[Number<1>{}],
b_thread_buf[Number<2>{}],
b_thread_buf[Number<3>{}],
b_thread_buf[Number<4>{}],
b_thread_buf[Number<5>{}],
b_thread_buf[Number<6>{}],
b_thread_buf[Number<7>{}],
b_thread_buf[Number<8>{}],
b_thread_buf[Number<9>{}],
b_thread_buf[Number<10>{}],
b_thread_buf[Number<11>{}],
b_thread_buf[Number<12>{}],
b_thread_buf[Number<13>{}],
b_thread_buf[Number<14>{}],
b_thread_buf[Number<15>{}]
);
#endif
#if 0
printf("Tid: %03d, m, n, k: %02d, %02d, %02d, converted_b_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n",
get_thread_local_1d_id(), m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<0>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<1>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<2>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<3>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<4>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<5>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<6>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<7>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<8>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<9>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<10>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<11>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<12>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<13>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<14>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<15>{}]))
);
#endif
}
vector_type<ADataType, WmmaK> a_thread_vec;
vector_type<ADataType, WmmaK> b_thread_vec;
......@@ -497,7 +567,7 @@ struct Blockwise_fpAintB_GemmWMMA
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
I1),
make_tuple(I0, I1, I0, I0, I0, I0));
// C[M, N, NumRegWMMA]
......@@ -587,11 +657,11 @@ struct Blockwise_fpAintB_GemmWMMA
ScaleDataType,
decltype(scale_block_desc_1_n0_n1_n2_1),
decltype(scale_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, 1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
1,
1>;
};
template <>
......
......@@ -182,8 +182,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
// When K = 1, it might be scale tensor.
assert(K % K1 == 0 && K != 1);
assert(K % K1 == 0);
if constexpr(BEnableLds)
{
......@@ -216,6 +215,52 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
}
}
static auto MakeScaleGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB = 0)
{
// assume Scale is [1, N]
const auto scale_grid_desc_n_k = [&]() {
const auto scale_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
return matrix_padder.PadBDescriptor_N_K(scale_grid_desc_nraw_kraw);
}();
const auto N = scale_grid_desc_n_k.GetLength(I0);
const auto K = scale_grid_desc_n_k.GetLength(I1);
// When K = 1, it might be scale tensor.
assert(K % K1 == 0 && K != 1);
if constexpr(BEnableLds)
{
const index_t K0 = K / K1;
return transform_tensor_descriptor(
scale_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, 1)), // Reduce K1 = 1
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto B_KRow = 2;
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor(
scale_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
......@@ -237,7 +282,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
// Gridwise descriptor, mapping to whole given provblem.
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
using ScaleGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
using ScaleGridDesc = decltype(MakeScaleGridDescriptor(1, 1, 0));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
......@@ -330,7 +375,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{
a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB);
scale_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, 0);
scale_grid_desc_ = DeviceOp::MakeScaleGridDescriptor(K, N, 0);
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
......
......@@ -52,6 +52,12 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
if (false && get_thread_local_1d_id()==0){
printf("lds_size: %lu\n", GridwiseGemm::SharedMemTrait::lds_size);
printf("lds_a_size: %d\n", GridwiseGemm::SharedMemTrait::a_block_space_size_aligned);
printf("lds_b_size: %d\n", GridwiseGemm::SharedMemTrait::b_block_space_size_aligned);
printf("lds_scale_size: %d\n", GridwiseGemm::SharedMemTrait::scale_block_space_size_aligned);
}
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
......@@ -262,7 +268,7 @@ struct GridwiseFpAintBGemm_Wmma
constexpr auto K0PerBlock = KPerBlock / K1;
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, I1),
make_tuple(I0, I1, I0));
}
else
......@@ -276,7 +282,7 @@ struct GridwiseFpAintBGemm_Wmma
Number<K0PerWmma>{},
I1,
I1,
K1),
I1),
make_tuple(I0, I1, I0, I0, I0, I0, I0));
}
}();
......@@ -424,6 +430,52 @@ struct GridwiseFpAintBGemm_Wmma
return b_wave_desc;
}
template <typename ScaleBlockDesc_>
__host__ __device__ static constexpr auto MakeScaleWaveDescriptor(const ScaleBlockDesc_&)
{
constexpr auto scale_wave_desc = [&]() {
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = ScaleBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = ScaleBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
return transform_tensor_descriptor(
ScaleBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
else
{
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ScaleBlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = ScaleBlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = ScaleBlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = ScaleBlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform
return make_naive_tensor_descriptor(make_tuple(Number<KWmma * K0PerWmma>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(I0,
I1,
I0,
I0,
I0,
I0));
}
}();
return scale_wave_desc;
}
__host__ __device__ static constexpr auto
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
......@@ -590,9 +642,10 @@ struct GridwiseFpAintBGemm_Wmma
: 0;
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned;
static constexpr auto b_block_space_offset =
(a_block_space_offset + a_block_space_size_aligned) * sizeof(ADataType)/sizeof(BDataType);
static constexpr auto scale_block_space_offset =
b_block_space_offset + b_block_space_size_aligned;
(b_block_space_offset + b_block_space_size_aligned) * sizeof(BDataType)/sizeof(ScaleDataType);
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_space_size =
......@@ -753,7 +806,7 @@ struct GridwiseFpAintBGemm_Wmma
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
SharedMemTrait::b_block_space_size_aligned);
// printf("b_lds_offset: %lu\n", SharedMemTrait::b_block_space_offset);
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
......@@ -834,13 +887,15 @@ struct GridwiseFpAintBGemm_Wmma
auto scale_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ScaleDataType*>(p_shared) + SharedMemTrait::scale_block_space_offset,
SharedMemTrait::scale_block_space_size_aligned);
// printf("scale_lds_offset: %lu\n", SharedMemTrait::scale_block_space_offset);
auto scale_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
// Reduce slice length K1 to 1
Sequence<K0PerBlock, NPerBlock, I1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
ScaleDataType,
......@@ -851,10 +906,10 @@ struct GridwiseFpAintBGemm_Wmma
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
1, // no effect
1, // no effect
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>(
......@@ -926,7 +981,7 @@ struct GridwiseFpAintBGemm_Wmma
AccDataType,
decltype(MakeAWaveDescriptor(a_block_desc)),
decltype(MakeBWaveDescriptor(b_block_desc)),
decltype(MakeBWaveDescriptor(scale_block_desc)),
decltype(MakeScaleWaveDescriptor(scale_block_desc)),
MPerBlock,
NPerBlock,
KPerBlock,
......
......@@ -581,9 +581,9 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
typename BBlockTransferStep,
typename ScaleGridDesc,
typename ScaleBlockDesc,
typename ScaleBlockTransfer,
typename ScaleGridBuffer,
typename ScaleBlockBuffer,
typename ScaleBlockTransfer,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
......@@ -658,6 +658,116 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
}
};
template <>
struct GridwiseGemmPipeline_v1_dequant<1, true, false>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename ScaleGridDesc,
typename ScaleBlockDesc,
typename ScaleBlockTransfer,
typename ScaleGridBuffer,
typename ScaleBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const ScaleGridDesc& scale_grid_desc,
const ScaleBlockDesc& scale_block_desc,
ScaleBlockTransfer& scale_blockwise_copy,
const ScaleGridBuffer& scale_grid_buf,
ScaleBlockBuffer& scale_block_buf,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
scale_blockwise_copy.Run(scale_grid_desc, scale_grid_buf, scale_block_desc, b_block_origin_idx, scale_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
block_sync_lds();
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_block_buf = b_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
template <index_t NumPrefetch>
struct GridwiseGemmPipelineInterwave_v1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment