Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
58db931e
Commit
58db931e
authored
Feb 14, 2025
by
coderfeli
Browse files
fix topk id
parent
84b27d75
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
...ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
+4
-2
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
View file @
58db931e
...
@@ -1170,7 +1170,6 @@ struct GridwiseMoeGemmGather
...
@@ -1170,7 +1170,6 @@ struct GridwiseMoeGemmGather
if
(
token_pos
>=
max_token_id
||
token0
>=
problem
.
NumTokens
)
if
(
token_pos
>=
max_token_id
||
token0
>=
problem
.
NumTokens
)
return
;
return
;
const
index_t
topk_id
=
(
p_sorted_token_ids
[
block_m_id
*
MPerBlock
]
&
0xff000000
)
>>
24
;
StaticallyIndexedArray
<
index_t
,
AMRepeats
>
gather_offsets
;
//= p_sorted_token_ids[token_pos];
StaticallyIndexedArray
<
index_t
,
AMRepeats
>
gather_offsets
;
//= p_sorted_token_ids[token_pos];
static_for
<
0
,
AMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
AMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
const
index_t
token_offset
=
(
token_pos
+
m0
<
max_token_id
)
?
const
index_t
token_offset
=
(
token_pos
+
m0
<
max_token_id
)
?
...
@@ -1463,8 +1462,11 @@ struct GridwiseMoeGemmGather
...
@@ -1463,8 +1462,11 @@ struct GridwiseMoeGemmGather
StaticallyIndexedArray
<
float
,
EMRepeats
>
scatter_weights
;
//= for topk
StaticallyIndexedArray
<
float
,
EMRepeats
>
scatter_weights
;
//= for topk
// too hack here, 2 specific for topk weights, fixme
// too hack here, 2 specific for topk weights, fixme
const
float
*
p_sorted_weights
=
p_ds_grid
[
I0
];
const
float
*
p_sorted_weights
=
p_ds_grid
[
I0
];
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
scatter_offsets
(
m0
)
=
((
p_sorted_token_ids
[
c_token_pos
+
m0
]
&
0xffffff
)
*
problem
.
TopK
+
topk_id
)
*
problem
.
N
;
const
index_t
fused_token
=
p_sorted_token_ids
[
c_token_pos
+
m0
];
scatter_offsets
(
m0
)
=
((
fused_token
&
0xffffff
)
*
problem
.
TopK
+
(
fused_token
>>
24
))
*
problem
.
N
;
scatter_weights
(
m0
)
=
p_sorted_weights
[(
c_token_pos
+
m0
)
*
problem
.
StrideDs
[
0
]];
scatter_weights
(
m0
)
=
p_sorted_weights
[(
c_token_pos
+
m0
)
*
problem
.
StrideDs
[
0
]];
// if(threadIdx.x % 16 == 0)
// if(threadIdx.x % 16 == 0)
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment