Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3d15f364
Unverified
Commit
3d15f364
authored
Dec 23, 2024
by
carlushuang
Committed by
GitHub
Dec 23, 2024
Browse files
[CK_TILE] optimize moe-sorting kernel (#1771)
* opt moe sorting * remove commented code
parent
07339c73
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
289 additions
and
80 deletions
+289
-80
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
+34
-19
example/ck_tile/13_moe_sorting/script/smoke_test.sh
example/ck_tile/13_moe_sorting/script/smoke_test.sh
+2
-1
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
...e/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+34
-19
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+210
-37
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
...de/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
+9
-4
No files found.
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
View file @
3d15f364
...
@@ -3,9 +3,11 @@
...
@@ -3,9 +3,11 @@
#include "moe_sorting_api.hpp"
#include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_
)
\
#define MOE_SORTING_DISPATCH
_ETILE
(unroll_num_
, expert_tile_)
\
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
using ms_problem = \
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 grids = kernel::GridSize(a); \
...
@@ -15,6 +17,28 @@
...
@@ -15,6 +17,28 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
{
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
...
@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
...
@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
case
(
6
):
{
case
(
6
):
{
MOE_SORTING_DISPATCH
(
6
);
MOE_SORTING_DISPATCH
(
6
);
}
}
case
(
7
):
{
MOE_SORTING_DISPATCH
(
7
);
}
case
(
8
):
{
case
(
8
):
{
MOE_SORTING_DISPATCH
(
8
);
MOE_SORTING_DISPATCH
(
8
);
}
}
case
(
9
):
{
MOE_SORTING_DISPATCH
(
9
);
}
case
(
10
):
{
case
(
10
):
{
MOE_SORTING_DISPATCH
(
10
);
MOE_SORTING_DISPATCH
(
10
);
}
}
case
(
11
):
{
MOE_SORTING_DISPATCH
(
11
);
}
default:
{
default:
{
MOE_SORTING_DISPATCH
(
4
);
MOE_SORTING_DISPATCH
(
4
);
}
}
...
...
example/ck_tile/13_moe_sorting/script/smoke_test.sh
View file @
3d15f364
...
@@ -17,3 +17,4 @@ $EXE -t=71 -e=11 -k=11
...
@@ -17,3 +17,4 @@ $EXE -t=71 -e=11 -k=11
$EXE
-t
=
1
-e
=
1
-k
=
1
$EXE
-t
=
1
-e
=
1
-k
=
1
$EXE
-t
=
99
-e
=
2
-k
=
1
$EXE
-t
=
99
-e
=
2
-k
=
1
$EXE
-t
=
333
-e
=
99
-k
=
13
$EXE
-t
=
333
-e
=
99
-k
=
13
$EXE
-t
=
128
-e
=
32
-k
=
5
-moe_buf_size
=
262144
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
View file @
3d15f364
...
@@ -3,9 +3,11 @@
...
@@ -3,9 +3,11 @@
#include "fused_moesorting.hpp"
#include "fused_moesorting.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_
)
\
#define MOE_SORTING_DISPATCH
_ETILE
(unroll_num_
, expert_tile_)
\
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
using ms_problem = \
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 grids = kernel::GridSize(a); \
...
@@ -15,6 +17,28 @@
...
@@ -15,6 +17,28 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
{
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
...
@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
...
@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
case
(
6
):
{
case
(
6
):
{
MOE_SORTING_DISPATCH
(
6
);
MOE_SORTING_DISPATCH
(
6
);
}
}
case
(
7
):
{
MOE_SORTING_DISPATCH
(
7
);
}
case
(
8
):
{
case
(
8
):
{
MOE_SORTING_DISPATCH
(
8
);
MOE_SORTING_DISPATCH
(
8
);
}
}
case
(
9
):
{
MOE_SORTING_DISPATCH
(
9
);
}
case
(
10
):
{
case
(
10
):
{
MOE_SORTING_DISPATCH
(
10
);
MOE_SORTING_DISPATCH
(
10
);
}
}
case
(
11
):
{
MOE_SORTING_DISPATCH
(
11
);
}
default:
{
default:
{
MOE_SORTING_DISPATCH
(
4
);
MOE_SORTING_DISPATCH
(
4
);
}
}
...
...
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
View file @
3d15f364
...
@@ -130,7 +130,8 @@ struct MoeSortingKernel
...
@@ -130,7 +130,8 @@ struct MoeSortingKernel
CK_TILE_HOST
static
constexpr
auto
GetSmemSize
(
const
Hargs
&
h
)
CK_TILE_HOST
static
constexpr
auto
GetSmemSize
(
const
Hargs
&
h
)
{
{
const
auto
blocks
=
BlockSize
(
h
);
const
auto
blocks
=
BlockSize
(
h
);
return
((
blocks
.
x
+
1
)
*
h
.
num_experts
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
// usually num_experts is power of 2, we pad 1 dword here for the row-size
return
((
blocks
.
x
+
1
)
*
(
h
.
num_experts
+
1
)
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
}
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
...
@@ -154,6 +155,75 @@ struct MoeSortingKernel
...
@@ -154,6 +155,75 @@ struct MoeSortingKernel
return
k
;
return
k
;
}
}
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
template
<
typename
data_t
,
int
wave_size
>
__device__
inline
void
wave_cumsum
(
data_t
&
thread_data
)
const
{
// wave_size must be power of 2
constexpr
int
row_mask
=
0xf
;
constexpr
int
bank_mask
=
0xf
;
constexpr
bool
bound_ctrl
=
true
;
// ! out-of-bound is zero !
auto
reduce_op
=
[
&
](
auto
x_
,
auto
y_
)
{
return
x_
+
y_
;
};
if
constexpr
(
wave_size
>
1
)
{
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x111
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:1
}
if
constexpr
(
wave_size
>
2
)
{
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x112
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:2
}
if
constexpr
(
wave_size
>
4
)
{
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x114
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:4
}
if
constexpr
(
wave_size
>
8
)
{
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x118
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:8
}
if
constexpr
(
wave_size
>
16
)
{
// now row-0, row-0+row-1, row-1+row-2, row-2+row-3
int
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(((
__lane_id
()
&
0x30
)
-
1
)
<<
2
,
__builtin_bit_cast
(
int
,
thread_data
));
v_remote_tmp
=
__lane_id
()
>=
16
?
v_remote_tmp
:
0
;
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
v_remote_tmp
));
}
if
constexpr
(
wave_size
>
32
)
{
// lane-id 48...63->31
int
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(((
__lane_id
()
&
0x30
)
-
17
)
<<
2
,
__builtin_bit_cast
(
int
,
thread_data
));
v_remote_tmp
=
__lane_id
()
>=
32
?
v_remote_tmp
:
0
;
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
v_remote_tmp
));
}
}
CK_TILE_DEVICE
index_t
calc_index
(
index_t
total_col
,
index_t
row
,
index_t
col
)
const
CK_TILE_DEVICE
index_t
calc_index
(
index_t
total_col
,
index_t
row
,
index_t
col
)
const
{
{
return
row
*
total_col
+
col
;
return
row
*
total_col
+
col
;
...
@@ -187,36 +257,92 @@ struct MoeSortingKernel
...
@@ -187,36 +257,92 @@ struct MoeSortingKernel
index_t
*
shared_mem
=
reinterpret_cast
<
index_t
*>
(
smem
);
index_t
*
shared_mem
=
reinterpret_cast
<
index_t
*>
(
smem
);
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
num_experts
;
// 1: (num_experts + 1)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
(
num_experts
+
1
);
// 1: (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
{
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
i
)]
=
0
;
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
)]
=
0
;
}
}
#pragma unroll Problem_::InternalLoadUnroll
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
{
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
topk_id
[
i
])];
++
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
topk_id
[
i
])];
}
}
__syncthreads
();
__syncthreads
();
#if 1
if
(
tid
<
num_experts
)
if
(
tid
<
num_experts
)
{
{
tokens_cnts
[
calc_index
(
num_experts
,
0
,
tid
)]
=
0
;
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
++
i
)
index_t
local_c
[
8
];
index_t
prev_c
=
0
;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
i
+=
8
)
{
local_c
[
0
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)];
local_c
[
1
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)];
local_c
[
2
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)];
local_c
[
3
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)];
local_c
[
4
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)];
local_c
[
5
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)];
local_c
[
6
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)];
local_c
[
7
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)];
local_c
[
0
]
+=
prev_c
;
local_c
[
1
]
+=
local_c
[
0
];
local_c
[
2
]
+=
local_c
[
1
];
local_c
[
3
]
+=
local_c
[
2
];
local_c
[
4
]
+=
local_c
[
3
];
local_c
[
5
]
+=
local_c
[
4
];
local_c
[
6
]
+=
local_c
[
5
];
local_c
[
7
]
+=
local_c
[
6
];
prev_c
=
local_c
[
7
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)]
=
local_c
[
0
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)]
=
local_c
[
1
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)]
=
local_c
[
2
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)]
=
local_c
[
3
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)]
=
local_c
[
4
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)]
=
local_c
[
5
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)]
=
local_c
[
6
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)]
=
local_c
[
7
];
}
}
#else
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic
{
{
tokens_cnts
[
calc_index
(
num_experts
,
i
,
tid
)]
+=
if
(
tid
<
num_experts
)
tokens_cnts
[
calc_index
(
num_experts
,
i
-
1
,
tid
)];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
i
+=
8
)
{
index_t
local_c
[
8
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
local_c
[
j
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)];
}
}
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
wave_cumsum
<
int
,
64
>
(
local_c
[
j
]);
}
}
// __syncthreads();
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)]
=
local_c
[
j
];
}
}
}
#endif
__syncthreads
();
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
(
tid
==
0
)
if
(
tid
==
0
)
{
{
cumsum
[
0
]
=
0
;
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
{
auto
current_units
=
[
&
]()
{
auto
current_units
=
[
&
]()
{
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
i
-
1
)]
+
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
i
-
1
)]
+
unit_size_mdiv
.
divisor
-
1
;
unit_size_mdiv
.
divisor
-
1
;
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
...
@@ -225,10 +351,30 @@ struct MoeSortingKernel
...
@@ -225,10 +351,30 @@ struct MoeSortingKernel
}
}
*
p_total_tokens_post_pad
=
cumsum
[
num_experts
];
*
p_total_tokens_post_pad
=
cumsum
[
num_experts
];
}
}
}
else
{
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert)
// for simplicity, not check experts here.
int
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
int
blocks_pers_expert
=
unit_size_mdiv
.
div
(
local_cnt
+
unit_size_mdiv
.
divisor
-
1
);
int
padded_tokens_per_expert
=
max
(
blocks_pers_expert
,
1
)
*
unit_size_mdiv
.
divisor
;
int
local_cumsum
=
padded_tokens_per_expert
;
wave_cumsum
<
int
,
64
>
(
local_cumsum
);
if
(
tid
==
(
num_experts
-
1
))
{
cumsum
[
0
]
=
0
;
*
p_total_tokens_post_pad
=
local_cumsum
;
}
if
(
tid
<
num_experts
)
{
cumsum
[
tid
+
1
]
=
local_cumsum
;
}
}
__syncthreads
();
__syncthreads
();
if
(
tid
<
num_experts
)
if
(
tid
<
num_experts
)
{
{
for
(
int
i
=
cumsum
[
tid
];
i
<
cumsum
[
tid
+
1
];
i
+=
unit_size_mdiv
.
divisor
)
int
e_start
=
cumsum
[
tid
];
int
e_end
=
cumsum
[
tid
+
1
];
for
(
int
i
=
e_start
;
i
<
e_end
;
i
+=
unit_size_mdiv
.
divisor
)
{
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
}
}
...
@@ -238,8 +384,8 @@ struct MoeSortingKernel
...
@@ -238,8 +384,8 @@ struct MoeSortingKernel
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
{
index_t
expert_id
=
topk_id
[
i
];
index_t
expert_id
=
topk_id
[
i
];
index_t
rank_post_pad
=
index_t
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)];
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)]
+
cumsum
[
expert_id
];
index_t
rank_post_pad
=
local_cnt
+
cumsum
[
expert_id
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t
curr_token_id
,
curr_topk_id
;
uint32_t
curr_token_id
,
curr_topk_id
;
topk_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_topk_id
);
topk_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_topk_id
);
...
@@ -248,15 +394,17 @@ struct MoeSortingKernel
...
@@ -248,15 +394,17 @@ struct MoeSortingKernel
p_sorted_token_ids
[
rank_post_pad
]
=
topk_mdiv
.
div
(
i
);
p_sorted_token_ids
[
rank_post_pad
]
=
topk_mdiv
.
div
(
i
);
#endif
#endif
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)]
=
local_cnt
+
1
;
}
}
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
if
(
tid
<
num_experts
)
if
(
tid
<
num_experts
)
{
{
index_t
expert_offset
=
index_t
expert_offset
=
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
tid
)];
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
while
(
expert_offset
<
cumsum
[
tid
+
1
])
index_t
expert_end
=
cumsum
[
tid
+
1
];
while
(
expert_offset
<
expert_end
)
{
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
expert_offset
]
=
p_sorted_token_ids
[
expert_offset
]
=
...
@@ -269,6 +417,31 @@ struct MoeSortingKernel
...
@@ -269,6 +417,31 @@ struct MoeSortingKernel
}
}
}
}
}
}
else
{
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
// TODO: only support expert-tile like 8, 16, 32
static
constexpr
index_t
experts_per_wave
=
warpSize
/
Problem
::
ExpertTile
;
{
index_t
eid
=
tid
/
experts_per_wave
;
index_t
expert_offset
=
cumsum
[
eid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
eid
)]
+
tid
%
experts_per_wave
;
index_t
expert_end
=
cumsum
[
eid
+
1
];
if
(
eid
<
num_experts
)
{
while
(
expert_offset
<
expert_end
)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
expert_offset
]
=
MOE_SORTING_MOCK_ID
(
prefill_token
,
topk_mdiv
.
divisor
);
#else
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
#endif
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
+=
experts_per_wave
;
}
}
}
}
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
{
...
...
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
View file @
3d15f364
...
@@ -9,7 +9,10 @@
...
@@ -9,7 +9,10 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
InternalLoadUnroll_
>
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
InternalLoadUnroll_
,
index_t
ExpertTile_
=
0
>
struct
MoeSortingProblem
struct
MoeSortingProblem
{
{
// TODO: this kernel only support warp per row
// TODO: this kernel only support warp per row
...
@@ -18,6 +21,8 @@ struct MoeSortingProblem
...
@@ -18,6 +21,8 @@ struct MoeSortingProblem
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
InternalLoadUnroll
=
InternalLoadUnroll_
;
static
constexpr
index_t
InternalLoadUnroll
=
InternalLoadUnroll_
;
// TODO: need better design(like tile size)
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment