Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a61084f4
Commit
a61084f4
authored
Feb 18, 2025
by
coderfeli
Committed by
mtgu0705
Feb 18, 2025
Browse files
opt gemm2 to 2x2 wave
parent
854cd8b4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
6 deletions
+8
-6
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+8
-6
No files found.
example/65_gemm_multiply_multiply/moe_gemm2.cpp
View file @
a61084f4
...
@@ -127,12 +127,14 @@ using CDEElementOp = MulABScaleExpertWeight;
...
@@ -127,12 +127,14 @@ using CDEElementOp = MulABScaleExpertWeight;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
BLOCKSIZE
=
256
;
static
constexpr
ck
::
index_t
BLOCKSIZE
=
256
;
static
constexpr
ck
::
index_t
MXDLPerWave
=
2
;
static
constexpr
ck
::
index_t
NXDLPerWave
=
2
;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
KPerBlock
=
256
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
KPerBlock
=
128
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
MXDLPerWave
=
MPerBlock
/
32
;
//todo fix this constraint
//
static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
static
constexpr
ck
::
index_t
CShuffleMXDLPerWave
=
MPerBlock
/
32
;
//
static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static
constexpr
ck
::
index_t
CShuffleNLane
=
NPerBlock
/
2
;
static
constexpr
ck
::
index_t
CShuffleNLane
=
3
2
;
static
constexpr
ck
::
index_t
CShuffleMLane
=
BLOCKSIZE
/
CShuffleNLane
;
static
constexpr
ck
::
index_t
CShuffleMLane
=
BLOCKSIZE
/
CShuffleNLane
;
static
constexpr
ck
::
index_t
AK1
=
16
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
AK1
=
16
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
BK1
=
16
/
sizeof
(
B0DataType
);
static
constexpr
ck
::
index_t
BK1
=
16
/
sizeof
(
B0DataType
);
...
@@ -159,7 +161,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
...
@@ -159,7 +161,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// mn_perxdl
// mn_perxdl
MNPerXDL
,
MNPerXDL
,
MNPerXDL
,
MNPerXDL
,
// mn_xdlperwave
// mn_xdlperwave
MXDLPerWave
,
1
,
MXDLPerWave
,
NXDLPerWave
,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
...
@@ -168,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
...
@@ -168,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffle
MXDLPerWave
,
1
,
S
<
1
,
CShuffleMLane
,
1
,
CShuffleNLane
>
,
S
<
EVec
,
D0Vec
,
D1Vec
,
D2Vec
>
,
MXDLPerWave
,
1
,
S
<
1
,
CShuffleMLane
,
1
,
CShuffleNLane
>
,
S
<
EVec
,
D0Vec
,
D1Vec
,
D2Vec
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
true
,
false
,
A0DataType
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
true
,
false
,
A0DataType
>
;
// kernel 2: 128->32x128x128
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment