Commit 182e7480 authored by mtgu0705's avatar mtgu0705
Browse files

Split the blockwise pipeline for fp8xint4.

parent 966f9051
...@@ -268,8 +268,8 @@ int main(int argc, char* argv[]) ...@@ -268,8 +268,8 @@ int main(int argc, char* argv[])
expert_ids.savetxt("expert_ids.txt", "int"); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N*K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N*K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
......
...@@ -268,8 +268,8 @@ int main(int argc, char* argv[]) ...@@ -268,8 +268,8 @@ int main(int argc, char* argv[])
expert_ids.savetxt("expert_ids.txt", "int"); expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk*K, K, 1})); Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk*K, K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N*K, 1, N}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N*K, 1, N}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace ck {
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPacks>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1
{
};
template <index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack
// ,bool TransposeC //disable transposec right now...
>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
using Base::A_K1;
using Base::B_K1;
using Base::I0;
using Base::I1;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
using Base::a_block_desc_m0_m1_m2_k;
using Base::CalculateCThreadOriginDataIndex;
using Base::CalculateCThreadOriginDataIndex8D;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetCThreadBuffer;
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::AMmaKStride;
using Base::BMmaKStride;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 2;
template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
{
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
make_tuple(
make_pass_through_transform(Number<M0>{}),
make_pass_through_transform(Number<M1>{}),
make_pass_through_transform(Number<M2>{}),
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
__device__ static constexpr auto HotLoopScheduler()
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
// B global
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
// A global
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
// A local
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
});
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
ignore = b_block_buf;
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_thread_desc_.GetElementSpaceSize());
auto b_thread_dequant_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs;
// Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
__builtin_amdgcn_sched_barrier(0);
// // Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// // Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Local prefetch A1
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I0));
// Initialize C
c_thread_buf.Clear();
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(local_read_buf));
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I1));
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
}
else
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
protected:
// MRepeat MWave MLane KRepeat KLane KPack
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
using Base::c_thread_desc_;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_k0_k1),
decltype(b_block_desc_n0_n1_k0_k1),
tensor_operation::element_wise::PassThrough,
Sequence<Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}>,
Sequence<1, 2, 0, 3>,
3,
KPack>;
const PassThrough b_element_op{};
BThreadDequantCopy b_thread_dequant_copy_{b_element_op};
};
} // namespace ck
...@@ -33,7 +33,7 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer, ...@@ -33,7 +33,7 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPacks> index_t KPacks>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3 struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3
{ {
}; };
...@@ -58,7 +58,7 @@ template <index_t BlockSize, ...@@ -58,7 +58,7 @@ template <index_t BlockSize,
index_t KPack index_t KPack
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::Intrawave, struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineScheduler::Intrawave,
BlockSize, BlockSize,
ADataType, ADataType,
BDataType, BDataType,
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
#pragma once #pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp"
namespace ck { namespace ck {
...@@ -33,6 +34,8 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer, ...@@ -33,6 +34,8 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
constexpr auto BlockGemmBPreshufflePipeline_Selector() constexpr auto BlockGemmBPreshufflePipeline_Selector()
{ {
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(std::is_same<ADataType, BDataType>::value)
{ {
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlkGemmPipeSche, return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlkGemmPipeSche,
BlockSize, BlockSize,
...@@ -55,6 +58,31 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector() ...@@ -55,6 +58,31 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
NRepeat, NRepeat,
KPack>{}; KPack>{};
} }
else
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{ {
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlkGemmPipeSche, return BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlkGemmPipeSche,
...@@ -81,7 +109,14 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector() ...@@ -81,7 +109,14 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3"); static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlkGemmPipeSche, if(std::is_same<ADataType, BDataType>::value)
{
std::cerr << "BlockGemmPipeline v3 configuration is not available" << std::endl;
}
else
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<
BlkGemmPipeSche,
BlockSize, BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -102,6 +137,7 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector() ...@@ -102,6 +137,7 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
NRepeat, NRepeat,
KPack>{}; KPack>{};
} }
}
else else
{ {
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -144,7 +144,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -144,7 +144,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 2; static constexpr index_t GlobalBufferNum = 1;
template <typename TileDesc_M0_M1_M2_K> template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
...@@ -235,7 +235,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -235,7 +235,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
const BGridBuffer& b_grid_buf, const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
// BThreadTransfer& b_thread_dequant_copy,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) const index_t num_loop) const
{ {
...@@ -243,19 +242,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -243,19 +242,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
auto b_thread_dequant_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs; StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs;
// Global prefetch A1 B1 // Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(b_grid_desc, b_blockwise_copy.Run(b_grid_desc,
b_grid_buf, b_grid_buf,
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
...@@ -264,13 +258,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -264,13 +258,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
__builtin_amdgcn_sched_barrier(0);
// // Local prefill A1 // // Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// // Global prefetch A2 // // Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Local prefetch A1 // Local prefetch A1
...@@ -285,13 +278,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -285,13 +278,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf); a_thread_buf);
}); });
}); });
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I0));
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
...@@ -310,14 +296,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -310,14 +296,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
b_block_desc_n0_n1_k0_k1, b_block_desc_n0_n1_k0_k1,
b_block_origin_idx, b_block_origin_idx,
b_thread_bufs(local_read_buf)); b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds(); block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// printf("bid %d tid %d %f %f\n", blockIdx.x, threadIdx.x,
// type_convert<float>(a_thread_buf[I0]),
// type_convert<float>(b_thread_bufs[mfma_reg_buf][I0]));
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
...@@ -329,7 +319,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -329,7 +319,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[mfma_reg_buf] b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset( [Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -360,13 +350,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -360,13 +350,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf); a_thread_buf);
}); });
}); });
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(local_read_buf));
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -401,7 +384,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -401,7 +384,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset( b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -430,13 +413,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -430,13 +413,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf); a_thread_buf);
}); });
}); });
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I1));
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -451,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -451,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset( b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -484,7 +460,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -484,7 +460,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset( b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -527,22 +503,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -527,22 +503,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
using Base::c_thread_desc_; using Base::c_thread_desc_;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_k0_k1),
decltype(b_block_desc_n0_n1_k0_k1),
tensor_operation::element_wise::PassThrough,
Sequence<Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}>,
Sequence<1, 2, 0, 3>,
3,
KPack>;
const PassThrough b_element_op{};
BThreadDequantCopy b_thread_dequant_copy_{b_element_op};
}; };
} // namespace ck } // namespace ck
...@@ -196,6 +196,20 @@ struct GridwiseMoeGemm ...@@ -196,6 +196,20 @@ struct GridwiseMoeGemm
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ static auto CalculateGridSize(index_t M, index_t N) __host__ static auto CalculateGridSize(index_t M, index_t N)
{ {
const index_t nblock = math::integer_divide_ceil(N, NPerBlock); const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
...@@ -385,6 +399,10 @@ struct GridwiseMoeGemm ...@@ -385,6 +399,10 @@ struct GridwiseMoeGemm
using GemmSpecialization = tensor_operation::device::GemmSpecialization; using GemmSpecialization = tensor_operation::device::GemmSpecialization;
static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
GemmSpec != GemmSpecialization::Default),
"pk_i4_t does not support padding");
if constexpr(GemmSpec == GemmSpecialization::NKPadding || if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
...@@ -681,7 +699,7 @@ struct GridwiseMoeGemm ...@@ -681,7 +699,7 @@ struct GridwiseMoeGemm
{ {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
a_k_split_offset = k_id * karg.KRead; a_k_split_offset = k_id * karg.KRead / APackedSize;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
...@@ -695,7 +713,7 @@ struct GridwiseMoeGemm ...@@ -695,7 +713,7 @@ struct GridwiseMoeGemm
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
// KPack * NLane * KLane * K0 * N0 // KPack * NLane * KLane * K0 * N0
b_k_split_offset = k_id * karg.KRead * NLane; b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
} }
if(k_id < karg.KBatch - 1) if(k_id < karg.KBatch - 1)
...@@ -725,7 +743,7 @@ struct GridwiseMoeGemm ...@@ -725,7 +743,7 @@ struct GridwiseMoeGemm
// in some cases. // in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) /APackedSize < 1
? 1 ? 1
: 32 * 4 / KPerBlock / sizeof(LDSTypeA); : 32 * 4 / KPerBlock / sizeof(LDSTypeA);
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
...@@ -875,8 +893,8 @@ struct GridwiseMoeGemm ...@@ -875,8 +893,8 @@ struct GridwiseMoeGemm
BlkGemmPipelineVer, BlkGemmPipelineVer,
BlkGemmPipeSched, BlkGemmPipeSched,
BlockSize, BlockSize,
LDSTypeA, ADataType,
LDSTypeB, BDataType,
ComputeTypeA, ComputeTypeA,
AccDataType, AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
...@@ -913,7 +931,7 @@ struct GridwiseMoeGemm ...@@ -913,7 +931,7 @@ struct GridwiseMoeGemm
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max(a_block_space_size_aligned * sizeof(LDSTypeA), return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
c_block_size * sizeof(CShuffleDataType)); c_block_size * sizeof(CShuffleDataType));
} }
...@@ -1209,7 +1227,7 @@ struct GridwiseMoeGemm ...@@ -1209,7 +1227,7 @@ struct GridwiseMoeGemm
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize());
// if(threadIdx.x==0) // if(threadIdx.x==0)
// printf("tid %d eid %d expert_stride %d bufsize %d\n", // printf("tid %d eid %d expert_stride %d bufsize %d\n",
// threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); // threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......
...@@ -100,7 +100,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -100,7 +100,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
{ {
if constexpr(is_same_v<ADataType, pk_i4_t>) if constexpr(is_same_v<ADataType, pk_i4_t>)
{ {
uint8_t i4x2 = arg.a_t_k_(m, k).data; uint8_t i4x2 = arg.a_t_k_(t, topk_id, k).data;
uint8_t i4 = 0; uint8_t i4 = 0;
if(k % 2 == 1) if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf; i4 = (i4x2 >> 0) & 0xf;
...@@ -124,7 +124,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -124,7 +124,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
} }
else else
{ {
arg.b_element_op_(v_b, arg.b_e_n_k_(e, n, k)); arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n));
} }
v_acc += v_acc +=
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment