Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4404984a
Commit
4404984a
authored
Feb 17, 2025
by
coderfeli
Browse files
2x2 ok
parent
f64b1375
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
477 deletions
+12
-477
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+8
-472
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
...ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
+2
-3
No files found.
example/65_gemm_multiply_multiply/moe_gemm1.cpp
View file @
4404984a
...
@@ -164,7 +164,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
...
@@ -164,7 +164,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// mn_perxdl
// mn_perxdl
MNPerXDL
,
MNPerXDL
,
MNPerXDL
,
MNPerXDL
,
// mn_xdlperwave
// mn_xdlperwave
MXDLPerWave
,
1
,
2
,
2
,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
...
@@ -173,7 +173,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
...
@@ -173,7 +173,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
EVec
,
D0Vec
,
D1Vec
>
,
2
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
EVec
,
D0Vec
,
D1Vec
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
true
,
A0DataType
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
true
,
A0DataType
>
;
// kernel 2: 128->32x128x128
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
4404984a
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
View file @
4404984a
...
@@ -175,7 +175,6 @@ struct GridwiseMoeGemmGather
...
@@ -175,7 +175,6 @@ struct GridwiseMoeGemmGather
static
constexpr
index_t
KRepeat
=
KPerBlock
/
KLane
/
KPack
;
static
constexpr
index_t
KRepeat
=
KPerBlock
/
KLane
/
KPack
;
static
constexpr
index_t
NLane
=
NPerXdl
;
static
constexpr
index_t
NLane
=
NPerXdl
;
static
constexpr
index_t
NWave
=
NPerBlock
/
NPerXdl
/
NXdlPerWave
;
static
constexpr
index_t
NWave
=
NPerBlock
/
NPerXdl
/
NXdlPerWave
;
static_assert
(
NWave
*
warpSize
==
BlockSize
);
// static constexpr index_t NumTokens = 1;
// static constexpr index_t NumTokens = 1;
static
constexpr
index_t
SortedTileSize
=
MPerBlock
;
static
constexpr
index_t
SortedTileSize
=
MPerBlock
;
...
@@ -1247,13 +1246,13 @@ struct GridwiseMoeGemmGather
...
@@ -1247,13 +1246,13 @@ struct GridwiseMoeGemmGather
decltype
(
b_grid_desc_bpreshuffled
),
decltype
(
b_grid_desc_bpreshuffled
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
Sequence
<
Number
<
NXdlPerWave
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
BK1Value
>
{}
>
,
Sequence
<
Number
<
NXdlPerWave
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
BK1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
1
,
2
,
0
,
3
>
,
3
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_bpreshuffled
,
true
>
(
b_grid_desc_bpreshuffled
,
make_multi_index
(
n_block_data_idx_on_grid
,
make_multi_index
(
n_block_data_idx_on_grid
,
get_warp_local_1d_id
(),
get_warp_local_1d_id
()
%
NWave
,
0
,
0
,
KPack
*
(
get_thread_local_1d_id
()
%
warpSize
)));
KPack
*
(
get_thread_local_1d_id
()
%
warpSize
)));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment