Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
30b404ce
"vscode:/vscode.git/clone" did not exist on "db6aca44d8970606899439d5ff89fc05bcec3cda"
Unverified
Commit
30b404ce
authored
Sep 13, 2024
by
Jerry Zhang
Committed by
GitHub
Sep 14, 2024
Browse files
Add torchao quant for mixtral and qwen_moe (#1418)
parent
70b68029
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
20 deletions
+50
-20
python/sglang/srt/layers/torchao_utils.py
python/sglang/srt/layers/torchao_utils.py
+38
-1
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+2
-19
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+5
-0
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+5
-0
No files found.
python/sglang/srt/layers/torchao_utils.py
View file @
30b404ce
...
@@ -2,10 +2,20 @@
...
@@ -2,10 +2,20 @@
Common utilities for torchao.
Common utilities for torchao.
"""
"""
from
typing
import
Dict
,
Set
import
torch
import
torch
def
torchao_quantize_param_data
(
param
,
torchao_config
):
def
torchao_quantize_param_data
(
param
:
torch
.
Tensor
,
torchao_config
:
str
):
"""Quantize a Tensor with torchao quantization specified by torchao_config
Args:
`param`: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to use to
quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
128
"""
# Lazy import to suppress some warnings
# Lazy import to suppress some warnings
from
torchao.quantization
import
(
from
torchao.quantization
import
(
int4_weight_only
,
int4_weight_only
,
...
@@ -36,3 +46,30 @@ def torchao_quantize_param_data(param, torchao_config):
...
@@ -36,3 +46,30 @@ def torchao_quantize_param_data(param, torchao_config):
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_
(
dummy_linear
,
float8_weight_only
())
quantize_
(
dummy_linear
,
float8_weight_only
())
return
dummy_linear
.
weight
return
dummy_linear
.
weight
def
apply_torchao_config_
(
self
:
torch
.
nn
.
Module
,
params_dict
:
Dict
[
str
,
torch
.
Tensor
],
param_suffixes
:
Set
[
str
],
)
->
None
:
"""A util function used for quantizing the weight parameters after they are loaded if
self.torchao_config is specified
Args:
`self`: the model we want to quantize
`params_dict`: dictionary mapping from param_name to the parameter Tensor
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
Returns:
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
"""
if
self
.
torchao_config
:
for
param_suffix
in
param_suffixes
:
for
name
in
params_dict
:
param
=
params_dict
[
name
]
if
param_suffix
in
name
and
param
.
ndim
==
2
:
params_dict
[
name
]
=
torchao_quantize_param_data
(
param
,
self
.
torchao_config
)
self
.
load_state_dict
(
params_dict
,
assign
=
True
)
python/sglang/srt/models/llama.py
View file @
30b404ce
...
@@ -41,7 +41,7 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -41,7 +41,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
torchao_
quantize_param_data
from
sglang.srt.layers.torchao_utils
import
apply_
torchao_
config_
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -405,24 +405,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -405,24 +405,7 @@ class LlamaForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
self
.
torchao_config
:
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
if
name
.
endswith
(
"proj.weight"
)
and
param
.
ndim
==
2
:
params_dict
[
name
]
=
torchao_quantize_param_data
(
param
,
self
.
torchao_config
)
if
self
.
torchao_config
:
# quantizing the loaded, stacked params, e.g. "...qkv_proj"
stacked_params
=
set
(
entry
[
0
]
for
entry
in
stacked_params_mapping
)
for
param_suffix
in
stacked_params
:
for
name
in
params_dict
:
if
param_suffix
in
name
:
param
=
params_dict
[
name
]
params_dict
[
name
]
=
torchao_quantize_param_data
(
param
,
self
.
torchao_config
)
self
.
load_state_dict
(
params_dict
,
assign
=
True
)
class
Phi3ForCausalLM
(
LlamaForCausalLM
):
class
Phi3ForCausalLM
(
LlamaForCausalLM
):
...
...
python/sglang/srt/models/mixtral.py
View file @
30b404ce
...
@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -296,6 +298,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -296,6 +298,7 @@ class MixtralForCausalLM(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
,
prefix
=
"model"
)
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
,
prefix
=
"model"
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
@@ -376,5 +379,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -376,5 +379,7 @@ class MixtralForCausalLM(nn.Module):
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
EntryClass
=
MixtralForCausalLM
EntryClass
=
MixtralForCausalLM
python/sglang/srt/models/qwen2_moe.py
View file @
30b404ce
...
@@ -47,6 +47,8 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -47,6 +47,8 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -359,6 +361,7 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -359,6 +361,7 @@ class Qwen2MoeForCausalLM(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
model
=
Qwen2MoeModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Qwen2MoeModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
...
@@ -451,5 +454,7 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -451,5 +454,7 @@ class Qwen2MoeForCausalLM(nn.Module):
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
EntryClass
=
Qwen2MoeForCausalLM
EntryClass
=
Qwen2MoeForCausalLM
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment