Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9cc733b3
Unverified
Commit
9cc733b3
authored
Dec 04, 2024
by
Jerry Zhang
Committed by
GitHub
Dec 04, 2024
Browse files
move apply_torchao_config_ to model_runner (#2342)
parent
d693ec04
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
25 additions
and
71 deletions
+25
-71
python/sglang/srt/layers/torchao_utils.py
python/sglang/srt/layers/torchao_utils.py
+17
-41
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+8
-0
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+0
-5
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+0
-5
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+0
-5
python/sglang/srt/models/phi3_small.py
python/sglang/srt/models/phi3_small.py
+0
-5
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+0
-5
python/sglang/srt/models/torch_native_llama.py
python/sglang/srt/models/torch_native_llama.py
+0
-5
No files found.
python/sglang/srt/layers/torchao_utils.py
View file @
9cc733b3
...
...
@@ -7,13 +7,15 @@ from typing import Dict, Set
import
torch
def
torchao_quantize_param_data
(
param
:
torch
.
Tensor
,
torchao_config
:
str
):
"""Quantize a Tensor with torchao quantization specified by torchao_config
def
apply_torchao_config_to_model_
(
model
:
torch
.
nn
.
Module
,
torchao_config
:
str
,
filter_fn
=
None
):
"""Quantize a modelwith torchao quantization specified by torchao_config
Args:
`
param`: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to use to
quantize the
Tensor
, e.g. int4wo-128 means int4 weight only quantization with group_size
`
model`: a model to be quantized based on torchao_config
`torchao_config`
(str)
: type of quantization and their arguments we want to use to
quantize the
model
, e.g. int4wo-128 means int4 weight only quantization with group_size
128
"""
# Lazy import to suppress some warnings
...
...
@@ -26,12 +28,12 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
)
from
torchao.quantization.observer
import
PerRow
,
PerTensor
dummy_linear
=
torch
.
nn
.
Linear
(
param
.
shape
[
1
],
param
.
shape
[
0
],
bias
=
False
)
dummy_linear
.
weight
=
param
if
"int8wo"
in
torchao_config
:
quantize_
(
dummy_linear
,
int8_weight_only
())
if
torchao_config
==
""
or
torchao_config
is
None
:
return
model
el
if
"int8wo"
in
torchao_config
:
quantize_
(
model
,
int8_weight_only
()
,
filter_fn
=
filter_fn
)
elif
"int8dq"
in
torchao_config
:
quantize_
(
dummy_linear
,
int8_dynamic_activation_int8_weight
())
quantize_
(
model
,
int8_dynamic_activation_int8_weight
()
,
filter_fn
=
filter_fn
)
elif
"int4wo"
in
torchao_config
:
group_size
=
int
(
torchao_config
.
split
(
"-"
)[
-
1
])
assert
group_size
in
[
...
...
@@ -40,13 +42,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
128
,
256
,
],
f
"int4wo groupsize needs to be one of [32, 64, 128, 256] but got
{
group_size
}
"
quantize_
(
dummy_linear
,
int4_weight_only
(
group_size
=
group_size
))
quantize_
(
model
,
int4_weight_only
(
group_size
=
group_size
)
,
filter_fn
=
filter_fn
)
elif
"fp8wo"
in
torchao_config
:
from
torchao.quantization
import
float8_weight_only
# this requires newer hardware
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_
(
dummy_linear
,
float8_weight_only
())
quantize_
(
model
,
float8_weight_only
()
,
filter_fn
=
filter_fn
)
elif
"fp8dq"
in
torchao_config
:
granularity
=
torchao_config
.
split
(
"-"
)[
-
1
]
GRANULARITY_MAP
=
{
...
...
@@ -57,39 +59,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
granularity
in
GRANULARITY_MAP
),
f
"Supported granularity are:
{
GRANULARITY_MAP
.
keys
()
}
, got
{
granularity
}
"
quantize_
(
dummy_linear
,
model
,
float8_dynamic_activation_float8_weight
(
granularity
=
GRANULARITY_MAP
[
granularity
]
),
filter_fn
=
filter_fn
,
)
else
:
raise
ValueError
(
f
"Unexpected config:
{
torchao_config
}
"
)
return
dummy_linear
.
weight
def
apply_torchao_config_
(
self
:
torch
.
nn
.
Module
,
params_dict
:
Dict
[
str
,
torch
.
Tensor
],
param_suffixes
:
Set
[
str
],
)
->
None
:
"""A util function used for quantizing the weight parameters after they are loaded if
self.torchao_config is specified
Args:
`self`: the model we want to quantize
`params_dict`: dictionary mapping from param_name to the parameter Tensor
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
Returns:
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
"""
if
self
.
torchao_config
:
for
param_suffix
in
param_suffixes
:
for
name
in
params_dict
:
param
=
params_dict
[
name
]
if
param_suffix
in
name
and
param
.
ndim
==
2
:
params_dict
[
name
]
=
torchao_quantize_param_data
(
param
,
self
.
torchao_config
)
self
.
load_state_dict
(
params_dict
,
assign
=
True
)
return
model
python/sglang/srt/model_executor/model_runner.py
View file @
9cc733b3
...
...
@@ -38,6 +38,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model_
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
...
...
@@ -159,6 +160,13 @@ class ModelRunner:
else
:
self
.
torch_tp_applied
=
False
def
filter_fn
(
module
,
fqn
):
return
"proj"
in
fqn
apply_torchao_config_to_model_
(
self
.
model
,
global_server_args_dict
[
"torchao_config"
],
filter_fn
)
# Init memory pool and attention backends
if
server_args
.
lora_paths
is
not
None
:
self
.
init_lora_manager
()
...
...
python/sglang/srt/models/grok.py
View file @
9cc733b3
...
...
@@ -35,12 +35,10 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -290,7 +288,6 @@ class Grok1ForCausalLM(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
@@ -374,8 +371,6 @@ class Grok1ForCausalLM(nn.Module):
)
weight_loader
(
param
,
loaded_weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
class
Grok1ModelForCausalLM
(
Grok1ForCausalLM
):
"""An alias for backward-compatbility."""
...
...
python/sglang/srt/models/llama.py
View file @
9cc733b3
...
...
@@ -36,12 +36,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
make_layers
...
...
@@ -304,7 +302,6 @@ class LlamaForCausalLM(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
# Llama 3.2 1B Insturct set tie_word_embeddings to True
# Llama 3.1 8B Insturct set tie_word_embeddings to False
...
...
@@ -424,8 +421,6 @@ class LlamaForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
def
get_weights_by_name
(
self
,
name
:
str
,
truncate_size
:
int
=
100
,
tp_size
:
int
=
1
)
->
Optional
[
torch
.
Tensor
]:
...
...
python/sglang/srt/models/mixtral.py
View file @
9cc733b3
...
...
@@ -34,12 +34,10 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -295,7 +293,6 @@ class MixtralForCausalLM(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
model
=
MixtralModel
(
config
,
quant_config
=
quant_config
,
prefix
=
"model"
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
@@ -387,7 +384,5 @@ class MixtralForCausalLM(nn.Module):
)
weight_loader
(
param
,
loaded_weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
EntryClass
=
MixtralForCausalLM
python/sglang/srt/models/phi3_small.py
View file @
9cc733b3
...
...
@@ -17,13 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
make_layers
...
...
@@ -348,7 +346,6 @@ class Phi3SmallForCausalLM(nn.Module):
quant_config
=
quant_config
,
prefix
=
"model"
,
)
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
vocab_size
=
config
.
vocab_size
self
.
mup_width_multiplier
=
config
.
mup_width_multiplier
self
.
lm_head
=
ParallelLMHead
(
...
...
@@ -441,7 +438,5 @@ class Phi3SmallForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
EntryClass
=
Phi3SmallForCausalLM
python/sglang/srt/models/qwen2_moe.py
View file @
9cc733b3
...
...
@@ -40,12 +40,10 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -352,7 +350,6 @@ class Qwen2MoeForCausalLM(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
model
=
Qwen2MoeModel
(
config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
...
...
@@ -445,7 +442,5 @@ class Qwen2MoeForCausalLM(nn.Module):
)
weight_loader
(
param
,
loaded_weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
EntryClass
=
Qwen2MoeForCausalLM
python/sglang/srt/models/torch_native_llama.py
View file @
9cc733b3
...
...
@@ -58,12 +58,10 @@ from sglang.srt.layers.layernorm import RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -392,7 +390,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
torchao_config
=
global_server_args_dict
[
"torchao_config"
]
self
.
supports_torch_tp
=
True
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
...
...
@@ -503,8 +500,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
class
TorchNativePhi3ForCausalLM
(
TorchNativeLlamaForCausalLM
):
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment