Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c8517e46
"profiler/vscode:/vscode.git/clone" did not exist on "afc7d431d976113660f5e8e3b4d8453336fa136e"
Unverified
Commit
c8517e46
authored
Feb 11, 2025
by
jakpiase
Committed by
GitHub
Feb 11, 2025
Browse files
Merge branch 'develop' into jakpiase/ck_tile_examples_package
parents
17aa1102
c0adab48
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
936 additions
and
71 deletions
+936
-71
example/ck_tile/13_moe_sorting/moe_sorting.cpp
example/ck_tile/13_moe_sorting/moe_sorting.cpp
+57
-6
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
+82
-0
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
+2
-1
example/ck_tile/13_moe_sorting/script/smoke_test.sh
example/ck_tile/13_moe_sorting/script/smoke_test.sh
+8
-0
example/ck_tile/15_fused_moe/README.md
example/ck_tile/15_fused_moe/README.md
+1
-1
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
...e/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+74
-0
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+24
-2
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+1
-1
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+1
-1
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+634
-59
include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp
+52
-0
No files found.
example/ck_tile/13_moe_sorting/moe_sorting.cpp
View file @
c8517e46
...
...
@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
.
insert
(
"k"
,
"4"
,
"topk"
)
.
insert
(
"unit"
,
"32"
,
"unit_size"
)
.
insert
(
"moe_buf_size"
,
"0"
,
"moe_buf_size"
)
.
insert
(
"local_eid"
,
"-1"
,
"a list of experts enabled as local expert. e.g.
\"
0,1,4,5
\"\n
"
"please make sure eid is in ascending order!"
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"kname"
,
"0"
,
"when set to 1 it will print kernel name"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
...
...
@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int
kname
=
args
.
get_int
(
"kname"
);
int
warmup
=
args
.
get_int
(
"warmup"
);
int
repeat
=
args
.
get_int
(
"repeat"
);
int
max_output_ids
=
ck_tile
::
integer_least_multiple
(
topk
*
tokens
+
num_experts
*
unit_size
-
topk
,
unit_size
);
...
...
@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
return
false
;
}
bool
local_expert_masking
=
args
.
get_str
(
"local_eid"
)
!=
"-1"
;
auto
local_expert_masking_host
=
[
&
]()
{
if
(
local_expert_masking
)
{
auto
local_eid
=
args
.
get_int_vec
(
"local_eid"
);
// std::vector<int> v_ {num_experts, 0};
ck_tile
::
HostTensor
<
IndexType
>
v_
{{
num_experts
}};
v_
.
SetZero
();
for
(
auto
eid
:
local_eid
)
{
if
(
eid
>=
num_experts
)
{
throw
std
::
runtime_error
(
"local_eid larger than number of expert, please check"
);
}
v_
.
mData
[
eid
]
=
1
;
}
return
v_
;
}
else
// return std::vector<int>{};
return
ck_tile
::
HostTensor
<
IndexType
>
{{
1
}};
}();
// tokens already considered batch size
ck_tile
::
HostTensor
<
IndexType
>
topk_ids_host
({
tokens
,
topk
},
{
topk
,
1
});
ck_tile
::
HostTensor
<
WeightType
>
weights_host
({
tokens
,
topk
},
{
topk
,
1
});
...
...
@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_expert_ids_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_id_cnt_dev
(
sorted_id_cnt_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
moe_buf_dev
(
moe_buf_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
local_expert_masking_dev
(
local_expert_masking_host
.
get_element_space_size_in_bytes
());
topk_ids_dev
.
ToDevice
(
topk_ids_host
.
data
());
weights_dev
.
ToDevice
(
weights_host
.
data
());
...
...
@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
{
moe_buf_dev
.
ToDevice
(
moe_buf_host
.
data
());
}
if
(
local_expert_masking
)
local_expert_masking_dev
.
ToDevice
(
local_expert_masking_host
.
data
());
moe_sorting_trait
trait
{
index_prec
,
weight_prec
};
moe_sorting_trait
trait
{
index_prec
,
weight_prec
,
local_expert_masking
};
moe_sorting_args
karg
{
topk_ids_dev
.
GetDeviceBuffer
(),
weights_dev
.
GetDeviceBuffer
(),
local_expert_masking
?
local_expert_masking_dev
.
GetDeviceBuffer
()
:
nullptr
,
sorted_ids_dev
.
GetDeviceBuffer
(),
sorted_weights_dev
.
GetDeviceBuffer
(),
sorted_expert_ids_dev
.
GetDeviceBuffer
(),
...
...
@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
warmup
,
repeat
};
auto
ms
=
moe_sorting
(
trait
,
karg
,
sc
);
printf
(
"[%s|%s]tokens:%d, num_experts:%d, topk:%d,
ms:%f ,
"
,
printf
(
"[%s|%s]tokens:%d, num_experts:%d, topk:%d, "
,
index_prec
.
c_str
(),
weight_prec
.
c_str
(),
tokens
,
num_experts
,
topk
,
ms
);
topk
);
if
(
local_expert_masking
)
{
printf
(
"local_eid:%s, "
,
args
.
get_str
(
"local_eid"
).
c_str
());
}
if
(
ms
<
0
)
printf
(
"not supported
\n
"
);
else
printf
(
"ms:%f, "
,
ms
);
fflush
(
stdout
);
if
(
ms
<
0
)
{
...
...
@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int32_t
ref_total_tokens_post_pad
=
0
;
ck_tile
::
reference_moe_sorting
<
WeightType
,
IndexType
>
(
topk_ids_host
,
weights_host
,
local_expert_masking_host
,
sorted_ids_ref
,
sorted_weights_ref
,
sorted_expert_ids_ref
,
ref_total_tokens_post_pad
,
num_experts
,
unit_size
);
unit_size
,
local_expert_masking
);
rtn
&=
ck_tile
::
check_err
(
sorted_ids_host
,
sorted_ids_ref
,
std
::
string
(
"OUT Error: Incorrect ids!"
),
1e-6
,
1e-6
);
rtn
&=
ck_tile
::
check_err
(
sorted_weights_host
,
...
...
@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
moe_buf_host
,
moe_buf_ref
,
std
::
string
(
"OUT Error: Incorrect zero buf!"
),
0
,
0
);
}
rtn
&=
ref_total_tokens_post_pad
==
sorted_id_cnt_host
.
mData
[
0
];
printf
(
"total_tokens_post_pad:%d(%d), "
,
ref_total_tokens_post_pad
,
sorted_id_cnt_host
.
mData
[
0
]);
}
printf
(
"valid:%s
\n
"
,
rtn
?
"y"
:
"n"
);
printf
(
"valid:%s"
,
rtn
?
"y"
:
"n"
);
fflush
(
stdout
);
if
(
!
rtn
)
printf
(
", (%d)"
,
seed
);
printf
(
"
\n
"
);
fflush
(
stdout
);
return
rtn
;
}
...
...
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
View file @
c8517e46
...
...
@@ -3,6 +3,12 @@
#include "moe_sorting_api.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
...
...
@@ -17,6 +23,67 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
...
...
@@ -38,11 +105,13 @@
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
#endif
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
{
#if !MOE_SORTING_USE_EX_KERNEL
if
(
a
.
num_experts
>
127
)
{
printf
(
"lds size exceed, only support experts <127
\n
"
);
...
...
@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
MOE_SORTING_DISPATCH
(
4
);
}
}
#else
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
auto
[
r_
,
c_
]
=
ck_tile
::
moe_sorting_get_smem_row_col
(
a
.
tokens
,
a
.
num_experts
);
auto
sub_token_
=
r_
-
2
;
r_
=
(
r_
-
2
)
/
8
;
bool
is_sub_token_onshot
=
a
.
tokens
<=
sub_token_
;
bool
is_local_expert_masking
=
t
.
local_expert_masking
;
(
void
)
c_
;
MOE_SORTING_DISPATCH_EMASK_
(
r_
);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
return
-
1
;
}
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
View file @
c8517e46
...
...
@@ -10,7 +10,8 @@
struct
moe_sorting_trait
{
std
::
string
index_type
;
std
::
string
weight_type
;
// currently always float
std
::
string
weight_type
;
// currently always float
bool
local_expert_masking
;
// if mask experts as local expert
};
struct
moe_sorting_args
:
public
ck_tile
::
MoeSortingHostArgs
...
...
example/ck_tile/13_moe_sorting/script/smoke_test.sh
View file @
c8517e46
...
...
@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
$EXE
-t
=
1
-e
=
1
-k
=
1
$EXE
-t
=
99
-e
=
2
-k
=
1
$EXE
-t
=
333
-e
=
99
-k
=
13
$EXE
-t
=
11
-e
=
256
-k
=
5
$EXE
-t
=
64
-e
=
455
-k
=
8
$EXE
-t
=
777
-e
=
802
-k
=
99
$EXE
-t
=
4097
-e
=
906
-k
=
51
$EXE
-t
=
128
-e
=
32
-k
=
5
-moe_buf_size
=
262144
$EXE
-t
=
13
-e
=
64
-k
=
3
-local_eid
=
4,5,6,7,8,9,10,11
$EXE
-t
=
99
-e
=
33
-k
=
9
-local_eid
=
6,10,11,15,19
$EXE
-t
=
80
-e
=
99
-k
=
10
-local_eid
=
0,8,12,33
$EXE
-t
=
11
-e
=
256
-k
=
5
-local_eid
=
99,110,129
example/ck_tile/15_fused_moe/README.md
View file @
c8517e46
...
...
@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator:
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
View file @
c8517e46
...
...
@@ -3,6 +3,12 @@
#include "fused_moesorting.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
...
...
@@ -17,6 +23,24 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
using ms_problem = \
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
...
...
@@ -38,11 +62,13 @@
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
#endif
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
{
#if !MOE_SORTING_USE_EX_KERNEL
if
(
a
.
num_experts
>
127
)
{
printf
(
"lds size exceed, only support experts <127
\n
"
);
...
...
@@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
MOE_SORTING_DISPATCH
(
4
);
}
}
#else
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
auto
[
r_
,
c_
]
=
ck_tile
::
moe_sorting_get_smem_row_col
(
a
.
tokens
,
a
.
num_experts
);
auto
sub_token_
=
r_
-
2
;
r_
=
(
r_
-
2
)
/
8
;
bool
is_sub_token_onshot
=
a
.
tokens
<=
sub_token_
;
(
void
)
c_
;
if
(
is_sub_token_onshot
)
{
if
(
r_
%
8
==
0
)
{
MOE_SORTING_DISPATCH_
(
8
,
true
);
}
else
if
(
r_
%
4
==
0
)
{
MOE_SORTING_DISPATCH_
(
4
,
true
);
}
else
if
(
r_
%
2
==
0
)
{
MOE_SORTING_DISPATCH_
(
2
,
true
);
}
else
{
MOE_SORTING_DISPATCH_
(
1
,
true
);
}
}
else
{
if
(
r_
%
8
==
0
)
{
MOE_SORTING_DISPATCH_
(
8
,
false
);
}
else
if
(
r_
%
4
==
0
)
{
MOE_SORTING_DISPATCH_
(
4
,
false
);
}
else
if
(
r_
%
2
==
0
)
{
MOE_SORTING_DISPATCH_
(
2
,
false
);
}
else
{
MOE_SORTING_DISPATCH_
(
1
,
false
);
}
}
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
return
-
1
;
}
include/ck_tile/host/reference/reference_moe_sorting.hpp
View file @
c8517e46
...
...
@@ -14,12 +14,15 @@ namespace ck_tile {
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
const
HostTensor
<
WeightType
>&
weights
,
const
HostTensor
<
IndexType
>&
local_expert_mask
,
HostTensor
<
IndexType
>&
p_sorted_token_ids
,
HostTensor
<
WeightType
>&
sorted_weight
,
HostTensor
<
IndexType
>&
sorted_expert_ids
,
index_t
&
unit_cnt
,
const
index_t
experts
,
const
index_t
unit_size
)
const
index_t
unit_size
,
bool
local_expert_masking
,
bool
skip_experts_with_zero_token
=
true
)
{
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
...
...
@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
#endif
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
// count number of unit-size slices in this expert
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
// count the tokens used in this expert
std
::
vector
<
IndexType
>
expert_slice_idxs
(
experts
,
0
);
// TODO: above 2 buffer seems duplicated
for
(
index_t
t
=
0
;
t
<
num_token
;
t
++
)
{
...
...
@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
IndexType
*
out_tokens
=
p_sorted_token_ids
.
data
();
WeightType
*
out_weights
=
sorted_weight
.
data
();
IndexType
*
out_expert_id
=
sorted_expert_ids
.
data
();
int
curr_expert_id
=
0
;
for
(
index_t
e
=
0
;
e
<
experts
;
e
++
)
{
if
(
local_expert_masking
)
{
if
(
local_expert_mask
(
e
)
==
0
)
continue
;
}
if
(
skip_experts_with_zero_token
)
{
if
(
expert_slice_idxs
[
e
]
==
0
)
{
curr_expert_id
++
;
continue
;
}
}
memcpy
(
out_tokens
,
expert_tokens
[
e
].
data
(),
sizeof
(
index_t
)
*
expert_slices
[
e
]
*
unit_size
);
out_tokens
+=
expert_slices
[
e
]
*
unit_size
;
memcpy
(
out_weights
,
...
...
@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
for
(
index_t
s
=
0
;
s
<
expert_slices
[
e
];
s
++
)
{
out_expert_id
[
s
]
=
e
;
out_expert_id
[
s
]
=
curr_expert_id
;
unit_cnt
++
;
}
out_expert_id
+=
expert_slices
[
e
];
curr_expert_id
++
;
}
unit_cnt
*=
unit_size
;
return
;
...
...
include/ck_tile/ops/fused_moe.hpp
View file @
c8517e46
...
...
@@ -7,6 +7,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
...
...
@@ -14,7 +15,6 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
c8517e46
...
...
@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
View file @
c8517e46
...
...
@@ -15,6 +15,10 @@ namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
...
...
@@ -28,7 +32,7 @@ namespace ck_tile {
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
@@ -55,6 +59,34 @@ namespace ck_tile {
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// skip_experts_with_zero_tokens(SkipExpertsWithZeroTokens)
// if enabled, the expert with no tokens will be skipped, in stead of padding to at least 1 unit_size(M_a)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 5]
// num_tokens_post_padded_ptr : [24]
//
// * local_expert_mask : indicate local expert mask used on current GPU (used for EP case)
// and modify the output expert-ID, because we will only have enbaled expert on specific GPU.
// we call expert input to this kernel as "global expert id", output as "local expert id"
//
// * local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4)
//
// (pack below tensor, skip element marked with `-`)
// Y Y Y Y - - - - Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// sorted_expert_ids_ptr : [0, 1, 2, 2, 3] (note original it was exper-id= 0, 2, 3, 5, but we produce "local expert id")
// num_tokens_post_padded_ptr : [20]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
...
...
@@ -67,10 +99,80 @@ namespace ck_tile {
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
CK_TILE_HOST
constexpr
auto
moe_sorting_get_smem_row_col
(
int
num_tokens_
,
int
num_experts_
)
{
/* num_experts + 1
* +--------------------------------------+
* | |
* | |
* | | * -> sub-tokens
* | |
* | |
* +--------------------------------------+
* | | 2 -> cumsum buffer
* +--------------------------------------+
*
*/
int
smem_cols
=
num_experts_
+
1
;
// usually experts is power of 2. padding here
int
smem_rows
=
[
&
](){
index_t
target_occupancy_
=
2
;
constexpr
index_t
total_
=
65536
/
sizeof
(
int
);
constexpr
index_t
sub_unroll
=
8
;
constexpr
index_t
cumsum_bufs
=
2
;
// 1 for cumsum, 1 for cnt
// at lease 2 lines, one for sub_token unroll, one for cumsum
// should be enough
if
((
total_
/
target_occupancy_
)
<
((
cumsum_bufs
+
sub_unroll
)
*
smem_cols
))
{
if
((
total_
/
1
)
<
((
cumsum_bufs
+
sub_unroll
)
*
smem_cols
))
throw
std
::
runtime_error
(
"too many num_experts, can't allocate smem"
);
target_occupancy_
=
1
;
}
int
r
=
total_
/
target_occupancy_
/
smem_cols
;
// round to sub_unroll multipl
int
r_for_sub_token
=
r
-
cumsum_bufs
;
r_for_sub_token
=
min
(
r_for_sub_token
,
num_tokens_
);
r_for_sub_token
=
(
r_for_sub_token
+
sub_unroll
-
1
)
/
sub_unroll
*
sub_unroll
;
r_for_sub_token
=
max
(
r_for_sub_token
,
1
);
if
(
r_for_sub_token
>
1
)
{
int
r_unroll_
=
r_for_sub_token
/
sub_unroll
;
// round to 1x/2x/4x/8x number of sub_unroll
int
clz_
=
__builtin_clz
(
r_unroll_
);
// 0b1:31 0b2:30, 0b3:30, 0b4:29
int
mask_
=
(
1
<<
(
31
-
clz_
))
-
1
;
mask_
=
mask_
>
0b111
?
0b111
:
mask_
;
//clamp to 8x at most
mask_
=
~
mask_
;
//printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout);
r_for_sub_token
=
(
r_unroll_
&
mask_
)
*
sub_unroll
;
}
// final check
if
(
(
r_for_sub_token
+
cumsum_bufs
*
smem_cols
*
target_occupancy_
)
>=
total_
)
{
throw
std
::
runtime_error
(
"can't run this kernel, request LDS over size"
);
}
return
r_for_sub_token
+
cumsum_bufs
;
}();
// printf("r:%d, c:%d\n", smem_rows, smem_cols);
return
ck_tile
::
make_tuple
(
smem_rows
,
smem_cols
);
}
struct
MoeSortingHostArgs
{
const
void
*
p_topk_ids
;
// [token, topk]
const
void
*
p_weights
;
// [token, topk]
const
void
*
p_local_expert_mask
;
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
...
...
@@ -101,6 +203,7 @@ struct MoeSortingKernel
{
const
void
*
p_topk_ids
;
const
void
*
p_weights
;
const
void
*
p_local_expert_mask
;
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
...
...
@@ -111,8 +214,11 @@ struct MoeSortingKernel
index_t
moe_buf_bytes
;
index_t
tokens_per_thread
;
index_t
smem_rows
;
mdiv
unit_size_mdiv
;
mdiv
topk_mdiv
;
mdiv
expert_mdiv
;
// mdiv sub_tokens_mdiv;
};
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
...
...
@@ -123,15 +229,25 @@ struct MoeSortingKernel
CK_TILE_HOST
static
constexpr
auto
BlockSize
(
const
Hargs
&
h
)
{
#if MOE_SORTING_USE_EX_KERNEL
(
void
)
h
;
return
dim3
(
256
);
#else
return
dim3
(
ck_tile
::
integer_least_multiple
(
h
.
num_experts
,
ck_tile
::
get_warp_size
()));
#endif
}
// in byte
CK_TILE_HOST
static
constexpr
auto
GetSmemSize
(
const
Hargs
&
h
)
{
#if MOE_SORTING_USE_EX_KERNEL
auto
[
smem_rows
,
smem_cols
]
=
moe_sorting_get_smem_row_col
(
h
.
tokens
,
h
.
num_experts
);
return
smem_rows
*
smem_cols
*
sizeof
(
int
);
#else
const
auto
blocks
=
BlockSize
(
h
);
// usually num_experts is power of 2, we pad 1 dword here for the row-size
return
((
blocks
.
x
+
1
)
*
(
h
.
num_experts
+
1
)
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
#endif
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
...
...
@@ -139,6 +255,7 @@ struct MoeSortingKernel
Kargs
k
;
k
.
p_topk_ids
=
h
.
p_topk_ids
;
k
.
p_weights
=
h
.
p_weights
;
k
.
p_local_expert_mask
=
h
.
p_local_expert_mask
;
k
.
p_sorted_token_ids
=
h
.
p_sorted_token_ids
;
k
.
p_sorted_weights
=
h
.
p_sorted_weights
;
k
.
p_sorted_expert_ids
=
h
.
p_sorted_expert_ids
;
...
...
@@ -152,10 +269,18 @@ struct MoeSortingKernel
k
.
tokens_per_thread
=
integer_divide_ceil
(
h
.
tokens
*
h
.
topk
,
blocks
.
x
);
k
.
unit_size_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
unit_size
)};
k
.
topk_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
topk
)};
k
.
smem_rows
=
[
&
](){
auto
[
r_
,
c_
]
=
moe_sorting_get_smem_row_col
(
h
.
tokens
,
h
.
num_experts
);
(
void
)
c_
;
return
r_
;
}();
k
.
expert_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
num_experts
)};
// k.sub_tokens_mdiv = mdiv{static_cast<uint32_t>(k.smem_rows - 1)};
return
k
;
}
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
// NOTE: wave_size need at least be 16!! dpp 16 is one row
template
<
typename
data_t
,
int
wave_size
>
__device__
inline
void
wave_cumsum
(
data_t
&
thread_data
)
const
{
...
...
@@ -196,6 +321,40 @@ struct MoeSortingKernel
bank_mask
,
bound_ctrl
)));
// row_shr:4
}
if
constexpr
(
wave_size
==
8
)
{
// wave-size=8 need one extra shift
thread_data
=
reduce_op
(
thread_data
,
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x118
,
row_mask
,
bank_mask
,
bound_ctrl
)));
// row_shr:8
#if 0
constexpr int bank_mask_0_7 = 0b1100;
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
__builtin_amdgcn_update_dpp(0, /* old value */
__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask_0_7,
bound_ctrl))// row_newbcast:7
);
#else
data_t
xxx
=
__builtin_bit_cast
(
data_t
,
__builtin_amdgcn_mov_dpp
(
__builtin_bit_cast
(
int
,
thread_data
),
0x157
,
row_mask
,
bank_mask
,
bound_ctrl
));
// row_newbcast:7
data_t
yyy
=
(
__lane_id
()
/
8
)
%
2
==
0
?
0
:
xxx
;
thread_data
=
thread_data
-
yyy
;
#endif
}
if
constexpr
(
wave_size
>
8
)
{
thread_data
=
...
...
@@ -224,6 +383,36 @@ struct MoeSortingKernel
}
}
// reduce single pixel within a wave
template
<
typename
T
,
typename
F
,
index_t
wave_size_
=
warpSize
>
__device__
static
constexpr
T
wave_reduce
(
T
local
,
F
reduce_f
,
number
<
wave_size_
>
=
{})
{
// constexpr int wave_size = 64;
// constexpr int reduce_stage = 6; // 1<<6=64
// clang-format off
constexpr
int
reduce_stage
=
[](){
if
constexpr
(
wave_size_
==
2
)
return
1
;
else
if
constexpr
(
wave_size_
==
4
)
return
2
;
else
if
constexpr
(
wave_size_
==
8
)
return
3
;
else
if
constexpr
(
wave_size_
==
16
)
return
4
;
else
if
constexpr
(
wave_size_
==
32
)
return
5
;
else
if
constexpr
(
wave_size_
==
64
)
return
6
;
else
return
0
;
}();
// clang-format on
T
v_local
=
local
;
#pragma unroll reduce_stage
for
(
int
i_stage
=
0
;
i_stage
<
reduce_stage
;
i_stage
++
)
{
int
src_lane
=
__lane_id
()
^
(
1
<<
i_stage
);
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
v_local
));
T
v_remote
=
bit_cast
<
T
>
(
v_remote_tmp
);
v_local
=
reduce_f
(
v_local
,
v_remote
);
}
return
v_local
;
}
CK_TILE_DEVICE
index_t
calc_index
(
index_t
total_col
,
index_t
row
,
index_t
col
)
const
{
return
row
*
total_col
+
col
;
...
...
@@ -257,37 +446,37 @@ struct MoeSortingKernel
index_t
*
shared_mem
=
reinterpret_cast
<
index_t
*>
(
smem
);
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
(
num_experts
+
1
);
// 1: (num_experts + 1)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
(
num_experts
+
1
);
// 1: (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
)]
=
0
;
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
)]
=
0
;
}
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
topk_id
[
i
])];
++
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
topk_id
[
i
])];
}
__syncthreads
();
#if 1
if
(
tid
<
num_experts
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
index_t
local_c
[
8
];
index_t
prev_c
=
0
;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
i
+=
8
)
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
i
+=
8
)
{
local_c
[
0
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)];
local_c
[
1
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)];
local_c
[
2
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)];
local_c
[
3
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)];
local_c
[
4
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)];
local_c
[
5
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)];
local_c
[
6
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)];
local_c
[
7
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)];
local_c
[
0
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)];
local_c
[
1
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)];
local_c
[
2
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)];
local_c
[
3
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)];
local_c
[
4
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)];
local_c
[
5
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)];
local_c
[
6
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)];
local_c
[
7
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)];
local_c
[
0
]
+=
prev_c
;
local_c
[
1
]
+=
local_c
[
0
];
...
...
@@ -299,51 +488,57 @@ struct MoeSortingKernel
local_c
[
7
]
+=
local_c
[
6
];
prev_c
=
local_c
[
7
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)]
=
local_c
[
0
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)]
=
local_c
[
1
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)]
=
local_c
[
2
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)]
=
local_c
[
3
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)]
=
local_c
[
4
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)]
=
local_c
[
5
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)]
=
local_c
[
6
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)]
=
local_c
[
7
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
0
,
tid
)]
=
local_c
[
0
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
1
,
tid
)]
=
local_c
[
1
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
2
,
tid
)]
=
local_c
[
2
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
3
,
tid
)]
=
local_c
[
3
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
4
,
tid
)]
=
local_c
[
4
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
5
,
tid
)]
=
local_c
[
5
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
6
,
tid
)]
=
local_c
[
6
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
i
+
7
,
tid
)]
=
local_c
[
7
];
}
}
#else
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future
// heuristic
{
if
(
tid
<
num_experts
)
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
i
+=
8
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
0
,
tid
)]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
i
+=
8
)
{
index_t
local_c
[
8
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
local_c
[
j
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
local_c
[
j
]
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)];
}
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
wave_cumsum
<
int
,
64
>
(
local_c
[
j
]);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)]
=
local_c
[
j
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
+
1
,
i
+
j
)]
=
local_c
[
j
];
}
}
}
#endif
__syncthreads
();
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
(
tid
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
auto
current_units
=
[
&
]()
{
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
i
-
1
)]
+
unit_size_mdiv
.
divisor
-
1
;
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
i
-
1
)]
+
unit_size_mdiv
.
divisor
-
1
;
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
}();
...
...
@@ -351,20 +546,24 @@ struct MoeSortingKernel
}
*
p_total_tokens_post_pad
=
cumsum
[
num_experts
];
}
}
else
{
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert)
// for simplicity, not check experts here.
int
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
}
else
{
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >=
// expert) for simplicity, not check experts here.
int
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
int
blocks_pers_expert
=
unit_size_mdiv
.
div
(
local_cnt
+
unit_size_mdiv
.
divisor
-
1
);
int
padded_tokens_per_expert
=
max
(
blocks_pers_expert
,
1
)
*
unit_size_mdiv
.
divisor
;
int
local_cumsum
=
padded_tokens_per_expert
;
int
local_cumsum
=
padded_tokens_per_expert
;
wave_cumsum
<
int
,
64
>
(
local_cumsum
);
if
(
tid
==
(
num_experts
-
1
))
{
cumsum
[
0
]
=
0
;
if
(
tid
==
(
num_experts
-
1
))
{
cumsum
[
0
]
=
0
;
*
p_total_tokens_post_pad
=
local_cumsum
;
}
if
(
tid
<
num_experts
)
{
if
(
tid
<
num_experts
)
{
cumsum
[
tid
+
1
]
=
local_cumsum
;
}
}
...
...
@@ -373,7 +572,7 @@ struct MoeSortingKernel
if
(
tid
<
num_experts
)
{
int
e_start
=
cumsum
[
tid
];
int
e_end
=
cumsum
[
tid
+
1
];
int
e_end
=
cumsum
[
tid
+
1
];
for
(
int
i
=
e_start
;
i
<
e_end
;
i
+=
unit_size_mdiv
.
divisor
)
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
...
...
@@ -383,8 +582,8 @@ struct MoeSortingKernel
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
index_t
expert_id
=
topk_id
[
i
];
index_t
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)];
index_t
expert_id
=
topk_id
[
i
];
index_t
local_cnt
=
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)];
index_t
rank_post_pad
=
local_cnt
+
cumsum
[
expert_id
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t
curr_token_id
,
curr_topk_id
;
...
...
@@ -393,16 +592,17 @@ struct MoeSortingKernel
#else
p_sorted_token_ids
[
rank_post_pad
]
=
topk_mdiv
.
div
(
i
);
#endif
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)]
=
local_cnt
+
1
;
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
tokens_cnts
[
calc_index
(
num_experts
+
1
,
tid
,
expert_id
)]
=
local_cnt
+
1
;
}
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
if
constexpr
(
Problem
::
ExpertTile
==
0
)
{
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
if
(
tid
<
num_experts
)
{
index_t
expert_offset
=
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
tid
)];
index_t
expert_end
=
cumsum
[
tid
+
1
];
while
(
expert_offset
<
expert_end
)
{
...
...
@@ -417,16 +617,19 @@ struct MoeSortingKernel
}
}
}
else
{
else
{
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
// TODO: only support expert-tile like 8, 16, 32
static
constexpr
index_t
experts_per_wave
=
warpSize
/
Problem
::
ExpertTile
;
{
index_t
eid
=
tid
/
experts_per_wave
;
index_t
expert_offset
=
cumsum
[
eid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
eid
)]
+
tid
%
experts_per_wave
;
index_t
eid
=
tid
/
experts_per_wave
;
index_t
expert_offset
=
cumsum
[
eid
]
+
tokens_cnts
[
calc_index
(
num_experts
+
1
,
blockDim
.
x
,
eid
)]
+
tid
%
experts_per_wave
;
index_t
expert_end
=
cumsum
[
eid
+
1
];
if
(
eid
<
num_experts
)
{
if
(
eid
<
num_experts
)
{
while
(
expert_offset
<
expert_end
)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
...
...
@@ -436,10 +639,363 @@ struct MoeSortingKernel
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
#endif
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
+=
experts_per_wave
;
expert_offset
+=
experts_per_wave
;
}
}
}
}
}
// only support index_t, and single pixel access
struct
simple_smem_indexer
{
index_t
*
smem
;
index_t
row_stride
;
// this is 2D
CK_TILE_DEVICE
simple_smem_indexer
(
index_t
*
smem_
,
index_t
row_stride_
)
:
smem
(
smem_
),
row_stride
(
row_stride_
)
{
}
CK_TILE_DEVICE
const
index_t
&
operator
()(
index_t
i_row
,
index_t
i_col
)
const
{
return
smem
[
i_row
*
row_stride
+
i_col
];
}
CK_TILE_DEVICE
index_t
&
operator
()(
index_t
i_row
,
index_t
i_col
)
{
return
smem
[
i_row
*
row_stride
+
i_col
];
}
// this is 1D or linear
CK_TILE_DEVICE
simple_smem_indexer
(
index_t
*
smem_
)
:
smem
(
smem_
),
row_stride
(
0
)
{}
CK_TILE_DEVICE
const
index_t
&
operator
()(
index_t
idx
)
const
{
return
smem
[
idx
];
}
CK_TILE_DEVICE
index_t
&
operator
()(
index_t
idx
)
{
return
smem
[
idx
];
}
};
CK_TILE_DEVICE
void
moe_align_block_size_kernel_ex
(
const
IndexType
*
__restrict__
topk_id
,
const
WeightType
*
__restrict__
weights
,
const
IndexType
*
__restrict__
local_expert_mask
,
index_t
*
p_sorted_token_ids
,
WeightType
*
p_sorted_weights
,
index_t
*
p_sorted_expert_ids
,
index_t
*
p_total_tokens_post_pad
,
const
index_t
num_experts
,
const
index_t
tokens
,
const
mdiv
unit_size_mdiv
,
const
mdiv
topk_mdiv
,
const
mdiv
expert_mdiv
,
const
index_t
smem_rows
,
void
*
smem
)
const
{
const
index_t
tid
=
static_cast
<
index_t
>
(
threadIdx
.
x
);
const
index_t
wid
=
__builtin_amdgcn_readfirstlane
(
tid
/
warpSize
);
const
index_t
lid
=
__lane_id
();
constexpr
index_t
block_size
=
256
;
// blockDim.x;
const
index_t
sub_tokens
=
smem_rows
-
2
;
// sub_tokens_mdiv.divisor;
const
index_t
topk
=
topk_mdiv
.
divisor
;
auto
f_sum
=
[](
auto
x_
,
auto
y_
)
{
return
x_
+
y_
;
};
const
index_t
smem_cols
=
num_experts
+
1
;
simple_smem_indexer
smem_cumsum
{
reinterpret_cast
<
index_t
*>
(
smem
)
+
0
};
simple_smem_indexer
smem_cumdup
{
reinterpret_cast
<
index_t
*>
(
smem
)
+
smem_cols
};
simple_smem_indexer
smem_tokens
{
reinterpret_cast
<
index_t
*>
(
smem
)
+
2
*
smem_cols
,
smem_cols
};
// #pragma unroll 8
for
(
int
i
=
tid
;
i
<
(
sub_tokens
*
num_experts
);
i
+=
block_size
)
{
uint32_t
curr_token_id
,
curr_expert_id
;
expert_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_expert_id
);
smem_tokens
(
curr_token_id
,
curr_expert_id
)
=
0
;
}
__syncthreads
();
for
(
int
i_token
=
0
;
i_token
<
tokens
;
i_token
+=
sub_tokens
)
{
// NOTE: below for loop can't have barrier inside!!
for
(
int
i
=
tid
;
i
<
(
sub_tokens
*
topk
);
i
+=
block_size
)
{
uint32_t
curr_token_id
,
curr_topk_id
;
topk_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_topk_id
);
int
i_t
=
i_token
+
curr_token_id
;
if
(
i_t
<
tokens
)
{
int
eid
=
topk_id
[
i_t
*
topk
+
curr_topk_id
];
if
constexpr
(
Problem
::
SubTokenOneShot
)
smem_tokens
(
curr_token_id
,
eid
)
=
curr_topk_id
+
1
;
else
smem_tokens
(
curr_token_id
,
eid
)
++
;
}
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
}
__syncthreads
();
// make sure different i_token iteration not overlap by different wave
}
// counting
if
(
tid
==
0
)
{
smem_cumsum
(
0
)
=
0
;
// smem_cumdup(0) = 0;
}
{
constexpr
int
lane_group_sz
=
8
;
int
lane_group_id
=
tid
/
lane_group_sz
;
int
lane_group_os
=
tid
%
lane_group_sz
;
constexpr
int
lane_group_nm
=
block_size
/
lane_group_sz
;
for
(
int
i_e
=
lane_group_id
;
i_e
<
num_experts
;
i_e
+=
lane_group_nm
)
{
index_t
local_c
[
Problem
::
SubTokenTile
];
index_t
cnt
=
0
;
for
(
int
i
=
0
;
i
<
sub_tokens
;
i
+=
8
*
Problem
::
SubTokenTile
)
{
#pragma unroll Problem::SubTokenTile
for
(
int
j
=
0
;
j
<
Problem
::
SubTokenTile
;
j
++
)
{
local_c
[
j
]
=
smem_tokens
(
i
+
j
*
8
+
lane_group_os
,
i_e
);
if
constexpr
(
Problem
::
SubTokenOneShot
)
{
local_c
[
j
]
=
local_c
[
j
]
!=
0
?
1
:
0
;
}
}
#pragma unroll Problem::SubTokenTile
for
(
int
j
=
0
;
j
<
Problem
::
SubTokenTile
;
j
++
)
{
cnt
+=
wave_reduce
(
local_c
[
j
],
f_sum
,
number
<
8
>
{});
}
}
if
(
lane_group_os
==
0
)
smem_cumsum
(
i_e
+
1
)
=
cnt
;
}
}
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
smem_cumdup
(
0
)
=
0
;
for
(
int
i_e
=
tid
;
i_e
<
num_experts
;
i_e
+=
block_size
)
{
// reuse this buffer
smem_cumdup
(
i_e
+
1
)
=
local_expert_mask
[
i_e
];
}
}
__syncthreads
();
{
if
(
wid
==
0
)
{
// NOTE: under this block can never use __syncthreads!
int
i_e_
=
0
;
int
local_cumsum_
=
0
;
for
(;
i_e_
<
num_experts
;
i_e_
+=
warpSize
)
{
int
pre_cumsum_
=
smem_cumsum
(
lid
==
0
?
i_e_
:
0
);
int
local_cnt
=
smem_cumsum
(
i_e_
+
lid
+
1
);
int
blocks_pers_expert
=
unit_size_mdiv
.
div
(
local_cnt
+
unit_size_mdiv
.
divisor
-
1
);
int
pre_cumsum_masking
=
[
&
]()
{
if
constexpr
(
Problem
::
LocalExpertMasking
)
return
smem_cumdup
(
lid
==
0
?
i_e_
:
0
);
else
return
0
;
// not used
}();
int
local_masking
=
[
&
]()
{
if
constexpr
(
Problem
::
LocalExpertMasking
)
return
smem_cumdup
(
i_e_
+
lid
+
1
);
else
return
0
;
// not used
}();
int
padded_tokens_per_expert
=
[
&
]()
{
int
x_
=
[
&
]()
{
if
constexpr
(
Problem
::
SkipExpertsWithZeroTokens
)
{
// if local_cnt is zero, blocks_pers_expert will be zero
// this is what we want to achieve
return
blocks_pers_expert
*
unit_size_mdiv
.
divisor
;
}
else
{
return
max
(
blocks_pers_expert
,
1
)
*
unit_size_mdiv
.
divisor
;
}
}();
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
return
local_masking
?
x_
:
0
;
}
else
return
x_
;
}();
local_cumsum_
=
padded_tokens_per_expert
;
local_cumsum_
+=
pre_cumsum_
;
// note pre_cumsum must be added after local
// cumsum padded in case local cumsum is zero, but
// pre_sumsum has value, which will result int
// zero local cumsum(but we want at least padded)
wave_cumsum
<
int
,
warpSize
>
(
local_cumsum_
);
if
((
i_e_
+
lid
)
<
num_experts
)
smem_cumsum
(
i_e_
+
lid
+
1
)
=
local_cumsum_
;
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
local_masking
+=
pre_cumsum_masking
;
wave_cumsum
<
int
,
warpSize
>
(
local_masking
);
if
((
i_e_
+
lid
)
<
num_experts
)
smem_cumdup
(
i_e_
+
lid
+
1
)
=
local_masking
;
}
// NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt()
// for above write however __syncthreads will cause barrier with waves other
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
}
if
((
lid
+
i_e_
-
warpSize
)
==
(
num_experts
-
1
))
{
*
p_total_tokens_post_pad
=
local_cumsum_
;
}
}
__syncthreads
();
}
for
(
int
i_e
=
tid
;
i_e
<
num_experts
;
i_e
+=
block_size
)
{
int
e_start
=
smem_cumsum
(
i_e
);
int
e_end
=
smem_cumsum
(
i_e
+
1
);
int
expert_id
=
[
&
]()
{
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
// local expert id from cumsum
return
smem_cumdup
(
i_e
);
}
else
return
i_e
;
}();
smem_cumdup
(
i_e
)
=
e_start
;
// duplicate cumsum for later use
if
constexpr
(
Problem
::
SkipExpertsWithZeroTokens
)
{
if
(
e_start
==
e_end
)
// skip zero token expert
continue
;
}
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
if
(
local_expert_mask
[
i_e
]
==
0
)
continue
;
}
for
(
int
i
=
e_start
;
i
<
e_end
;
i
+=
unit_size_mdiv
.
divisor
)
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
expert_id
;
}
}
smem_cumdup
(
num_experts
)
=
smem_cumsum
(
num_experts
);
// fill the p_sorted_token_ids/p_sorted_weights
for
(
int
i_token
=
0
;
i_token
<
tokens
;
i_token
+=
sub_tokens
)
{
if
constexpr
(
!
Problem
::
SubTokenOneShot
)
{
// clear every time
for
(
int
i
=
tid
;
i
<
(
sub_tokens
*
num_experts
);
i
+=
block_size
)
{
uint32_t
curr_token_id
,
curr_expert_id
;
expert_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_expert_id
);
smem_tokens
(
curr_token_id
,
curr_expert_id
)
=
0
;
}
__syncthreads
();
// load again
for
(
int
i
=
tid
;
i
<
(
sub_tokens
*
topk
);
i
+=
block_size
)
{
uint32_t
curr_token_id_
,
curr_topk_id_
;
topk_mdiv
.
divmod
(
i
,
curr_token_id_
,
curr_topk_id_
);
int
curr_token_id
=
static_cast
<
int
>
(
curr_token_id_
);
int
curr_topk_id
=
static_cast
<
int
>
(
curr_topk_id_
);
int
i_t
=
i_token
+
curr_token_id
;
if
(
i_t
<
tokens
)
{
int
eid
=
topk_id
[
i_t
*
topk
+
curr_topk_id
];
smem_tokens
(
curr_token_id
,
eid
)
=
curr_topk_id
+
1
;
// at least 1
}
}
__syncthreads
();
}
{
constexpr
int
lane_group_sz
=
8
;
int
lane_group_id
=
tid
/
lane_group_sz
;
int
lane_group_os
=
tid
%
lane_group_sz
;
constexpr
int
lane_group_nm
=
block_size
/
lane_group_sz
;
for
(
int
eid
=
lane_group_id
;
eid
<
num_experts
;
eid
+=
lane_group_nm
)
{
if
constexpr
(
Problem
::
LocalExpertMasking
)
{
if
(
local_expert_mask
[
eid
]
==
0
)
continue
;
}
int
position
=
smem_cumsum
(
eid
);
for
(
int
i_sub_token
=
lane_group_os
;
i_sub_token
<
sub_tokens
;
i_sub_token
+=
lane_group_sz
)
{
auto
x
=
smem_tokens
(
i_sub_token
,
eid
);
int
local_cnt_cache
=
x
!=
0
?
1
:
0
;
int
local_cnt
=
local_cnt_cache
;
wave_cumsum
<
int
,
lane_group_sz
>
(
local_cnt
);
if
(
x
!=
0
)
{
// now x is topk value
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
position
+
local_cnt
-
1
]
=
MOE_SORTING_MOCK_ID
(
i_token
+
i_sub_token
,
x
-
1
);
#else
p_sorted_token_ids
[
position
+
local_cnt
-
1
]
=
i_token
+
i_sub_token
;
#endif
p_sorted_weights
[
position
+
local_cnt
-
1
]
=
weights
[(
i_token
+
i_sub_token
)
*
topk
+
x
-
1
];
}
int
remote_cnt
=
__builtin_amdgcn_ds_bpermute
(
(
lane_group_sz
*
(
lane_group_id
+
1
)
-
1
)
<<
2
,
local_cnt
);
position
+=
remote_cnt
;
}
smem_cumsum
(
eid
)
=
position
;
}
}
}
__syncthreads
();
}
// add the skip number
for
(
int
eid
=
tid
;
eid
<
num_experts
;
eid
+=
block_size
)
{
int
e_start
=
smem_cumsum
(
eid
);
int
e_end
=
smem_cumdup
(
eid
+
1
);
if
constexpr
(
Problem
::
SkipExpertsWithZeroTokens
)
{
if
(
e_start
==
e_end
)
// skip zero token expert
continue
;
}
while
(
e_start
<
e_end
)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
e_start
]
=
MOE_SORTING_MOCK_ID
(
tokens
,
topk
);
#else
p_sorted_token_ids
[
e_start
]
=
tokens
;
#endif
p_sorted_weights
[
e_start
]
=
static_cast
<
WeightType
>
(
0.0
);
e_start
++
;
}
}
}
...
...
@@ -456,6 +1012,24 @@ struct MoeSortingKernel
}
const
size_t
numel
=
kargs
.
tokens
*
kargs
.
topk_mdiv
.
divisor
;
extern
__shared__
char
smem
[];
#if MOE_SORTING_USE_EX_KERNEL
(
void
)
numel
;
return
moe_align_block_size_kernel_ex
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
const
IndexType
*>
(
kargs
.
p_local_expert_mask
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_token_ids
),
static_cast
<
WeightType
*>
(
kargs
.
p_sorted_weights
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_expert_ids
),
static_cast
<
IndexType
*>
(
kargs
.
p_total_tokens_post_pad
),
kargs
.
num_experts
,
kargs
.
tokens
,
kargs
.
unit_size_mdiv
,
kargs
.
topk_mdiv
,
kargs
.
expert_mdiv
,
kargs
.
smem_rows
,
smem
);
#else
return
moe_align_block_size_kernel
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_token_ids
),
...
...
@@ -468,6 +1042,7 @@ struct MoeSortingKernel
kargs
.
unit_size_mdiv
,
kargs
.
topk_mdiv
,
smem
);
#endif
}
};
...
...
include/ck_tile/ops/fused_moe/
pipeli
ne/moe_sorting_problem.hpp
→
include/ck_tile/ops/fused_moe/
ker
ne
l
/moe_sorting_problem.hpp
View file @
c8517e46
...
...
@@ -25,4 +25,28 @@ struct MoeSortingProblem
InternalLoadUnroll_
;
// TODO: need better design(like tile size)
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
SubTokenTile_
,
// 1,2,4,8, or 0 in the future
bool
SubTokenOneShot_
,
// if we only loop over once or not
bool
LocalExpertMasking_
,
// used in EP case
bool
SkipExpertsWithZeroTokens_
=
true
,
index_t
ExpertTile_
=
0
>
struct
MoeSortingProblemEx
{
// TODO: this kernel only support warp per row
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
SubTokenTile
=
SubTokenTile_
;
static
constexpr
bool
SubTokenOneShot
=
SubTokenOneShot_
;
static
constexpr
bool
LocalExpertMasking
=
LocalExpertMasking_
;
static
constexpr
bool
SkipExpertsWithZeroTokens
=
SkipExpertsWithZeroTokens_
;
static_assert
(
SubTokenTile
==
1
||
SubTokenTile
==
2
||
SubTokenTile
==
4
||
SubTokenTile
==
8
);
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment