Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1228f7ca
Unverified
Commit
1228f7ca
authored
Dec 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 03, 2024
Browse files
Fix gptq for moe layers (#2300)
Co-authored-by:
root
<
me@zhyncs.com
>
parent
fda628d8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
2 deletions
+44
-2
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+34
-0
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+10
-2
No files found.
python/sglang/srt/layers/quantization/__init__.py
View file @
1228f7ca
...
@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix):
...
@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix):
return
None
return
None
def
gptq_get_quant_method
(
self
,
layer
,
prefix
):
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinLinearMethod
,
GPTQMarlinMoEMethod
,
)
from
sglang.srt.layers.fused_moe_triton.layer
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
return
GPTQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
GPTQMarlinMoEMethod
(
self
)
return
None
def
awq_get_quant_method
(
self
,
layer
,
prefix
):
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.awq_marlin
import
(
AWQMarlinLinearMethod
,
AWQMoEMethod
,
)
from
sglang.srt.layers.fused_moe_triton.layer
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
return
AWQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
AWQMoEMethod
(
self
)
return
None
def
apply_monkey_patches
():
def
apply_monkey_patches
():
"""Apply all monkey patches in one place."""
"""Apply all monkey patches in one place."""
setattr
(
Fp8MoEMethod
,
"apply"
,
fp8_moe_apply
)
setattr
(
Fp8MoEMethod
,
"apply"
,
fp8_moe_apply
)
setattr
(
Fp8Config
,
"get_quant_method"
,
fp8_get_quant_method
)
setattr
(
Fp8Config
,
"get_quant_method"
,
fp8_get_quant_method
)
setattr
(
GPTQMarlinConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
setattr
(
AWQMarlinConfig
,
"get_quant_method"
,
awq_get_quant_method
)
# Apply patches when module is imported
# Apply patches when module is imported
...
...
python/sglang/srt/models/mixtral.py
View file @
1228f7ca
...
@@ -339,7 +339,9 @@ class MixtralForCausalLM(nn.Module):
...
@@ -339,7 +339,9 @@ class MixtralForCausalLM(nn.Module):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
(
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
)
)
and
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
...
@@ -353,6 +355,10 @@ class MixtralForCausalLM(nn.Module):
...
@@ -353,6 +355,10 @@ class MixtralForCausalLM(nn.Module):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
(
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
)
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
weight_loader
(
...
@@ -365,7 +371,9 @@ class MixtralForCausalLM(nn.Module):
...
@@ -365,7 +371,9 @@ class MixtralForCausalLM(nn.Module):
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
(
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
)
)
and
name
not
in
params_dict
:
continue
continue
# Skip loading kv_scale from ckpts towards new design.
# Skip loading kv_scale from ckpts towards new design.
if
name
.
endswith
(
".kv_scale"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".kv_scale"
)
and
name
not
in
params_dict
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment