Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
30b404ce
Unverified
Commit
30b404ce
authored
Sep 13, 2024
by
Jerry Zhang
Committed by
GitHub
Sep 14, 2024
Browse files
Add torchao quant for mixtral and qwen_moe (#1418)
parent
70b68029
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
20 deletions
+50
-20
python/sglang/srt/layers/torchao_utils.py
python/sglang/srt/layers/torchao_utils.py
+38
-1
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+2
-19
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+5
-0
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+5
-0
No files found.
python/sglang/srt/layers/torchao_utils.py
View file @
30b404ce
...
@@ -2,10 +2,20 @@
...
@@ -2,10 +2,20 @@
Common utilities for torchao.
Common utilities for torchao.
"""
"""
from
typing
import
Dict
,
Set
import
torch
import
torch
def
torchao_quantize_param_data
(
param
,
torchao_config
):
def
torchao_quantize_param_data
(
param
:
torch
.
Tensor
,
torchao_config
:
str
):
"""Quantize a Tensor with torchao quantization specified by torchao_config
Args:
`param`: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to use to
quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
128
"""
# Lazy import to suppress some warnings
# Lazy import to suppress some warnings
from
torchao.quantization
import
(
from
torchao.quantization
import
(
int4_weight_only
,
int4_weight_only
,
...
@@ -36,3 +46,30 @@ def torchao_quantize_param_data(param, torchao_config):
...
@@ -36,3 +46,30 @@ def torchao_quantize_param_data(param, torchao_config):
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_
(
dummy_linear
,
float8_weight_only
())
quantize_
(
dummy_linear
,
float8_weight_only
())
return
dummy_linear
.
weight
return
dummy_linear
.
weight
def
apply_torchao_config_
(
self
:
torch
.
nn
.
Module
,
params_dict
:
Dict
[
str
,
torch
.
Tensor
],
param_suffixes
:
Set
[
str
],
)
->
None
:
"""A util function used for quantizing the weight parameters after they are loaded if
self.torchao_config is specified
Args:
`self`: the model we want to quantize
`params_dict`: dictionary mapping from param_name to the parameter Tensor
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
Returns:
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
"""
if
self
.
torchao_config
:
for
param_suffix
in
param_suffixes
:
for
name
in
params_dict
:
param
=
params_dict
[
name
]
if
param_suffix
in
name
and
param
.
ndim
==
2
:
params_dict
[
name
]
=
torchao_quantize_param_data
(
param
,
self
.
torchao_config
)
self
.
load_state_dict
(
params_dict
,
assign
=
True
)
python/sglang/srt/models/llama.py
View file @
30b404ce
...
@@ -41,7 +41,7 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -41,7 +41,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
torchao_
quantize_param_data
from
sglang.srt.layers.torchao_utils
import
apply_
torchao_
config_
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -405,24 +405,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -405,24 +405,7 @@ class LlamaForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
torchao_config
:
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
if
name
.
endswith
(
"proj.weight"
)
and
param
.
ndim
==
2
:
params_dict
[
name
]
=
torchao_quantize_param_data
(
param
,
self
.
torchao_config
)
if
self
.
torchao_config
:
# quantizing the loaded, stacked params, e.g. "...qkv_proj"
stacked_params
=
set
(
entry
[
0
]
for
entry
in
stacked_params_mapping
)
for
param_suffix
in
stacked_params
:
for
name
in
params_dict
:
if
param_suffix
in
name
:
param
=
params_dict
[
name
]
params_dict
[
name
]
=
torchao_quantize_param_data
(
param
,
self
.
torchao_config
)
self
.
load_state_dict
(
params_dict
,
assign
=
True
)
class
Phi3ForCausalLM
(
LlamaForCausalLM
):
class
Phi3ForCausalLM
(
LlamaForCausalLM
):
...
...
python/sglang/srt/models/mixtral.py
View file @
30b404ce
...
@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -296,6 +298,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -296,6 +298,7 @@ class MixtralForCausalLM(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
,
prefix
=
"model"
)
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
,
prefix
=
"model"
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
@@ -376,5 +379,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -376,5 +379,7 @@ class MixtralForCausalLM(nn.Module):
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
EntryClass
=
MixtralForCausalLM
EntryClass
=
MixtralForCausalLM
python/sglang/srt/models/qwen2_moe.py
View file @
30b404ce
...
@@ -47,6 +47,8 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -47,6 +47,8 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -359,6 +361,7 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -359,6 +361,7 @@ class Qwen2MoeForCausalLM(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
model
=
Qwen2MoeModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Qwen2MoeModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
...
@@ -451,5 +454,7 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -451,5 +454,7 @@ class Qwen2MoeForCausalLM(nn.Module):
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
EntryClass
=
Qwen2MoeForCausalLM
EntryClass
=
Qwen2MoeForCausalLM
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment