Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d332aa3b
"vscode:/vscode.git/clone" did not exist on "9bca40296e3f00fb26597a0f4cfe2fdfd2ad2fd2"
Unverified
Commit
d332aa3b
authored
Dec 07, 2024
by
Yineng Zhang
Committed by
GitHub
Dec 07, 2024
Browse files
fix: resolve fp8 moe issue (#2387)
parent
c36736c8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
56 deletions
+27
-56
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+2
-47
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+25
-9
No files found.
python/sglang/srt/layers/quantization/__init__.py
View file @
d332aa3b
...
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig
...
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"aqlm"
:
AQLMConfig
,
...
@@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...
@@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
return
QUANTIZATION_METHODS
[
quantization
]
return
QUANTIZATION_METHODS
[
quantization
]
def
fp8_moe_apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
"""Enhanced apply method for FP8 MoE."""
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
fused_experts
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
)
# Expert fusion with FP8 quantization
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
def
fp8_get_quant_method
(
self
,
layer
,
prefix
):
def
fp8_get_quant_method
(
self
,
layer
,
prefix
):
"""Enhanced get_quant_method for FP8 config."""
"""Enhanced get_quant_method for FP8 config."""
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.linear
import
LinearBase
...
@@ -106,7 +62,7 @@ def fp8_get_quant_method(self, layer, prefix):
...
@@ -106,7 +62,7 @@ def fp8_get_quant_method(self, layer, prefix):
from
sglang.srt.layers.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.linear
import
UnquantizedLinearMethod
from
sglang.srt.layers.linear
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8LinearMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8LinearMethod
,
Fp8MoEMethod
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
...
@@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix):
...
@@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix):
def
apply_monkey_patches
():
def
apply_monkey_patches
():
"""Apply all monkey patches in one place."""
"""Apply all monkey patches in one place."""
setattr
(
Fp8MoEMethod
,
"apply"
,
fp8_moe_apply
)
setattr
(
Fp8Config
,
"get_quant_method"
,
fp8_get_quant_method
)
setattr
(
Fp8Config
,
"get_quant_method"
,
fp8_get_quant_method
)
setattr
(
GPTQMarlinConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
setattr
(
GPTQMarlinConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
setattr
(
AWQMarlinConfig
,
"get_quant_method"
,
awq_get_quant_method
)
setattr
(
AWQMarlinConfig
,
"get_quant_method"
,
awq_get_quant_method
)
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
d332aa3b
...
@@ -24,11 +24,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -24,11 +24,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
)
from
vllm.model_executor.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
vllm.model_executor.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.fused_moe_triton
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
from
sglang.srt.layers.linear
import
LinearMethodBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.linear
import
LinearMethodBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizationConfig
,
...
@@ -100,6 +95,8 @@ class Fp8Config(QuantizationConfig):
...
@@ -100,6 +95,8 @@ class Fp8Config(QuantizationConfig):
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
from
vllm.attention.layer
import
Attention
# Avoid circular import
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
...
@@ -306,7 +303,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -306,7 +303,7 @@ class Fp8LinearMethod(LinearMethodBase):
)
)
class
Fp8MoEMethod
(
FusedMoEMethodBase
)
:
class
Fp8MoEMethod
:
"""MoE method for FP8.
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
dynamic/static activation scale.
...
@@ -319,7 +316,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -319,7 +316,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_config: The quantization config.
quant_config: The quantization config.
"""
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
def
__new__
(
cls
,
*
args
,
**
kwargs
):
from
sglang.srt.layers.fused_moe_triton
import
FusedMoEMethodBase
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
def
create_weights
(
def
create_weights
(
...
@@ -331,6 +346,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -331,6 +346,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
from
sglang.srt.layers.fused_moe_triton
import
FusedMoeWeightScaleSupported
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
params_dtype
=
torch
.
float8_e4m3fn
...
@@ -521,8 +537,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -521,8 +537,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
sglang.srt.layers.fused_moe_triton
import
FusedMoE
from
vllm.model_executor.layers
.fused_moe
import
fused_experts
from
sglang.srt.layers.fused_moe_triton
.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment