Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8f4dc357
Commit
8f4dc357
authored
Nov 02, 2024
by
dummycoderfe
Browse files
add an loop unroll for moe lds ops
parent
68952cba
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
17 deletions
+57
-17
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
+50
-10
include/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
...ude/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
+3
-4
include/ck_tile/ops/moe_sorting/pipeline/moe_sorting_problem.hpp
.../ck_tile/ops/moe_sorting/pipeline/moe_sorting_problem.hpp
+4
-3
No files found.
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
View file @
8f4dc357
...
...
@@ -3,6 +3,17 @@
#include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); \
return ave_time;
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
...
...
@@ -12,16 +23,45 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
printf
(
"lds size exceed, only support experts <127
\n
"
);
return
-
1
;
}
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
using
ms_problem
=
ck_tile
::
MoeSortingProblem
<
index_t
,
ms_weight_type
>
;
using
kernel
=
ck_tile
::
MoeSortingKernel
<
ms_problem
>
;
auto
kargs
=
kernel
::
MakeKargs
(
a
);
const
dim3
grids
=
kernel
::
GridSize
(
a
);
const
dim3
blocks
=
kernel
::
BlockSize
(
a
);
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
(
kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
index_t
smem_io_unroll_num
=
ck_tile
::
integer_divide_ceil
(
a
.
tokens
*
a
.
topk
,
64
);
switch
(
smem_io_unroll_num
)
{
case
(
1
):
{
MOE_SORTING_DISPATCH
(
1
);
}
case
(
2
):
{
MOE_SORTING_DISPATCH
(
2
);
}
case
(
3
):
{
MOE_SORTING_DISPATCH
(
3
);
}
case
(
5
):
{
MOE_SORTING_DISPATCH
(
5
);
}
case
(
6
):
{
MOE_SORTING_DISPATCH
(
6
);
}
case
(
7
):
{
MOE_SORTING_DISPATCH
(
7
);
}
case
(
8
):
{
MOE_SORTING_DISPATCH
(
8
);
}
case
(
9
):
{
MOE_SORTING_DISPATCH
(
9
);
}
case
(
10
):
{
MOE_SORTING_DISPATCH
(
10
);
}
case
(
11
):
{
MOE_SORTING_DISPATCH
(
11
);
}
default:
{
MOE_SORTING_DISPATCH
(
4
);
}
}
}
return
-
1
;
}
include/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
View file @
8f4dc357
...
...
@@ -63,7 +63,7 @@ struct MoeSortingKernel
CK_TILE_HOST
static
constexpr
auto
BlockSize
(
const
Hargs
&
h
)
{
// TODO: need pad to multiply of warp size
return
dim3
(
ck_tile
::
max
(
h
.
num_experts
,
ck_tile
::
get_warp_size
()));
return
dim3
(
ck_tile
::
integer_least_multiple
(
h
.
num_experts
,
ck_tile
::
get_warp_size
()));
}
// in byte
...
...
@@ -119,12 +119,11 @@ struct MoeSortingKernel
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
num_experts
;
// 1: (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
i
)]
=
0
;
}
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
topk_id
[
i
])];
...
...
@@ -157,7 +156,6 @@ struct MoeSortingKernel
}
*
total_tokens_post_pad
=
unit_size_mdiv
.
div
(
cumsum
[
num_experts
]);
}
__syncthreads
();
if
(
tid
<
num_experts
)
{
...
...
@@ -167,6 +165,7 @@ struct MoeSortingKernel
}
}
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
index_t
expert_id
=
topk_id
[
i
];
...
...
include/ck_tile/ops/moe_sorting/pipeline/moe_sorting_problem.hpp
View file @
8f4dc357
...
...
@@ -9,14 +9,15 @@
namespace
ck_tile
{
template
<
typename
IndexType_
,
typename
WeightType_
>
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
InternalLoadUnroll_
>
struct
MoeSortingProblem
{
// TODO: this kernel only support warp per row
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
InternalLoadUnroll
=
InternalLoadUnroll_
;
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment