Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f35cb46c
Unverified
Commit
f35cb46c
authored
Nov 21, 2024
by
HAI
Committed by
GitHub
Nov 21, 2024
Browse files
ROCm: Fix MoE padding for none FP8 cases (#2111)
parent
7f8fcd39
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
4 deletions
+11
-4
python/sglang/srt/layers/fused_moe/fused_moe.py
python/sglang/srt/layers/fused_moe/fused_moe.py
+11
-4
No files found.
python/sglang/srt/layers/fused_moe/fused_moe.py
View file @
f35cb46c
...
@@ -250,9 +250,12 @@ def invoke_fused_moe_kernel(
...
@@ -250,9 +250,12 @@ def invoke_fused_moe_kernel(
assert
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
padded_size
=
padding_size
if
not
use_fp8
:
if
not
use_fp8
:
assert
A_scale
is
None
assert
A_scale
is
None
assert
B_scale
is
None
assert
B_scale
is
None
# MOE_PADDING FP8 only
padded_size
=
0
else
:
else
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
assert
B_scale
is
not
None
...
@@ -262,7 +265,7 @@ def invoke_fused_moe_kernel(
...
@@ -262,7 +265,7 @@ def invoke_fused_moe_kernel(
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
)
)
K
=
B
.
shape
[
2
]
-
padd
ing
_size
K
=
B
.
shape
[
2
]
-
padd
ed
_size
if
K
%
config
[
"BLOCK_SIZE_K"
]
==
0
:
if
K
%
config
[
"BLOCK_SIZE_K"
]
==
0
:
even_ks
=
True
even_ks
=
True
else
:
else
:
...
@@ -279,7 +282,7 @@ def invoke_fused_moe_kernel(
...
@@ -279,7 +282,7 @@ def invoke_fused_moe_kernel(
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
B
.
shape
[
1
],
B
.
shape
[
1
],
B
.
shape
[
2
]
-
padd
ing
_size
,
B
.
shape
[
2
]
-
padd
ed
_size
,
sorted_token_ids
.
shape
[
0
],
sorted_token_ids
.
shape
[
0
],
topk_ids
.
numel
(),
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
0
),
...
@@ -480,8 +483,12 @@ def fused_experts(
...
@@ -480,8 +483,12 @@ def fused_experts(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
padded_size
=
padding_size
if
not
use_fp8
:
# MOE_PADDING FP8 only
padded_size
=
0
# Check constraints.
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
]
-
padd
ing
_size
,
"Hidden size mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
]
-
padd
ed
_size
,
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
...
@@ -498,7 +505,7 @@ def fused_experts(
...
@@ -498,7 +505,7 @@ def fused_experts(
get_config_func
=
functools
.
partial
(
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
try_get_optimal_moe_config
,
w1
.
shape
,
w1
.
shape
,
(
w2
.
shape
[
0
],
w2
.
shape
[
1
],
w2
.
shape
[
2
]
-
padd
ing
_size
),
(
w2
.
shape
[
0
],
w2
.
shape
[
1
],
w2
.
shape
[
2
]
-
padd
ed
_size
),
topk_ids
.
shape
[
1
],
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
,
"float8"
if
use_fp8
else
None
,
override_config
=
override_config
,
override_config
=
override_config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment