Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
95f93f49
"vscode:/vscode.git/clone" did not exist on "47af8be9072b26f85c445b90df3c759a52bd7f73"
Unverified
Commit
95f93f49
authored
Dec 07, 2024
by
HAI
Committed by
GitHub
Dec 07, 2024
Browse files
Fp8 MoE optimizations on AMD (#2388)
parent
aaac33fd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
22 deletions
+97
-22
python/sglang/srt/layers/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/fused_moe_triton/fused_moe.py
+64
-21
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+33
-1
No files found.
python/sglang/srt/layers/fused_moe_triton/fused_moe.py
View file @
95f93f49
...
...
@@ -16,6 +16,7 @@ from vllm import _custom_ops as ops
from
sglang.srt.utils
import
direct_register_custom_op
,
get_device_name
logger
=
logging
.
getLogger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
@
triton
.
jit
...
...
@@ -58,6 +59,7 @@ def fused_moe_kernel(
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
even_Ks
:
tl
.
constexpr
,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
...
...
@@ -143,12 +145,21 @@ def fused_moe_kernel(
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
if
even_Ks
:
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
],
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
)
else
:
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
# We accumulate along the K dimension.
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
...
...
@@ -254,7 +265,9 @@ def invoke_fused_moe_kernel(
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
padded_size
=
0
if
use_fp8_w8a8
:
padded_size
=
padding_size
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
elif
use_int8_w8a16
:
...
...
@@ -268,6 +281,12 @@ def invoke_fused_moe_kernel(
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
)
K
=
B
.
shape
[
2
]
-
padded_size
if
K
%
config
[
"BLOCK_SIZE_K"
]
==
0
:
even_Ks
=
True
else
:
even_Ks
=
False
fused_moe_kernel
[
grid
](
A
,
B
,
...
...
@@ -279,7 +298,7 @@ def invoke_fused_moe_kernel(
expert_ids
,
num_tokens_post_padded
,
B
.
shape
[
1
],
B
.
shape
[
2
],
B
.
shape
[
2
]
-
padded_size
,
sorted_token_ids
.
shape
[
0
],
topk_ids
.
numel
(),
A
.
stride
(
0
),
...
...
@@ -296,6 +315,7 @@ def invoke_fused_moe_kernel(
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
even_Ks
=
even_Ks
,
**
config
,
)
...
...
@@ -351,20 +371,39 @@ def get_default_config(
dtype
:
Optional
[
str
],
is_marlin
:
bool
,
)
->
Dict
[
str
,
int
]:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
}
# A heuristic: fused marlin works faster with this config for small M
if
M
<=
E
or
(
is_marlin
and
M
<=
32
):
if
dtype
==
"fp8_w8a8"
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
,
}
if
M
<=
E
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
}
# A heuristic: fused marlin works faster with this config for small M
if
M
<=
E
or
(
is_marlin
and
M
<=
32
):
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
}
return
config
...
...
@@ -645,8 +684,12 @@ def fused_experts_impl(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
):
padded_size
=
padding_size
if
not
use_fp8_w8a8
:
padded_size
=
0
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
]
-
padded_size
,
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
...
...
@@ -668,7 +711,7 @@ def fused_experts_impl(
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
(
w2
.
shape
[
0
],
w2
.
shape
[
1
],
w2
.
shape
[
2
]
-
padded_size
)
,
topk_ids
.
shape
[
1
],
config_dtype
,
)
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
95f93f49
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import
logging
import
os
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
...
...
@@ -24,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from
vllm.model_executor.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
padding_size
from
sglang.srt.layers.linear
import
LinearMethodBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
...
...
@@ -420,7 +423,7 @@ class Fp8MoEMethod:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If checkpoint is fp16, quantize in place.
# If checkpoint is fp16
or bfloat16
, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
is_hip
()
else
torch
.
float8_e4m3fn
...
...
@@ -444,6 +447,19 @@ class Fp8MoEMethod:
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if
is_hip
()
and
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
))):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w2_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
return
# If checkpoint is fp8, we need to handle that the
...
...
@@ -472,6 +488,7 @@ class Fp8MoEMethod:
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
# Normalize the weights and scales
...
...
@@ -523,6 +540,19 @@ class Fp8MoEMethod:
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if
is_hip
()
and
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
))):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w2_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
return
def
apply
(
...
...
@@ -540,6 +570,7 @@ class Fp8MoEMethod:
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
fused_experts
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
@@ -551,6 +582,7 @@ class Fp8MoEMethod:
custom_routing_function
=
custom_routing_function
,
)
# Expert fusion with FP8 quantization
return
fused_experts
(
x
,
layer
.
w13_weight
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment