Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3b27b10e
Commit
3b27b10e
authored
Feb 13, 2025
by
valarLip
Browse files
porting fmoe_sorting from moe_sorting
parent
16fa63ea
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
61 deletions
+74
-61
example/ck_tile/15_fused_moe/fused_moe.hpp
example/ck_tile/15_fused_moe/fused_moe.hpp
+11
-8
example/ck_tile/15_fused_moe/fused_moesorting.hpp
example/ck_tile/15_fused_moe/fused_moesorting.hpp
+2
-1
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
+2
-1
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
...e/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+59
-51
No files found.
example/ck_tile/15_fused_moe/fused_moe.hpp
View file @
3b27b10e
...
...
@@ -8,14 +8,15 @@
struct
fused_moe_args
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token (no need to do zeroing)
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
const
void
*
local_expert_mask_ptr
;
// [e], local_expert_mask_ptr for EP
void
*
o_ptr
;
// [m, k], output token (no need to do zeroing)
const
void
*
topk_ids_ptr
;
// [tokens, topk]
const
void
*
topk_weight_ptr
;
// [tokens, topk]
...
...
@@ -48,6 +49,8 @@ struct fused_moe_traits
int
activation
;
// 0:gelu, 1:silu
int
gate_only
;
// 0:g1u0, 1:g1u1
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
bool
local_expert_masking
;
// if mask experts as local expert
};
float
fused_moe
(
fused_moe_traits
,
fused_moe_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/15_fused_moe/fused_moesorting.hpp
View file @
3b27b10e
...
...
@@ -10,7 +10,8 @@
struct
fused_moesorting_trait
{
std
::
string
index_type
;
std
::
string
weight_type
;
// currently always float
std
::
string
weight_type
;
// currently always float
bool
local_expert_masking
;
// if mask experts as local expert
};
struct
fused_moesorting_args
:
public
ck_tile
::
MoeSortingHostArgs
...
...
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
View file @
3b27b10e
...
...
@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
return
1
;
}();
auto
t0
=
fused_moesorting_trait
{
"int32"
,
"fp32"
};
auto
t0
=
fused_moesorting_trait
{
"int32"
,
"fp32"
,
t
.
local_expert_masking
};
auto
a0
=
fused_moesorting_args
{
a
.
topk_ids_ptr
,
// const void* p_topk_ids;
a
.
topk_weight_ptr
,
// const void* p_weights;
a
.
local_expert_mask_ptr
,
// const void* p_local_expert_mask;
a
.
sorted_token_ids_ptr
,
// void* p_sorted_token_ids;
a
.
sorted_weight_ptr
,
// void* p_sorted_weights;
a
.
sorted_expert_ids_ptr
,
// void* p_sorted_expert_ids;
...
...
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
View file @
3b27b10e
...
...
@@ -24,20 +24,63 @@
return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
using ms_problem = \
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
...
...
@@ -116,47 +159,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
auto
sub_token_
=
r_
-
2
;
r_
=
(
r_
-
2
)
/
8
;
bool
is_sub_token_onshot
=
a
.
tokens
<=
sub_token_
;
bool
is_local_expert_masking
=
t
.
local_expert_masking
;
(
void
)
c_
;
if
(
is_sub_token_onshot
)
{
if
(
r_
%
8
==
0
)
{
MOE_SORTING_DISPATCH_
(
8
,
true
);
}
else
if
(
r_
%
4
==
0
)
{
MOE_SORTING_DISPATCH_
(
4
,
true
);
}
else
if
(
r_
%
2
==
0
)
{
MOE_SORTING_DISPATCH_
(
2
,
true
);
}
else
{
MOE_SORTING_DISPATCH_
(
1
,
true
);
}
}
else
{
if
(
r_
%
8
==
0
)
{
MOE_SORTING_DISPATCH_
(
8
,
false
);
}
else
if
(
r_
%
4
==
0
)
{
MOE_SORTING_DISPATCH_
(
4
,
false
);
}
else
if
(
r_
%
2
==
0
)
{
MOE_SORTING_DISPATCH_
(
2
,
false
);
}
else
{
MOE_SORTING_DISPATCH_
(
1
,
false
);
}
}
MOE_SORTING_DISPATCH_EMASK_
(
r_
);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
return
-
1
;
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment