Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
167deece
Commit
167deece
authored
Feb 15, 2025
by
mtgu0705
Browse files
test expert = 8 and function pass.
parent
f904a37d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
6 deletions
+7
-6
example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp
example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp
+7
-6
No files found.
example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp
View file @
167deece
...
@@ -152,7 +152,8 @@ using AElementOp = PassThrough;
...
@@ -152,7 +152,8 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
#if 0
#if 1
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
CShuffleMXDLPerWave
=
MPerBlock
/
32
;
static
constexpr
ck
::
index_t
CShuffleMXDLPerWave
=
MPerBlock
/
32
;
static
constexpr
ck
::
index_t
KPerBlock
=
128
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
KPerBlock
=
128
/
sizeof
(
A0DataType
);
...
@@ -168,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
...
@@ -168,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row
,
Col
,
DsLayout
,
ELayout
,
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
6
4
, MPerBlock, 1
6
, KPerBlock,
25
6
,
MPerBlock
,
1
28
,
KPerBlock
,
AK1
,
BK1
,
AK1
,
BK1
,
MNPerXDL
,
MNPerXDL
,
MNPerXDL
,
MNPerXDL
,
MXDLPerWave
,
1
,
MXDLPerWave
,
1
,
...
@@ -208,12 +209,12 @@ int main(int argc, char* argv[])
...
@@ -208,12 +209,12 @@ int main(int argc, char* argv[])
// GEMM shape
// GEMM shape
ck
::
index_t
N
=
6144
;
ck
::
index_t
N
=
6144
;
ck
::
index_t
K
=
8192
;
ck
::
index_t
K
=
8192
;
ck
::
index_t
experts
=
1
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
sorted_tile_num
=
1
;
ck
::
index_t
sorted_tile_num
=
8
;
ck
::
index_t
sorted_tile_size
=
MPerBlock
;
ck
::
index_t
sorted_tile_size
=
MPerBlock
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
//
ck::index_t tokens = 128;
ck
::
index_t
tokens
=
128
;
ck
::
index_t
tokens
=
16
;
//
ck::index_t tokens = 16;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment