Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
862bcff8
Unverified
Commit
862bcff8
authored
Jan 22, 2025
by
Ke Wen
Committed by
GitHub
Jan 22, 2025
Browse files
Support loading of larger models with on-the-fly quantization (#3061)
parent
8b84e69f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
116 additions
and
14 deletions
+116
-14
python/sglang/srt/configs/load_config.py
python/sglang/srt/configs/load_config.py
+1
-0
python/sglang/srt/layers/torchao_utils.py
python/sglang/srt/layers/torchao_utils.py
+12
-6
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-3
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+75
-0
python/sglang/srt/models/torch_native_llama.py
python/sglang/srt/models/torch_native_llama.py
+17
-4
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+5
-1
No files found.
python/sglang/srt/configs/load_config.py
View file @
862bcff8
...
...
@@ -20,6 +20,7 @@ class LoadFormat(str, enum.Enum):
GGUF
=
"gguf"
BITSANDBYTES
=
"bitsandbytes"
MISTRAL
=
"mistral"
LAYERED
=
"layered"
@
dataclass
...
...
python/sglang/srt/layers/torchao_utils.py
View file @
862bcff8
...
...
@@ -5,6 +5,7 @@ Common utilities for torchao.
import
logging
import
os
import
pwd
from
typing
import
Callable
,
Optional
import
torch
...
...
@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
return
True
def
proj_filter
(
module
:
torch
.
nn
.
Module
,
fqn
:
str
,
):
"""Filter function for quantizing projection layers."""
return
"proj"
in
fqn
def
apply_torchao_config_to_model
(
model
:
torch
.
nn
.
Module
,
torchao_config
:
str
,
filter_fn
=
None
model
:
torch
.
nn
.
Module
,
torchao_config
:
str
,
filter_fn
:
Optional
[
Callable
]
=
proj_filter
,
):
"""Quantize a modelwith torchao quantization specified by torchao_config
...
...
@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
)
from
torchao.quantization.observer
import
PerRow
,
PerTensor
if
filter_fn
is
None
:
def
filter_fn
(
module
,
fqn
):
return
"proj"
in
fqn
if
torchao_config
==
""
or
torchao_config
is
None
:
return
model
elif
"int8wo"
in
torchao_config
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
862bcff8
...
...
@@ -185,9 +185,12 @@ class ModelRunner:
self
.
load_model
()
# Apply torchao quantization
apply_torchao_config_to_model
(
self
.
model
,
global_server_args_dict
[
"torchao_config"
]
)
torchao_applied
=
getattr
(
self
.
model
,
"torchao_applied"
,
False
)
# In layered loading, torchao may have been applied
if
not
torchao_applied
:
apply_torchao_config_to_model
(
self
.
model
,
global_server_args_dict
[
"torchao_config"
]
)
# Apply torch TP if the model supports it
supports_torch_tp
=
getattr
(
self
.
model
,
"supports_torch_tp"
,
False
)
...
...
python/sglang/srt/model_loader/loader.py
View file @
862bcff8
...
...
@@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader):
return
model
.
eval
()
class
LayeredModelLoader
(
DefaultModelLoader
):
"""Model loader that loads weights layer by layer so that one can quantize a
layer before loading another to make the peak memory envelope smaller."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
# Back to the default load format
load_config
.
load_format
=
LoadFormat
.
AUTO
super
().
__init__
(
load_config
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
torchao_config
=
global_server_args_dict
.
get
(
"torchao_config"
)
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
# Create model on meta device
with
torch
.
device
(
"meta"
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
)
# Check model's layered load support
if
not
hasattr
(
model
,
"load_weights_to_module"
):
raise
ValueError
(
"LayeredModelLoader requires the model to have a "
"`load_weights_to_module` method. "
f
"
{
model_config
.
model_path
}
does not support it."
)
# Get all weights from disk
weights
=
self
.
_get_all_weights
(
model_config
,
model
)
# Helper function to recursively fill the weights of a module
def
fill_module
(
module
,
fqn
:
List
[
str
],
weights
):
"""
fqn: list of strings representing the fully qualified name of `module`.
"""
# Layer by layer
for
name
,
submod
in
module
.
named_children
():
fill_module
(
submod
,
fqn
+
[
name
],
weights
)
# First materialize on target device
module
.
to_empty
(
device
=
target_device
,
recurse
=
False
)
fqn_path
=
"."
.
join
(
fqn
)
# Fill weights
model
.
load_weights_to_module
(
fqn_path
,
weights
,
)
# Quantize weights if applicable
if
torchao_config
and
"proj"
in
fqn_path
:
# Note: `None` here is needed to indicate no filter, see
# `apply_torchao_config_to_model` for details.
apply_torchao_config_to_model
(
module
,
torchao_config
,
None
)
# Start calling on root module
fill_module
(
model
,
[],
weights
)
if
torchao_config
:
model
.
torchao_applied
=
True
return
model
.
eval
()
class
DummyModelLoader
(
BaseModelLoader
):
"""Model loader that will set model weights to random values."""
...
...
@@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if
load_config
.
load_format
==
LoadFormat
.
GGUF
:
return
GGUFModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
LAYERED
:
return
LayeredModelLoader
(
load_config
)
return
DefaultModelLoader
(
load_config
)
python/sglang/srt/models/torch_native_llama.py
View file @
862bcff8
...
...
@@ -460,7 +460,12 @@ class TorchNativeLlamaForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
return
len
(
params_dict
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights_to_module
(
self
,
fqn
:
str
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
):
"""Load weights onto submodule pointed by path `fqn`."""
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
...
...
@@ -469,7 +474,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
module
=
self
.
get_submodule
(
fqn
)
params_dict
=
dict
(
module
.
named_parameters
(
prefix
=
fqn
,
recurse
=
False
))
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
...
...
@@ -486,7 +492,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
or
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
...
...
@@ -494,12 +500,19 @@ class TorchNativeLlamaForCausalLM(nn.Module):
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
or
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
):
"""Load weights onto the full model."""
self
.
load_weights_to_module
(
""
,
weights
)
class
TorchNativePhi3ForCausalLM
(
TorchNativeLlamaForCausalLM
):
pass
...
...
python/sglang/srt/server_args.py
View file @
862bcff8
...
...
@@ -317,6 +317,7 @@ class ServerArgs:
"dummy"
,
"gguf"
,
"bitsandbytes"
,
"layered"
,
],
help
=
"The format of the model weights to load. "
'"auto" will try to load the weights in the safetensors format '
...
...
@@ -330,7 +331,10 @@ class ServerArgs:
"which is mainly for profiling."
'"gguf" will load the weights in the gguf format. '
'"bitsandbytes" will load the weights using bitsandbytes '
"quantization."
,
"quantization."
'"layered" loads weights layer by layer so that one can quantize a '
"layer before loading another to make the peak memory envelope "
"smaller."
,
)
parser
.
add_argument
(
"--trust-remote-code"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment