Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e44e7a95
Commit
e44e7a95
authored
Nov 01, 2024
by
Po Yen, Chen
Browse files
Run remod.py under include/ck_tile & example/ck_tile directories
parent
9964919d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
22 deletions
+24
-22
example/ck_tile/12_moe_sorting/moe_sorting_api.cpp
example/ck_tile/12_moe_sorting/moe_sorting_api.cpp
+3
-3
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-1
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+10
-9
include/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
...ude/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
+10
-9
No files found.
example/ck_tile/12_moe_sorting/moe_sorting_api.cpp
View file @
e44e7a95
include/ck_tile/host.hpp
View file @
e44e7a95
...
...
@@ -23,12 +23,12 @@
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
include/ck_tile/host/reference/reference_moe_sorting.hpp
View file @
e44e7a95
...
...
@@ -22,8 +22,8 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
std
::
vector
<
std
::
vector
<
IndexType
>>
expert_tokens
(
experts
,
std
::
vector
<
IndexType
>
(
unit_size
,
num_token
));
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
std
::
vector
<
IndexType
>
expert_slice_idxs
(
experts
,
0
);
...
...
@@ -60,8 +60,9 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
{
memcpy
(
out_tokens
,
expert_tokens
[
e
].
data
(),
sizeof
(
index_t
)
*
expert_slices
[
e
]
*
unit_size
);
out_tokens
+=
expert_slices
[
e
]
*
unit_size
;
memcpy
(
out_weights
,
expert_token_weights
[
e
].
data
(),
sizeof
(
WeightType
)
*
expert_slices
[
e
]
*
unit_size
);
memcpy
(
out_weights
,
expert_token_weights
[
e
].
data
(),
sizeof
(
WeightType
)
*
expert_slices
[
e
]
*
unit_size
);
out_weights
+=
expert_slices
[
e
]
*
unit_size
;
for
(
index_t
s
=
0
;
s
<
expert_slices
[
e
];
s
++
)
...
...
include/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
View file @
e44e7a95
...
...
@@ -71,7 +71,7 @@ struct MoeSortingKernel
tokens_cnts
[
calc_index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
calc_index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_id
[
i
])];
}
...
...
@@ -95,7 +95,8 @@ struct MoeSortingKernel
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
max
(
integer_divide_ceil
(
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
unit_size
),
max
(
integer_divide_ceil
(
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
unit_size
),
1
)
*
unit_size
;
}
...
...
@@ -137,12 +138,12 @@ struct MoeSortingKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
const
size_t
numel
=
kargs
.
tokens
*
kargs
.
topk
;
return
moe_align_block_size_kernel
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
IndexType
*>
(
kargs
.
sorted_token_ids
),
static_cast
<
WeightType
*>
(
kargs
.
sorted_weights
),
static_cast
<
IndexType
*>
(
kargs
.
expert_ids
),
static_cast
<
IndexType
*>
(
kargs
.
total_tokens_post_pad
),
return
moe_align_block_size_kernel
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
IndexType
*>
(
kargs
.
sorted_token_ids
),
static_cast
<
WeightType
*>
(
kargs
.
sorted_weights
),
static_cast
<
IndexType
*>
(
kargs
.
expert_ids
),
static_cast
<
IndexType
*>
(
kargs
.
total_tokens_post_pad
),
kargs
.
num_experts
,
kargs
.
unit_size
,
numel
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment