Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e15351ca
Commit
e15351ca
authored
Feb 07, 2025
by
coderfeli
Browse files
tile m = 64 ok
parent
48d87d9c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
14 deletions
+16
-14
example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp
example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp
+5
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+1
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp
...ary/reference_tensor_operation/cpu/reference_moe_gemm.hpp
+10
-7
No files found.
example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp
View file @
e15351ca
...
@@ -133,13 +133,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
...
@@ -133,13 +133,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
//threadnum, mblock, nblock, kblock
//threadnum, mblock, nblock, kblock
256
,
32
,
128
,
128
,
256
,
64
,
128
,
128
,
// ak1, bk1
// ak1, bk1
8
,
8
,
8
,
8
,
// mn_perxdl
// mn_perxdl
32
,
32
,
32
,
32
,
// mn_xdlperwave
// mn_xdlperwave
1
,
1
,
2
,
1
,
// a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra
// a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
...
@@ -169,8 +169,8 @@ int main(int argc, char* argv[])
...
@@ -169,8 +169,8 @@ int main(int argc, char* argv[])
ck
::
index_t
N
=
6144
;
ck
::
index_t
N
=
6144
;
ck
::
index_t
K
=
8192
;
ck
::
index_t
K
=
8192
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
sorted_tile_num
=
8
;
ck
::
index_t
sorted_tile_num
=
1
;
ck
::
index_t
sorted_tile_size
=
32
;
ck
::
index_t
sorted_tile_size
=
64
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
ck
::
index_t
tokens
=
64
;
ck
::
index_t
tokens
=
64
;
...
@@ -368,7 +368,7 @@ int main(int argc, char* argv[])
...
@@ -368,7 +368,7 @@ int main(int argc, char* argv[])
auto
ref_invoker
=
ref_moe_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_moe_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_moe_gemm
.
MakeArgument
(
auto
ref_argument
=
ref_moe_gemm
.
MakeArgument
(
sorted_token_ids
,
expert_ids
,
a0_t_k
,
b0_e_n_k
,
c_m_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
sorted_token_ids
,
expert_ids
,
sorted_tile_size
,
a0_t_k
,
b0_e_n_k
,
c_m_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
SORTED_SIZE
;
++
m
)
for
(
int
m
=
0
;
m
<
SORTED_SIZE
;
++
m
)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
e15351ca
...
@@ -176,8 +176,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -176,8 +176,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static
constexpr
index_t
NWave
=
NPerBlock
/
NPerXdl
/
NXdlPerWave
;
static
constexpr
index_t
NWave
=
NPerBlock
/
NPerXdl
/
NXdlPerWave
;
static_assert
(
NWave
*
warpSize
==
BlockSize
);
static_assert
(
NWave
*
warpSize
==
BlockSize
);
// static constexpr index_t NumTokens = 1;
// static constexpr index_t NumTokens = 1;
static
constexpr
index_t
Experts
=
8
;
static
constexpr
index_t
SortedTileSize
=
MPerBlock
;
static
constexpr
index_t
SortedTileSize
=
32
;
static
constexpr
auto
MakeDsGridPointer
()
static
constexpr
auto
MakeDsGridPointer
()
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp
View file @
e15351ca
...
@@ -30,14 +30,16 @@ struct ReferenceMoeGemm : public device::BaseOperator
...
@@ -30,14 +30,16 @@ struct ReferenceMoeGemm : public device::BaseOperator
{
{
Argument
(
const
Tensor
<
ck
::
index_t
>&
sorted_token_ids
,
Argument
(
const
Tensor
<
ck
::
index_t
>&
sorted_token_ids
,
const
Tensor
<
ck
::
index_t
>&
expert_ids
,
const
Tensor
<
ck
::
index_t
>&
expert_ids
,
const
index_t
sorted_tile_size
,
const
Tensor
<
ADataType
>&
a_t_k
,
const
Tensor
<
ADataType
>&
a_t_k
,
const
Tensor
<
BDataType
>&
b_e_n_k
,
const
Tensor
<
BDataType
>&
b_e_n_k
,
Tensor
<
CDataType
>&
c_m_n
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
:
expert_ids_
{
expert_ids
},
:
sorted_token_ids_
{
sorted_token_ids
},
sorted_token_ids_
{
sorted_token_ids
},
expert_ids_
{
expert_ids
},
sorted_tile_size_
{
sorted_tile_size
},
a_t_k_
{
a_t_k
},
a_t_k_
{
a_t_k
},
b_e_n_k_
{
b_e_n_k
},
b_e_n_k_
{
b_e_n_k
},
c_m_n_
{
c_m_n
},
c_m_n_
{
c_m_n
},
...
@@ -56,7 +58,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
...
@@ -56,7 +58,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
sorted_tile_size
=
32
;
index_t
sorted_tile_size
_
;
};
};
// Invoker
// Invoker
...
@@ -73,7 +75,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
...
@@ -73,7 +75,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
ComputeTypeA
v_a
{
0
};
ComputeTypeA
v_a
{
0
};
ComputeTypeB
v_b
{
0
};
ComputeTypeB
v_b
{
0
};
const
int
t
=
arg
.
sorted_token_ids_
(
m
);
const
int
t
=
arg
.
sorted_token_ids_
(
m
);
const
int
e
=
arg
.
expert_ids_
(
m
/
arg
.
sorted_tile_size
);
const
int
e
=
arg
.
expert_ids_
(
m
/
arg
.
sorted_tile_size
_
);
const
int
token_cnt
=
arg
.
a_t_k_
.
mDesc
.
GetLengths
()[
0
];
const
int
token_cnt
=
arg
.
a_t_k_
.
mDesc
.
GetLengths
()[
0
];
if
(
t
<
token_cnt
)
{
if
(
t
<
token_cnt
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
...
@@ -135,6 +137,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
...
@@ -135,6 +137,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
static
auto
MakeArgument
(
const
Tensor
<
ck
::
index_t
>&
sorted_token_ids
,
static
auto
MakeArgument
(
const
Tensor
<
ck
::
index_t
>&
sorted_token_ids
,
const
Tensor
<
ck
::
index_t
>&
expert_ids
,
const
Tensor
<
ck
::
index_t
>&
expert_ids
,
const
index_t
sorted_tile_size
,
const
Tensor
<
ADataType
>&
a_t_k
,
const
Tensor
<
ADataType
>&
a_t_k
,
const
Tensor
<
BDataType
>&
b_e_n_k
,
const
Tensor
<
BDataType
>&
b_e_n_k
,
Tensor
<
CDataType
>&
c_m_n
,
Tensor
<
CDataType
>&
c_m_n
,
...
@@ -142,7 +145,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
...
@@ -142,7 +145,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
{
{
return
Argument
{
sorted_token_ids
,
expert_ids
,
a_t_k
,
b_e_n_k
,
c_m_n
,
a_element_op
,
b_element_op
,
c_element_op
};
return
Argument
{
sorted_token_ids
,
expert_ids
,
sorted_tile_size
,
a_t_k
,
b_e_n_k
,
c_m_n
,
a_element_op
,
b_element_op
,
c_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment