"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "6f22f08dc2f2e65de97ef587cadc6d2b49c90290"
Commit e78fbf87 authored by coderfeli's avatar coderfeli
Browse files

merge 2 moegemm pipe together

parent 1687fc98
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp" // #include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp" #include "ck/host_utility/flush_cache.hpp"
...@@ -66,7 +66,7 @@ template <typename ALayout, ...@@ -66,7 +66,7 @@ template <typename ALayout,
typename CDEShuffleBlockTransferScalarPerVectors, typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool IsGatherGemm = true, bool IsInputGemm = true,
typename ComputeTypeA = CDataType, typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA, typename LDSTypeA = ComputeTypeA,
...@@ -85,8 +85,8 @@ struct DeviceMoeGemm ...@@ -85,8 +85,8 @@ struct DeviceMoeGemm
CElementwiseOperation> CElementwiseOperation>
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = std::conditional_t<IsGatherGemm, using GridwiseGemm =
GridwiseMoeGemmGather< GridwiseMoeGemm<
ALayout, ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
...@@ -136,58 +136,7 @@ struct DeviceMoeGemm ...@@ -136,58 +136,7 @@ struct DeviceMoeGemm
ComputeTypeA, ComputeTypeA,
ComputeTypeB, ComputeTypeB,
LDSTypeA, LDSTypeA,
LDSTypeB>, LDSTypeB>;
GridwiseMoeGemmScatter<
ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>>;
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
...@@ -305,86 +254,51 @@ struct DeviceMoeGemm ...@@ -305,86 +254,51 @@ struct DeviceMoeGemm
// { // {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// { // {
// if constexpr (IsGatherGemm) { // const auto kernel = kernel_moe_gemm<
// const auto kernel = kernel_moe_gemm_gather<
// GridwiseGemm, // GridwiseGemm,
// true, // true,
// InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy, // minimum_occupancy,
// IsInputGemm,
// TailNumber::Odd>; // TailNumber::Odd>;
// RunKernel(kernel); // RunKernel(kernel);
// else {
// const auto kernel = kernel_moe_gemm_scatter<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// RunKernel(kernel);
// }
// } // }
// else // else
// { // {
// if constexpr (IsGatherGemm) { // const auto kernel = kernel_moe_gemm<
// const auto kernel = kernel_moe_gemm_gather<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Even>;
// RunKernel(kernel);
// else {
// const auto kernel = kernel_moe_gemm_scatter<
// GridwiseGemm, // GridwiseGemm,
// true, // true,
// InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy, // minimum_occupancy,
// IsInputGemm,
// TailNumber::Even>; // TailNumber::Even>;
// RunKernel(kernel); // RunKernel(kernel);
// }
// } // }
// } // }
// else // else
{ {
constexpr auto MemoryDataOp = IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// { // {
// if constexpr (IsGatherGemm) { // const auto kernel = kernel_moe_gemm<
// const auto kernel = kernel_moe_gemm_gather<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::Set,
// minimum_occupancy,
// TailNumber::Odd>;
// RunKernel(kernel);
// } else {
// const auto kernel = kernel_moe_gemm_scatter<
// GridwiseGemm, // GridwiseGemm,
// true, // true,
// InMemoryDataOperationEnum::AtomicAdd, // MemoryDataOp,
// minimum_occupancy, // minimum_occupancy,
// IsInputGemm,
// TailNumber::Odd>; // TailNumber::Odd>;
// RunKernel(kernel); // RunKernel(kernel);
// }
// } // }
// else // else
{ {
if constexpr (IsGatherGemm) { const auto kernel = kernel_moe_gemm<
const auto kernel = kernel_moe_gemm_gather< GridwiseGemm,
GridwiseGemm, true,
true, MemoryDataOp,
InMemoryDataOperationEnum::Set, minimum_occupancy,
minimum_occupancy, IsInputGemm,
TailNumber::Even>; TailNumber::Even>;
RunKernel(kernel); RunKernel(kernel);
} else {
const auto kernel = kernel_moe_gemm_scatter<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
} }
} }
} }
...@@ -423,7 +337,7 @@ struct DeviceMoeGemm ...@@ -423,7 +337,7 @@ struct DeviceMoeGemm
// kernel_moe_gemm_gather_2lds< // kernel_moe_gemm_gather_2lds<
// GridwiseGemm, // GridwiseGemm,
// true, // true,
// IsGatherGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd, // IsInputGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy, // minimum_occupancy,
// TailNumber::Odd>; // TailNumber::Odd>;
// RunKernel(kernel); // RunKernel(kernel);
...@@ -434,7 +348,7 @@ struct DeviceMoeGemm ...@@ -434,7 +348,7 @@ struct DeviceMoeGemm
// kernel_moe_gemm_gather_2lds< // kernel_moe_gemm_gather_2lds<
// GridwiseGemm, // GridwiseGemm,
// true, // true,
// IsGatherGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd, // IsInputGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy, // minimum_occupancy,
// TailNumber::Even>; // TailNumber::Even>;
// RunKernel(kernel); // RunKernel(kernel);
......
...@@ -30,20 +30,21 @@ template <typename GridwiseGemm, ...@@ -30,20 +30,21 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1, index_t MinimumOccupancy = 1,
bool IsInputGemm = false,
TailNumber TailNum = TailNumber::Even> TailNumber TailNum = TailNumber::Even>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_moe_gemm_gather(typename GridwiseGemm::Argument karg) kernel_moe_gemm(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, IsInputGemm, TailNum>(
karg.p_sorted_token_ids, karg.p_sorted_token_ids,
karg.p_sorted_expert_ids, karg.p_sorted_expert_ids,
karg.p_max_token_id, karg.p_max_token_id,
...@@ -145,7 +146,7 @@ template <typename ALayout, ...@@ -145,7 +146,7 @@ template <typename ALayout,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType, typename LDSTypeA = ADataType,
typename LDSTypeB = BDataType> typename LDSTypeB = BDataType>
struct GridwiseMoeGemmGather struct GridwiseMoeGemm
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -1121,6 +1122,7 @@ struct GridwiseMoeGemmGather ...@@ -1121,6 +1122,7 @@ struct GridwiseMoeGemmGather
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool IsInputGemm = true,
TailNumber TailNum = TailNumber::Odd> TailNumber TailNum = TailNumber::Odd>
__device__ static void Run( __device__ static void Run(
const index_t* p_sorted_token_ids, const index_t* p_sorted_token_ids,
...@@ -1138,11 +1140,11 @@ struct GridwiseMoeGemmGather ...@@ -1138,11 +1140,11 @@ struct GridwiseMoeGemmGather
{ {
ignore = b_element_op; ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.NumTokens, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); IsInputGemm? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled = const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>( const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.NumTokens * problem.TopK, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); IsInputGemm? problem.NumTokens * problem.TopK : problem.NumTokens , problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
// printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(), // printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(),
// problem.MBlock, problem.NBlock, MPerBlock, NPerBlock); // problem.MBlock, problem.NBlock, MPerBlock, NPerBlock);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
...@@ -1177,8 +1179,12 @@ struct GridwiseMoeGemmGather ...@@ -1177,8 +1179,12 @@ struct GridwiseMoeGemmGather
return; return;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos]; StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) { static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t token_offset = (token_pos + m0 < max_token_id) ? const index_t fused_token = p_sorted_token_ids[token_pos + m0];
(p_sorted_token_ids[token_pos + m0] & 0xffffff) : problem.NumTokens; index_t token_offset = fused_token & 0xffffff;
if constexpr (!IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
gather_offsets(m0) = token_offset * problem.K; gather_offsets(m0) = token_offset * problem.K;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); // printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
}); });
...@@ -1464,16 +1470,26 @@ struct GridwiseMoeGemmGather ...@@ -1464,16 +1470,26 @@ struct GridwiseMoeGemmGather
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos]; StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme // too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights = p_ds_grid[I0]; const float *p_sorted_weights_0 = p_ds_grid[I0];
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24; // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
static_for<0, EMRepeats, 1>{}([&](auto m0) { static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
scatter_offsets(m0) = ((fused_token & 0xffffff) * problem.TopK + (fused_token >> 24)) * problem.N; index_t token_offset = fused_token & 0xffffff;
scatter_weights(m0) = p_sorted_weights[(c_token_pos + m0) * problem.StrideDs[0]]; float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
if constexpr (IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
} else {
const float *p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
// if(threadIdx.x % 16 == 0) // if(threadIdx.x % 16 == 0)
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0)); // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
}); });
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock, ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
...@@ -1502,7 +1518,7 @@ struct GridwiseMoeGemmGather ...@@ -1502,7 +1518,7 @@ struct GridwiseMoeGemmGather
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1, //ScatterDim 1, //ScatterDim
true, //OutputScatter: false, only use scatter weights true, //OutputScatter: false, only use scatter weights
1 // ScatterWeightIdx: ascale scatter_weight_idx // ScatterWeightIdx: ascale
> >
{c_ds_desc_refs, {c_ds_desc_refs,
idx_c_ds_block_begin, idx_c_ds_block_begin,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment