Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
95f93f49
Unverified
Commit
95f93f49
authored
Dec 07, 2024
by
HAI
Committed by
GitHub
Dec 07, 2024
Browse files
Fp8 MoE optimizations on AMD (#2388)
parent
aaac33fd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
22 deletions
+97
-22
python/sglang/srt/layers/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/fused_moe_triton/fused_moe.py
+64
-21
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+33
-1
No files found.
python/sglang/srt/layers/fused_moe_triton/fused_moe.py
View file @
95f93f49
...
...
@@ -16,6 +16,7 @@ from vllm import _custom_ops as ops
from
sglang.srt.utils
import
direct_register_custom_op
,
get_device_name
logger
=
logging
.
getLogger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
@
triton
.
jit
...
...
@@ -58,6 +59,7 @@ def fused_moe_kernel(
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
even_Ks
:
tl
.
constexpr
,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
...
...
@@ -143,12 +145,21 @@ def fused_moe_kernel(
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
if
even_Ks
:
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
],
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
)
else
:
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
# We accumulate along the K dimension.
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
...
...
@@ -254,7 +265,9 @@ def invoke_fused_moe_kernel(
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
padded_size
=
0
if
use_fp8_w8a8
:
padded_size
=
padding_size
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
elif
use_int8_w8a16
:
...
...
@@ -268,6 +281,12 @@ def invoke_fused_moe_kernel(
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
)
K
=
B
.
shape
[
2
]
-
padded_size
if
K
%
config
[
"BLOCK_SIZE_K"
]
==
0
:
even_Ks
=
True
else
:
even_Ks
=
False
fused_moe_kernel
[
grid
](
A
,
B
,
...
...
@@ -279,7 +298,7 @@ def invoke_fused_moe_kernel(
expert_ids
,
num_tokens_post_padded
,
B
.
shape
[
1
],
B
.
shape
[
2
],
B
.
shape
[
2
]
-
padded_size
,
sorted_token_ids
.
shape
[
0
],
topk_ids
.
numel
(),
A
.
stride
(
0
),
...
...
@@ -296,6 +315,7 @@ def invoke_fused_moe_kernel(
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
even_Ks
=
even_Ks
,
**
config
,
)
...
...
@@ -351,20 +371,39 @@ def get_default_config(
dtype
:
Optional
[
str
],
is_marlin
:
bool
,
)
->
Dict
[
str
,
int
]:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
}
# A heuristic: fused marlin works faster with this config for small M
if
M
<=
E
or
(
is_marlin
and
M
<=
32
):
if
dtype
==
"fp8_w8a8"
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
,
}
if
M
<=
E
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
}
# A heuristic: fused marlin works faster with this config for small M
if
M
<=
E
or
(
is_marlin
and
M
<=
32
):
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
}
return
config
...
...
@@ -645,8 +684,12 @@ def fused_experts_impl(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
):
padded_size
=
padding_size
if
not
use_fp8_w8a8
:
padded_size
=
0
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
]
-
padded_size
,
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
...
...
@@ -668,7 +711,7 @@ def fused_experts_impl(
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
(
w2
.
shape
[
0
],
w2
.
shape
[
1
],
w2
.
shape
[
2
]
-
padded_size
)
,
topk_ids
.
shape
[
1
],
config_dtype
,
)
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
95f93f49
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import
logging
import
os
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
...
...
@@ -24,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from
vllm.model_executor.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
padding_size
from
sglang.srt.layers.linear
import
LinearMethodBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
...
...
@@ -420,7 +423,7 @@ class Fp8MoEMethod:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If checkpoint is fp16, quantize in place.
# If checkpoint is fp16
or bfloat16
, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
is_hip
()
else
torch
.
float8_e4m3fn
...
...
@@ -444,6 +447,19 @@ class Fp8MoEMethod:
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if
is_hip
()
and
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
))):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w2_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
return
# If checkpoint is fp8, we need to handle that the
...
...
@@ -472,6 +488,7 @@ class Fp8MoEMethod:
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
# Normalize the weights and scales
...
...
@@ -523,6 +540,19 @@ class Fp8MoEMethod:
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if
is_hip
()
and
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
))):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w2_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
return
def
apply
(
...
...
@@ -540,6 +570,7 @@ class Fp8MoEMethod:
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
fused_experts
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
@@ -551,6 +582,7 @@ class Fp8MoEMethod:
custom_routing_function
=
custom_routing_function
,
)
# Expert fusion with FP8 quantization
return
fused_experts
(
x
,
layer
.
w13_weight
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment