Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d4154c35
Unverified
Commit
d4154c35
authored
May 14, 2025
by
Jinzhen Lin
Committed by
GitHub
May 13, 2025
Browse files
[Bugfix] fix moe marlin `topk_weight` loading (#18080)
Co-authored-by:
mgoin
<
mgoin64@gmail.com
>
parent
6685890d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
7 deletions
+7
-7
csrc/moe/marlin_moe_wna16/marlin_template.h
csrc/moe/marlin_moe_wna16/marlin_template.h
+7
-7
No files found.
csrc/moe/marlin_moe_wna16/marlin_template.h
View file @
d4154c35
...
@@ -473,15 +473,15 @@ __global__ void Marlin(
...
@@ -473,15 +473,15 @@ __global__ void Marlin(
if
(
mul_topk_weights
)
{
if
(
mul_topk_weights
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
idx
=
tid4
*
4
+
i
;
idx
=
idx
<
block_num_valid_tokens
?
idx
:
0
;
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
sh_block_topk_weights
[
tid4
*
4
+
i
]
=
__hmul2
(
sh_block_topk_weights
[
idx
]
=
__hmul2
(
global_scale
,
global_scale
,
Dtype
::
num2num2
(
Dtype
::
float2num
(
Dtype
::
num2num2
(
Dtype
::
float2num
(
topk_weights_ptr
[
sh_block_sorted_ids
[
idx
]])));
topk_weights_ptr
[
sh_block_sorted_ids
[
tid4
*
4
+
i
]])));
}
else
{
}
else
{
sh_block_topk_weights
[
tid4
*
4
+
i
]
=
sh_block_topk_weights
[
idx
]
=
Dtype
::
num2num2
(
Dtype
::
num2num2
(
Dtype
::
float2num
(
Dtype
::
float2num
(
topk_weights_ptr
[
sh_block_sorted_ids
[
idx
]]));
topk_weights_ptr
[
sh_block_sorted_ids
[
tid4
*
4
+
i
]]));
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment