Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7822382b
Commit
7822382b
authored
Feb 07, 2025
by
coderfeli
Browse files
tileM 32,64,128 ok
parent
e15351ca
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
6 deletions
+8
-6
example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp
example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp
+8
-6
No files found.
example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp
View file @
7822382b
...
...
@@ -118,6 +118,8 @@ using BElementOp = PassThrough;
using
CDEElementOp
=
MultiplyMultiply
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
MXDLPerWave
=
MPerBlock
/
32
;
//todo fix this constraint
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
...
@@ -133,13 +135,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
//threadnum, mblock, nblock, kblock
256
,
64
,
128
,
128
,
256
,
MPerBlock
,
128
,
128
,
// ak1, bk1
8
,
8
,
// mn_perxdl
32
,
32
,
// mn_xdlperwave
2
,
1
,
MXDLPerWave
,
1
,
// a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
...
...
@@ -169,10 +171,10 @@ int main(int argc, char* argv[])
ck
::
index_t
N
=
6144
;
ck
::
index_t
K
=
8192
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
sorted_tile_num
=
1
;
ck
::
index_t
sorted_tile_size
=
64
;
ck
::
index_t
sorted_tile_num
=
8
;
ck
::
index_t
sorted_tile_size
=
MPerBlock
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
ck
::
index_t
tokens
=
64
;
ck
::
index_t
tokens
=
512
;
if
(
argc
==
1
)
{
...
...
@@ -337,7 +339,7 @@ int main(int argc, char* argv[])
if
(
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
SORTED_SIZE
*
N
*
K
*
experts
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
SORTED_SIZE
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
A0DataType
)
*
SORTED_SIZE
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
*
experts
+
sizeof
(
EDataType
)
*
SORTED_SIZE
*
N
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment