Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e78fbf87
"profiler/vscode:/vscode.git/clone" did not exist on "b653c5eb2e440a181dde86fc29696851f329ab96"
Commit
e78fbf87
authored
Feb 18, 2025
by
coderfeli
Browse files
merge 2 moegemm pipe together
parent
1687fc98
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
1739 deletions
+51
-1739
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
...e/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
+24
-110
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
+27
-11
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
...k/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
+0
-1618
No files found.
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
View file @
e78fbf87
...
...
@@ -12,8 +12,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm
_gather
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp"
//
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
...
...
@@ -66,7 +66,7 @@ template <typename ALayout,
typename
CDEShuffleBlockTransferScalarPerVectors
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
bool
Is
Gather
Gemm
=
true
,
bool
Is
Input
Gemm
=
true
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
,
typename
LDSTypeA
=
ComputeTypeA
,
...
...
@@ -85,8 +85,8 @@ struct DeviceMoeGemm
CElementwiseOperation
>
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
GridwiseGemm
=
std
::
conditional_t
<
IsGatherGemm
,
GridwiseMoeGemm
Gather
<
using
GridwiseGemm
=
GridwiseMoeGemm
<
ALayout
,
BLayout
,
DsLayout
,
...
...
@@ -136,58 +136,7 @@ struct DeviceMoeGemm
ComputeTypeA
,
ComputeTypeB
,
LDSTypeA
,
LDSTypeB
>
,
GridwiseMoeGemmScatter
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
ADataType
,
BDataType
,
GemmAccDataType
,
CShuffleDataType
,
DsDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVectors
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
,
LDSTypeA
,
LDSTypeB
>>
;
LDSTypeB
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
...
...
@@ -305,86 +254,51 @@ struct DeviceMoeGemm
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// if constexpr (IsGatherGemm) {
// const auto kernel = kernel_moe_gemm_gather<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// RunKernel(kernel);
// else {
// const auto kernel = kernel_moe_gemm_scatter<
// const auto kernel = kernel_moe_gemm<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// IsInputGemm,
// TailNumber::Odd>;
// RunKernel(kernel);
// }
// }
// else
// {
// if constexpr (IsGatherGemm) {
// const auto kernel = kernel_moe_gemm_gather<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Even>;
// RunKernel(kernel);
// else {
// const auto kernel = kernel_moe_gemm_scatter<
// const auto kernel = kernel_moe_gemm<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// IsInputGemm,
// TailNumber::Even>;
// RunKernel(kernel);
// }
// }
// }
// else
{
constexpr
auto
MemoryDataOp
=
IsInputGemm
?
InMemoryDataOperationEnum
::
Set
:
InMemoryDataOperationEnum
::
AtomicAdd
;
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// if constexpr (IsGatherGemm) {
// const auto kernel = kernel_moe_gemm_gather<
// const auto kernel = kernel_moe_gemm<
// GridwiseGemm,
// true,
//
In
MemoryDataOp
erationEnum::Set
,
// MemoryDataOp,
// minimum_occupancy,
// IsInputGemm,
// TailNumber::Odd>;
// RunKernel(kernel);
// } else {
// const auto kernel = kernel_moe_gemm_scatter<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// RunKernel(kernel);
// }
// }
// else
{
if
constexpr
(
IsGatherGemm
)
{
const
auto
kernel
=
kernel_moe_gemm_gather
<
const
auto
kernel
=
kernel_moe_gemm
<
GridwiseGemm
,
true
,
In
MemoryDataOp
erationEnum
::
Set
,
MemoryDataOp
,
minimum_occupancy
,
IsInputGemm
,
TailNumber
::
Even
>
;
RunKernel
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_moe_gemm_scatter
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
RunKernel
(
kernel
);
}
}
}
}
...
...
@@ -423,7 +337,7 @@ struct DeviceMoeGemm
// kernel_moe_gemm_gather_2lds<
// GridwiseGemm,
// true,
// Is
Gather
Gemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// Is
Input
Gemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// RunKernel(kernel);
...
...
@@ -434,7 +348,7 @@ struct DeviceMoeGemm
// kernel_moe_gemm_gather_2lds<
// GridwiseGemm,
// true,
// Is
Gather
Gemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// Is
Input
Gemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Even>;
// RunKernel(kernel);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm
_gather
.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
View file @
e78fbf87
...
...
@@ -30,20 +30,21 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
bool
IsInputGemm
=
false
,
TailNumber
TailNum
=
TailNumber
::
Even
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_moe_gemm
_gather
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_moe_gemm
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
blockIdx
.
z
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
IsInputGemm
,
TailNum
>(
karg
.
p_sorted_token_ids
,
karg
.
p_sorted_expert_ids
,
karg
.
p_max_token_id
,
...
...
@@ -145,7 +146,7 @@ template <typename ALayout,
typename
ComputeTypeB
=
ComputeTypeA
,
typename
LDSTypeA
=
ADataType
,
typename
LDSTypeB
=
BDataType
>
struct
GridwiseMoeGemm
Gather
struct
GridwiseMoeGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -1121,6 +1122,7 @@ struct GridwiseMoeGemmGather
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
bool
IsInputGemm
=
true
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run
(
const
index_t
*
p_sorted_token_ids
,
...
...
@@ -1138,11 +1140,11 @@ struct GridwiseMoeGemmGather
{
ignore
=
b_element_op
;
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
NumTokens
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
IsInputGemm
?
problem
.
NumTokens
:
problem
.
NumTokens
*
problem
.
TopK
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bpreshuffled
=
MakeBGridDescriptor_Preshuffled
(
problem
.
BN0Shuffled
,
problem
.
BK0Shuffled
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
problem
.
NumTokens
*
problem
.
TopK
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
IsInputGemm
?
problem
.
NumTokens
*
problem
.
TopK
:
problem
.
NumTokens
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
// printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(),
// problem.MBlock, problem.NBlock, MPerBlock, NPerBlock);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -1177,8 +1179,12 @@ struct GridwiseMoeGemmGather
return
;
StaticallyIndexedArray
<
index_t
,
AMRepeats
>
gather_offsets
;
//= p_sorted_token_ids[token_pos];
static_for
<
0
,
AMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
const
index_t
token_offset
=
(
token_pos
+
m0
<
max_token_id
)
?
(
p_sorted_token_ids
[
token_pos
+
m0
]
&
0xffffff
)
:
problem
.
NumTokens
;
const
index_t
fused_token
=
p_sorted_token_ids
[
token_pos
+
m0
];
index_t
token_offset
=
fused_token
&
0xffffff
;
if
constexpr
(
!
IsInputGemm
)
{
token_offset
=
token_offset
*
problem
.
TopK
+
(
fused_token
>>
24
);
}
gather_offsets
(
m0
)
=
token_offset
*
problem
.
K
;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
});
...
...
@@ -1464,16 +1470,26 @@ struct GridwiseMoeGemmGather
StaticallyIndexedArray
<
index_t
,
EMRepeats
>
scatter_offsets
;
//= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray
<
float
,
EMRepeats
>
scatter_weights
;
//= for topk
// too hack here, 2 specific for topk weights, fixme
const
float
*
p_sorted_weights
=
p_ds_grid
[
I0
];
const
float
*
p_sorted_weights
_0
=
p_ds_grid
[
I0
];
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
const
index_t
fused_token
=
p_sorted_token_ids
[
c_token_pos
+
m0
];
scatter_offsets
(
m0
)
=
((
fused_token
&
0xffffff
)
*
problem
.
TopK
+
(
fused_token
>>
24
))
*
problem
.
N
;
scatter_weights
(
m0
)
=
p_sorted_weights
[(
c_token_pos
+
m0
)
*
problem
.
StrideDs
[
0
]];
index_t
token_offset
=
fused_token
&
0xffffff
;
float
weight
=
p_sorted_weights_0
[(
c_token_pos
+
m0
)
*
problem
.
StrideDs
[
0
]];
if
constexpr
(
IsInputGemm
)
{
token_offset
=
token_offset
*
problem
.
TopK
+
(
fused_token
>>
24
);
}
else
{
const
float
*
p_sorted_weights_2
=
p_ds_grid
[
I2
];
weight
=
weight
*
p_sorted_weights_2
[
c_token_pos
+
m0
];
}
scatter_offsets
(
m0
)
=
token_offset
*
problem
.
N
;
scatter_weights
(
m0
)
=
weight
;
// if(threadIdx.x % 16 == 0)
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
constexpr
index_t
scatter_weight_idx
=
IsInputGemm
?
1
:
3
;
//hack fix felix
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3_scatter
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
...
...
@@ -1502,7 +1518,7 @@ struct GridwiseMoeGemmGather
Sequence
<
false
>
,
// ThreadTransferDstResetCoordinateAfterRunFlags
1
,
//ScatterDim
true
,
//OutputScatter: false, only use scatter weights
1
// ScatterWeightIdx: ascale
scatter_weight_idx
// ScatterWeightIdx: ascale
>
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
deleted
100644 → 0
View file @
1687fc98
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment