Commit e78fbf87 authored by coderfeli's avatar coderfeli
Browse files

merge 2 moegemm pipe together

parent 1687fc98
......@@ -12,8 +12,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp"
// #include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
......@@ -66,7 +66,7 @@ template <typename ALayout,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool IsGatherGemm = true,
bool IsInputGemm = true,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
......@@ -85,8 +85,8 @@ struct DeviceMoeGemm
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = std::conditional_t<IsGatherGemm,
GridwiseMoeGemmGather<
using GridwiseGemm =
GridwiseMoeGemm<
ALayout,
BLayout,
DsLayout,
......@@ -136,58 +136,7 @@ struct DeviceMoeGemm
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>,
GridwiseMoeGemmScatter<
ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>>;
LDSTypeB>;
using Argument = typename GridwiseGemm::Argument;
......@@ -305,86 +254,51 @@ struct DeviceMoeGemm
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// if constexpr (IsGatherGemm) {
// const auto kernel = kernel_moe_gemm_gather<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// RunKernel(kernel);
// else {
// const auto kernel = kernel_moe_gemm_scatter<
// const auto kernel = kernel_moe_gemm<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// IsInputGemm,
// TailNumber::Odd>;
// RunKernel(kernel);
// }
// }
// else
// {
// if constexpr (IsGatherGemm) {
// const auto kernel = kernel_moe_gemm_gather<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Even>;
// RunKernel(kernel);
// else {
// const auto kernel = kernel_moe_gemm_scatter<
// const auto kernel = kernel_moe_gemm<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// IsInputGemm,
// TailNumber::Even>;
// RunKernel(kernel);
// }
// }
// }
// else
{
constexpr auto MemoryDataOp = IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// if constexpr (IsGatherGemm) {
// const auto kernel = kernel_moe_gemm_gather<
// const auto kernel = kernel_moe_gemm<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::Set,
// MemoryDataOp,
// minimum_occupancy,
// IsInputGemm,
// TailNumber::Odd>;
// RunKernel(kernel);
// } else {
// const auto kernel = kernel_moe_gemm_scatter<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// RunKernel(kernel);
// }
// }
// else
{
if constexpr (IsGatherGemm) {
const auto kernel = kernel_moe_gemm_gather<
const auto kernel = kernel_moe_gemm<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Even>;
RunKernel(kernel);
} else {
const auto kernel = kernel_moe_gemm_scatter<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
}
}
......@@ -423,7 +337,7 @@ struct DeviceMoeGemm
// kernel_moe_gemm_gather_2lds<
// GridwiseGemm,
// true,
// IsGatherGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// IsInputGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// RunKernel(kernel);
......@@ -434,7 +348,7 @@ struct DeviceMoeGemm
// kernel_moe_gemm_gather_2lds<
// GridwiseGemm,
// true,
// IsGatherGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// IsInputGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Even>;
// RunKernel(kernel);
......
......@@ -30,20 +30,21 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsInputGemm = false,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_moe_gemm_gather(typename GridwiseGemm::Argument karg)
kernel_moe_gemm(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, IsInputGemm, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
......@@ -145,7 +146,7 @@ template <typename ALayout,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType,
typename LDSTypeB = BDataType>
struct GridwiseMoeGemmGather
struct GridwiseMoeGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -1121,6 +1122,7 @@ struct GridwiseMoeGemmGather
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool IsInputGemm = true,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(
const index_t* p_sorted_token_ids,
......@@ -1138,11 +1140,11 @@ struct GridwiseMoeGemmGather
{
ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.NumTokens, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
IsInputGemm? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.NumTokens * problem.TopK, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
IsInputGemm? problem.NumTokens * problem.TopK : problem.NumTokens , problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
// printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(),
// problem.MBlock, problem.NBlock, MPerBlock, NPerBlock);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
......@@ -1177,8 +1179,12 @@ struct GridwiseMoeGemmGather
return;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t token_offset = (token_pos + m0 < max_token_id) ?
(p_sorted_token_ids[token_pos + m0] & 0xffffff) : problem.NumTokens;
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
if constexpr (!IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
gather_offsets(m0) = token_offset * problem.K;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
});
......@@ -1464,16 +1470,26 @@ struct GridwiseMoeGemmGather
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights = p_ds_grid[I0];
const float *p_sorted_weights_0 = p_ds_grid[I0];
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
scatter_offsets(m0) = ((fused_token & 0xffffff) * problem.TopK + (fused_token >> 24)) * problem.N;
scatter_weights(m0) = p_sorted_weights[(c_token_pos + m0) * problem.StrideDs[0]];
index_t token_offset = fused_token & 0xffffff;
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
if constexpr (IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
} else {
const float *p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
// if(threadIdx.x % 16 == 0)
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
......@@ -1502,7 +1518,7 @@ struct GridwiseMoeGemmGather
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1, //ScatterDim
true, //OutputScatter: false, only use scatter weights
1 // ScatterWeightIdx: ascale
scatter_weight_idx // ScatterWeightIdx: ascale
>
{c_ds_desc_refs,
idx_c_ds_block_begin,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment