Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
85e1a6f3
Unverified
Commit
85e1a6f3
authored
Dec 02, 2024
by
Yineng Zhang
Committed by
GitHub
Dec 02, 2024
Browse files
Update model_loader deps and qqq quantization deps (#2220) (#2318)
Co-authored-by:
HandH1998
<
1335248067@qq.com
>
parent
33deca81
Changes
58
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2167 additions
and
172 deletions
+2167
-172
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+4
-0
python/sglang/srt/configs/device_config.py
python/sglang/srt/configs/device_config.py
+17
-0
python/sglang/srt/configs/load_config.py
python/sglang/srt/configs/load_config.py
+84
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+161
-4
python/sglang/srt/configs/qwen2vl.py
python/sglang/srt/configs/qwen2vl.py
+5
-8
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+2
-0
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+1
-0
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+1
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+3
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+3
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+19
-136
python/sglang/srt/model_loader/__init__.py
python/sglang/srt/model_loader/__init__.py
+34
-0
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+1139
-0
python/sglang/srt/model_loader/utils.py
python/sglang/srt/model_loader/utils.py
+41
-0
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+640
-0
python/sglang/srt/models/baichuan.py
python/sglang/srt/models/baichuan.py
+3
-5
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+5
-14
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+1
-2
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+1
-2
No files found.
python/sglang/bench_one_batch.py
View file @
85e1a6f3
...
@@ -111,8 +111,12 @@ def load_model(server_args, port_args, tp_rank):
...
@@ -111,8 +111,12 @@ def load_model(server_args, port_args, tp_rank):
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
server_args
.
model_path
,
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
)
)
model_runner
=
ModelRunner
(
model_runner
=
ModelRunner
(
model_config
=
model_config
,
model_config
=
model_config
,
...
...
python/sglang/srt/configs/device_config.py
0 → 100644
View file @
85e1a6f3
import
logging
from
typing
import
Optional
import
torch
logger
=
logging
.
getLogger
(
__name__
)
class
DeviceConfig
:
device
:
Optional
[
torch
.
device
]
def
__init__
(
self
,
device
:
str
=
"cuda"
)
->
None
:
if
device
in
[
"cuda"
,
"xpu"
,
"hpu"
]:
self
.
device_type
=
device
else
:
raise
RuntimeError
(
f
"Not supported device type:
{
device
}
"
)
self
.
device
=
torch
.
device
(
self
.
device_type
)
python/sglang/srt/configs/load_config.py
0 → 100644
View file @
85e1a6f3
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
import
enum
import
json
import
logging
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
,
Union
from
sglang.srt.utils
import
is_hip
logger
=
logging
.
getLogger
(
__name__
)
class
LoadFormat
(
str
,
enum
.
Enum
):
AUTO
=
"auto"
PT
=
"pt"
SAFETENSORS
=
"safetensors"
NPCACHE
=
"npcache"
DUMMY
=
"dummy"
SHARDED_STATE
=
"sharded_state"
GGUF
=
"gguf"
BITSANDBYTES
=
"bitsandbytes"
MISTRAL
=
"mistral"
@
dataclass
class
LoadConfig
:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format
:
Union
[
str
,
LoadFormat
]
=
LoadFormat
.
AUTO
download_dir
:
Optional
[
str
]
=
None
model_loader_extra_config
:
Optional
[
Union
[
str
,
dict
]]
=
field
(
default_factory
=
dict
)
ignore_patterns
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
def
__post_init__
(
self
):
model_loader_extra_config
=
self
.
model_loader_extra_config
or
{}
if
isinstance
(
model_loader_extra_config
,
str
):
self
.
model_loader_extra_config
=
json
.
loads
(
model_loader_extra_config
)
self
.
_verify_load_format
()
if
self
.
ignore_patterns
is
not
None
and
len
(
self
.
ignore_patterns
)
>
0
:
logger
.
info
(
"Ignoring the following patterns when downloading weights: %s"
,
self
.
ignore_patterns
,
)
else
:
self
.
ignore_patterns
=
[
"original/**/*"
]
def
_verify_load_format
(
self
)
->
None
:
if
not
isinstance
(
self
.
load_format
,
str
):
return
load_format
=
self
.
load_format
.
lower
()
self
.
load_format
=
LoadFormat
(
load_format
)
rocm_not_supported_load_format
:
List
[
str
]
=
[]
if
is_hip
()
and
load_format
in
rocm_not_supported_load_format
:
rocm_supported_load_format
=
[
f
for
f
in
LoadFormat
.
__members__
if
(
f
not
in
rocm_not_supported_load_format
)
]
raise
ValueError
(
f
"load format '
{
load_format
}
' is not supported in ROCm. "
f
"Supported load formats are "
f
"
{
rocm_supported_load_format
}
"
)
python/sglang/srt/configs/model_config.py
View file @
85e1a6f3
...
@@ -15,12 +15,14 @@
...
@@ -15,12 +15,14 @@
import
json
import
json
import
logging
import
logging
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Union
import
torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
from
sglang.srt.utils
import
get_bool_env_var
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -33,17 +35,22 @@ class AttentionArch(IntEnum):
...
@@ -33,17 +35,22 @@ class AttentionArch(IntEnum):
class
ModelConfig
:
class
ModelConfig
:
def
__init__
(
def
__init__
(
self
,
self
,
path
:
str
,
model_
path
:
str
,
trust_remote_code
:
bool
=
True
,
trust_remote_code
:
bool
=
True
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
context_length
:
Optional
[
int
]
=
None
,
context_length
:
Optional
[
int
]
=
None
,
model_override_args
:
Optional
[
dict
]
=
None
,
model_override_args
:
Optional
[
dict
]
=
None
,
is_embedding
:
Optional
[
bool
]
=
None
,
is_embedding
:
Optional
[
bool
]
=
None
,
dtype
:
str
=
"auto"
,
quantization
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
self
.
model_path
=
model_path
self
.
revision
=
revision
self
.
quantization
=
quantization
# Parse args
# Parse args
self
.
model_override_args
=
json
.
loads
(
model_override_args
)
self
.
model_override_args
=
json
.
loads
(
model_override_args
)
self
.
hf_config
=
get_config
(
self
.
hf_config
=
get_config
(
path
,
model_
path
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
revision
=
revision
,
model_override_args
=
self
.
model_override_args
,
model_override_args
=
self
.
model_override_args
,
...
@@ -56,6 +63,7 @@ class ModelConfig:
...
@@ -56,6 +63,7 @@ class ModelConfig:
)
)
self
.
is_multimodal
=
is_multimodal_model
(
self
.
hf_config
.
architectures
)
self
.
is_multimodal
=
is_multimodal_model
(
self
.
hf_config
.
architectures
)
self
.
is_encoder_decoder
=
is_encoder_decoder_model
(
self
.
hf_config
.
architectures
)
self
.
is_encoder_decoder
=
is_encoder_decoder_model
(
self
.
hf_config
.
architectures
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
# Derive context length
# Derive context length
derived_context_len
=
get_context_length
(
self
.
hf_text_config
)
derived_context_len
=
get_context_length
(
self
.
hf_text_config
)
...
@@ -116,6 +124,8 @@ class ModelConfig:
...
@@ -116,6 +124,8 @@ class ModelConfig:
self
.
num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
self
.
num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
self
.
vocab_size
=
self
.
hf_text_config
.
vocab_size
self
.
vocab_size
=
self
.
hf_text_config
.
vocab_size
self
.
_verify_quantization
()
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def
get_total_num_kv_heads
(
self
)
->
int
:
def
get_total_num_kv_heads
(
self
)
->
int
:
"""Returns the total number of KV heads."""
"""Returns the total number of KV heads."""
...
@@ -174,6 +184,86 @@ class ModelConfig:
...
@@ -174,6 +184,86 @@ class ModelConfig:
# parallel size so each GPU has at least one KV head.
# parallel size so each GPU has at least one KV head.
return
max
(
1
,
total_num_kv_heads
//
tensor_parallel_size
)
return
max
(
1
,
total_num_kv_heads
//
tensor_parallel_size
)
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def
_parse_quant_hf_config
(
self
):
quant_cfg
=
getattr
(
self
.
hf_config
,
"quantization_config"
,
None
)
if
quant_cfg
is
None
:
# compressed-tensors uses a "compression_config" key
quant_cfg
=
getattr
(
self
.
hf_config
,
"compression_config"
,
None
)
return
quant_cfg
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
]
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"experts_int8"
,
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
# Parse quantization method from the HF model config, if available.
quant_cfg
=
self
.
_parse_quant_hf_config
()
if
quant_cfg
is
not
None
:
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
# Detect which checkpoint is it
for
_
,
method
in
QUANTIZATION_METHODS
.
items
():
quantization_override
=
method
.
override_quantization_method
(
quant_cfg
,
self
.
quantization
)
if
quantization_override
:
quant_method
=
quantization_override
self
.
quantization
=
quantization_override
break
# Verify quantization configurations.
if
self
.
quantization
is
None
:
self
.
quantization
=
quant_method
elif
self
.
quantization
!=
quant_method
:
raise
ValueError
(
"Quantization method specified in the model config "
f
"(
{
quant_method
}
) does not match the quantization "
f
"method specified in the `quantization` argument "
f
"(
{
self
.
quantization
}
)."
)
if
self
.
quantization
is
not
None
:
if
self
.
quantization
not
in
supported_quantization
:
raise
ValueError
(
f
"Unknown quantization method:
{
self
.
quantization
}
. Must "
f
"be one of
{
supported_quantization
}
."
)
if
is_hip
()
and
self
.
quantization
not
in
rocm_supported_quantization
:
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
if
self
.
quantization
not
in
optimized_quantization_methods
:
logger
.
warning
(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models."
,
self
.
quantization
,
)
def
get_hf_text_config
(
config
:
PretrainedConfig
):
def
get_hf_text_config
(
config
:
PretrainedConfig
):
"""Get the "sub" config relevant to llm for multi modal models.
"""Get the "sub" config relevant to llm for multi modal models.
...
@@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig):
...
@@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig):
if
class_name
.
startswith
(
"Llava"
)
and
class_name
.
endswith
(
"ForCausalLM"
):
if
class_name
.
startswith
(
"Llava"
)
and
class_name
.
endswith
(
"ForCausalLM"
):
# We support non-hf version of llava models, so we do not want to
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
# read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr
(
config
,
"torch_dtype"
,
torch
.
float16
)
return
config
return
config
if
hasattr
(
config
,
"text_config"
):
if
hasattr
(
config
,
"text_config"
):
...
@@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig):
...
@@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig):
return
config
return
config
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
float16
,
"float16"
:
torch
.
float16
,
"float"
:
torch
.
float32
,
"float32"
:
torch
.
float32
,
"bfloat16"
:
torch
.
bfloat16
,
}
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def
_get_and_verify_dtype
(
config
:
PretrainedConfig
,
dtype
:
Union
[
str
,
torch
.
dtype
],
)
->
torch
.
dtype
:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype
=
getattr
(
config
,
"torch_dtype"
,
None
)
if
config_dtype
is
None
:
config_dtype
=
torch
.
float32
if
isinstance
(
dtype
,
str
):
dtype
=
dtype
.
lower
()
if
dtype
==
"auto"
:
if
config_dtype
==
torch
.
float32
:
if
config
.
model_type
==
"gemma2"
:
logger
.
info
(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
)
torch_dtype
=
torch
.
bfloat16
else
:
# Following the common practice, we use float16 for float32
# models.
torch_dtype
=
torch
.
float16
else
:
torch_dtype
=
config_dtype
else
:
if
dtype
not
in
_STR_DTYPE_TO_TORCH_DTYPE
:
raise
ValueError
(
f
"Unknown dtype:
{
dtype
}
"
)
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
elif
isinstance
(
dtype
,
torch
.
dtype
):
torch_dtype
=
dtype
else
:
raise
ValueError
(
f
"Unknown dtype:
{
dtype
}
"
)
# Verify the dtype.
if
torch_dtype
!=
config_dtype
:
if
torch_dtype
==
torch
.
float32
:
# Upcasting to float32 is allowed.
logger
.
info
(
"Upcasting %s to %s."
,
config_dtype
,
torch_dtype
)
pass
elif
config_dtype
==
torch
.
float32
:
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger
.
info
(
"Downcasting %s to %s."
,
config_dtype
,
torch_dtype
)
pass
else
:
# Casting between float16 and bfloat16 is allowed with a warning.
logger
.
warning
(
"Casting %s to %s."
,
config_dtype
,
torch_dtype
)
return
torch_dtype
def
is_generation_model
(
model_architectures
:
List
[
str
],
is_embedding
:
bool
=
False
):
def
is_generation_model
(
model_architectures
:
List
[
str
],
is_embedding
:
bool
=
False
):
# We have two ways to determine whether a model is a generative model.
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 1. Check the model architectue
...
...
python/sglang/srt/configs/qwen2vl.py
View file @
85e1a6f3
...
@@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig):
...
@@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig):
self
.
attention_dropout
=
attention_dropout
self
.
attention_dropout
=
attention_dropout
self
.
rope_scaling
=
rope_scaling
self
.
rope_scaling
=
rope_scaling
# NOTE: the following section from original transformers config
# NOTE(HandH1998): This is necessary for configuring the `rope_type`` of qwen2vl models after removing dependencies on vllm.
# for Qwen2-VL is commented out to address rope config loading issue
if
self
.
rope_scaling
is
not
None
and
"type"
in
self
.
rope_scaling
:
#
if
self
.
rope_scaling
[
"type"
]
==
"mrope"
:
# if self.rope_scaling is not None and "type" in self.rope_scaling:
self
.
rope_scaling
[
"type"
]
=
"default"
# if self.rope_scaling["type"] == "mrope":
self
.
rope_scaling
[
"rope_type"
]
=
self
.
rope_scaling
[
"type"
]
# self.rope_scaling["type"] = "default"
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# rope_config_validation(self)
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
)
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
)
python/sglang/srt/hf_transformers_utils.py
View file @
85e1a6f3
...
@@ -75,6 +75,8 @@ def get_config(
...
@@ -75,6 +75,8 @@ def get_config(
if
config
.
model_type
in
_CONFIG_REGISTRY
:
if
config
.
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
)
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
)
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
setattr
(
config
,
"_name_or_path"
,
model
)
if
model_override_args
:
if
model_override_args
:
config
.
update
(
model_override_args
)
config
.
update
(
model_override_args
)
...
...
python/sglang/srt/layers/linear.py
View file @
85e1a6f3
...
@@ -42,6 +42,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
...
@@ -42,6 +42,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"Fp8LinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
"MarlinLinearMethod"
,
"GPTQLinearMethod"
,
"GPTQLinearMethod"
,
"QQQLinearMethod"
,
]
]
...
...
python/sglang/srt/lora/lora.py
View file @
85e1a6f3
...
@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
...
@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
...
@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
class
BaseLayerWithLoRA
(
nn
.
Module
):
class
BaseLayerWithLoRA
(
nn
.
Module
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
85e1a6f3
...
@@ -147,9 +147,12 @@ class Scheduler:
...
@@ -147,9 +147,12 @@ class Scheduler:
self
.
model_config
=
ModelConfig
(
self
.
model_config
=
ModelConfig
(
server_args
.
model_path
,
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
is_embedding
=
server_args
.
is_embedding
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
)
)
self
.
is_generation
=
self
.
model_config
.
is_generation
self
.
is_generation
=
self
.
model_config
.
is_generation
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
85e1a6f3
...
@@ -109,9 +109,12 @@ class TokenizerManager:
...
@@ -109,9 +109,12 @@ class TokenizerManager:
self
.
model_config
=
ModelConfig
(
self
.
model_config
=
ModelConfig
(
server_args
.
model_path
,
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
is_embedding
=
server_args
.
is_embedding
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
)
)
self
.
is_generation
=
self
.
model_config
.
is_generation
self
.
is_generation
=
self
.
model_config
.
is_generation
...
...
python/sglang/srt/managers/tp_worker.py
View file @
85e1a6f3
...
@@ -52,9 +52,12 @@ class TpModelWorker:
...
@@ -52,9 +52,12 @@ class TpModelWorker:
self
.
model_config
=
ModelConfig
(
self
.
model_config
=
ModelConfig
(
server_args
.
model_path
,
server_args
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
context_length
=
server_args
.
context_length
,
model_override_args
=
server_args
.
json_model_override_args
,
model_override_args
=
server_args
.
json_model_override_args
,
is_embedding
=
server_args
.
is_embedding
,
is_embedding
=
server_args
.
is_embedding
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
)
)
self
.
model_runner
=
ModelRunner
(
self
.
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
model_config
=
self
.
model_config
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
85e1a6f3
...
@@ -14,22 +14,12 @@
...
@@ -14,22 +14,12 @@
"""ModelRunner runs the forward passes of the models."""
"""ModelRunner runs the forward passes of the models."""
import
gc
import
gc
import
importlib
import
importlib.resources
import
inspect
import
json
import
json
import
logging
import
logging
import
pkgutil
from
typing
import
Optional
import
time
from
functools
import
lru_cache
from
tokenize
import
tabsize
from
typing
import
Any
,
Optional
,
Type
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.distributed
import
(
from
vllm.distributed
import
(
get_tp_group
,
get_tp_group
,
init_distributed_environment
,
init_distributed_environment
,
...
@@ -37,9 +27,9 @@ from vllm.distributed import (
...
@@ -37,9 +27,9 @@ from vllm.distributed import (
set_custom_all_reduce
,
set_custom_all_reduce
,
)
)
from
vllm.distributed.parallel_state
import
in_the_same_node_as
from
vllm.distributed.parallel_state
import
in_the_same_node_as
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.layers.attention.double_sparsity_backend
import
DoubleSparseAttnBackend
from
sglang.srt.layers.attention.double_sparsity_backend
import
DoubleSparseAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
...
@@ -56,16 +46,15 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -56,16 +46,15 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool
,
ReqToTokenPool
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
crash_on_warnings
,
enable_show_time_cost
,
enable_show_time_cost
,
get_available_gpu_memory
,
get_available_gpu_memory
,
init_custom_process_group
,
init_custom_process_group
,
is_hip
,
is_hip
,
monkey_patch_vllm_gguf_config
,
monkey_patch_vllm_gguf_config
,
monkey_patch_vllm_model_config
,
monkey_patch_vllm_p2p_access_check
,
monkey_patch_vllm_p2p_access_check
,
set_cpu_offload_max_bytes
,
set_cpu_offload_max_bytes
,
)
)
...
@@ -228,49 +217,6 @@ class ModelRunner:
...
@@ -228,49 +217,6 @@ class ModelRunner:
return
min_per_gpu_memory
return
min_per_gpu_memory
def
setup_model
(
self
):
try
:
from
vllm.config
import
VllmConfig
vllm_config
=
VllmConfig
()
vllm_config
.
model_config
=
self
.
vllm_model_config
vllm_config
.
load_config
=
self
.
load_config
vllm_config
.
device_config
=
DeviceConfig
(
self
.
device
)
vllm_config
.
quant_config
=
VllmConfig
.
_get_quantization_config
(
vllm_config
.
model_config
,
vllm_config
.
load_config
)
return
get_model
(
vllm_config
=
vllm_config
)
except
ImportError
:
pass
return
get_model
(
model_config
=
self
.
vllm_model_config
,
load_config
=
self
.
load_config
,
device_config
=
DeviceConfig
(
self
.
device
),
parallel_config
=
None
,
scheduler_config
=
None
,
lora_config
=
None
,
cache_config
=
None
,
)
def
get_model_config_params
(
self
):
sig
=
inspect
.
signature
(
VllmModelConfig
.
__init__
)
params
=
{
"model"
:
self
.
server_args
.
model_path
,
"quantization"
:
self
.
server_args
.
quantization
,
"tokenizer"
:
None
,
"tokenizer_mode"
:
None
,
"trust_remote_code"
:
self
.
server_args
.
trust_remote_code
,
"dtype"
:
self
.
server_args
.
dtype
,
"seed"
:
self
.
server_args
.
random_seed
,
"skip_tokenizer_init"
:
True
,
}
if
"task"
in
sig
.
parameters
:
params
[
"task"
]
=
""
return
params
def
load_model
(
self
):
def
load_model
(
self
):
logger
.
info
(
logger
.
info
(
f
"Load weight begin. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"Load weight begin. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
...
@@ -284,6 +230,7 @@ class ModelRunner:
...
@@ -284,6 +230,7 @@ class ModelRunner:
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
)
self
.
server_args
.
dtype
=
"float16"
self
.
server_args
.
dtype
=
"float16"
self
.
model_config
.
dtype
=
torch
.
float16
if
torch
.
cuda
.
get_device_capability
()[
1
]
<
5
:
if
torch
.
cuda
.
get_device_capability
()[
1
]
<
5
:
raise
RuntimeError
(
"SGLang only supports sm75 and above."
)
raise
RuntimeError
(
"SGLang only supports sm75 and above."
)
...
@@ -292,23 +239,21 @@ class ModelRunner:
...
@@ -292,23 +239,21 @@ class ModelRunner:
load_format
=
self
.
server_args
.
load_format
,
load_format
=
self
.
server_args
.
load_format
,
download_dir
=
self
.
server_args
.
download_dir
,
download_dir
=
self
.
server_args
.
download_dir
,
)
)
monkey_patch_vllm_model_config
()
if
self
.
server_args
.
load_format
==
"gguf"
:
if
self
.
server_args
.
load_format
==
"gguf"
:
monkey_patch_vllm_gguf_config
()
monkey_patch_vllm_gguf_config
()
self
.
vllm_model_config
=
VllmModelConfig
(
**
self
.
get_model_config_params
())
self
.
model
=
get_model
(
if
self
.
model_config
.
model_
override_args
is
not
None
:
model_config
=
self
.
model_
config
,
self
.
vllm_model_config
.
hf_config
.
update
(
load_config
=
self
.
load_config
,
self
.
model_config
.
model_override_args
device_config
=
DeviceConfig
(
self
.
device
),
)
)
self
.
model
=
self
.
setup_model
()
self
.
sliding_window_size
=
(
self
.
sliding_window_size
=
(
self
.
model
.
get_attention_sliding_window_size
()
self
.
model
.
get_attention_sliding_window_size
()
if
hasattr
(
self
.
model
,
"get_attention_sliding_window_size"
)
if
hasattr
(
self
.
model
,
"get_attention_sliding_window_size"
)
else
None
else
None
)
)
self
.
dtype
=
self
.
vllm_
model_config
.
dtype
self
.
dtype
=
self
.
model_config
.
dtype
logger
.
info
(
logger
.
info
(
f
"Load weight end. "
f
"Load weight end. "
...
@@ -319,12 +264,12 @@ class ModelRunner:
...
@@ -319,12 +264,12 @@ class ModelRunner:
def
update_weights_from_disk
(
self
,
model_path
:
str
,
load_format
:
str
):
def
update_weights_from_disk
(
self
,
model_path
:
str
,
load_format
:
str
):
"""Update engine weights online from disk."""
"""Update engine weights online from disk."""
from
vllm.model_executor
.model_loader.loader
import
(
from
sglang.srt
.model_loader.loader
import
(
DefaultModelLoader
,
DefaultModelLoader
,
device_loading_context
,
device_loading_context
,
get_model_loader
,
get_model_loader
,
)
)
from
vllm.model_executor
.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt
.model_loader.utils
import
set_default_torch_dtype
logger
.
info
(
logger
.
info
(
f
"Update engine weights online from disk begin. "
f
"Update engine weights online from disk begin. "
...
@@ -332,15 +277,7 @@ class ModelRunner:
...
@@ -332,15 +277,7 @@ class ModelRunner:
)
)
target_device
=
torch
.
device
(
self
.
device
)
target_device
=
torch
.
device
(
self
.
device
)
self
.
model_config
.
model_path
=
model_path
try
:
model_config_params
=
self
.
get_model_config_params
()
model_config_params
[
"model"
]
=
model_path
vllm_model_config
=
VllmModelConfig
(
**
model_config_params
)
except
Exception
as
e
:
message
=
f
"Failed to load model config:
{
e
}
."
return
False
,
message
load_config
=
LoadConfig
(
load_format
=
load_format
)
load_config
=
LoadConfig
(
load_format
=
load_format
)
# Only support vllm DefaultModelLoader for now
# Only support vllm DefaultModelLoader for now
...
@@ -352,7 +289,7 @@ class ModelRunner:
...
@@ -352,7 +289,7 @@ class ModelRunner:
def
get_weight_iter
(
config
):
def
get_weight_iter
(
config
):
iter
=
loader
.
_get_weights_iterator
(
iter
=
loader
.
_get_weights_iterator
(
DefaultModelLoader
.
Source
(
DefaultModelLoader
.
Source
(
config
.
model
,
config
.
model
_path
,
revision
=
config
.
revision
,
revision
=
config
.
revision
,
fall_back_to_pt
=
getattr
(
fall_back_to_pt
=
getattr
(
self
.
model
,
"fall_back_to_pt_during_load"
,
True
self
.
model
,
"fall_back_to_pt_during_load"
,
True
...
@@ -370,9 +307,9 @@ class ModelRunner:
...
@@ -370,9 +307,9 @@ class ModelRunner:
quant_method
.
process_weights_after_loading
(
module
)
quant_method
.
process_weights_after_loading
(
module
)
return
model
return
model
with
set_default_torch_dtype
(
vllm_
model_config
.
dtype
):
with
set_default_torch_dtype
(
self
.
model_config
.
dtype
):
try
:
try
:
iter
=
get_weight_iter
(
vllm_
model_config
)
iter
=
get_weight_iter
(
self
.
model_config
)
except
Exception
as
e
:
except
Exception
as
e
:
message
=
f
"Failed to get weights iterator:
{
e
}
."
message
=
f
"Failed to get weights iterator:
{
e
}
."
return
False
,
message
return
False
,
message
...
@@ -384,16 +321,14 @@ class ModelRunner:
...
@@ -384,16 +321,14 @@ class ModelRunner:
)
)
del
iter
del
iter
gc
.
collect
()
gc
.
collect
()
iter
=
get_weight_iter
(
self
.
vllm_
model_config
)
iter
=
get_weight_iter
(
self
.
model_config
)
self
.
model
=
model_load_weights
(
self
.
model
,
iter
)
self
.
model
=
model_load_weights
(
self
.
model
,
iter
)
return
False
,
message
return
False
,
message
self
.
model
=
model
self
.
model
=
model
self
.
server_args
.
model_path
=
model_path
self
.
server_args
.
model_path
=
model_path
self
.
server_args
.
load_format
=
load_format
self
.
server_args
.
load_format
=
load_format
self
.
vllm_model_config
=
vllm_model_config
self
.
load_config
=
load_config
self
.
load_config
=
load_config
self
.
model_config
.
path
=
model_path
logger
.
info
(
"Update weights end."
)
logger
.
info
(
"Update weights end."
)
return
True
,
"Succeeded to update model weights."
return
True
,
"Succeeded to update model weights."
...
@@ -794,55 +729,3 @@ class ModelRunner:
...
@@ -794,55 +729,3 @@ class ModelRunner:
if
rope_scaling
is
None
:
if
rope_scaling
is
None
:
return
False
return
False
return
rope_scaling
.
get
(
"type"
,
None
)
==
"mrope"
return
rope_scaling
.
get
(
"type"
,
None
)
==
"mrope"
@
lru_cache
()
def
import_model_classes
():
model_arch_name_to_cls
=
{}
package_name
=
"sglang.srt.models"
package
=
importlib
.
import_module
(
package_name
)
for
_
,
name
,
ispkg
in
pkgutil
.
iter_modules
(
package
.
__path__
,
package_name
+
"."
):
if
not
ispkg
:
try
:
module
=
importlib
.
import_module
(
name
)
except
Exception
as
e
:
logger
.
warning
(
f
"Ignore import error when loading
{
name
}
.
{
e
}
"
)
if
crash_on_warnings
():
raise
ValueError
(
f
"Ignore import error when loading
{
name
}
.
{
e
}
"
)
continue
if
hasattr
(
module
,
"EntryClass"
):
entry
=
module
.
EntryClass
if
isinstance
(
entry
,
list
):
# To support multiple model classes in one module
for
tmp
in
entry
:
assert
(
tmp
.
__name__
not
in
model_arch_name_to_cls
),
f
"Duplicated model implementation for
{
tmp
.
__name__
}
"
model_arch_name_to_cls
[
tmp
.
__name__
]
=
tmp
else
:
assert
(
entry
.
__name__
not
in
model_arch_name_to_cls
),
f
"Duplicated model implementation for
{
entry
.
__name__
}
"
model_arch_name_to_cls
[
entry
.
__name__
]
=
entry
return
model_arch_name_to_cls
def
load_model_cls_srt
(
model_arch
:
str
)
->
Optional
[
Type
[
nn
.
Module
]]:
model_arch_name_to_cls
=
import_model_classes
()
if
model_arch
not
in
model_arch_name_to_cls
:
raise
ValueError
(
f
"Unsupported architectures:
{
model_arch
}
. "
f
"Supported list:
{
list
(
model_arch_name_to_cls
.
keys
())
}
"
)
return
model_arch_name_to_cls
[
model_arch
]
# Monkey patch model loader
setattr
(
ModelRegistry
,
"_try_load_model_cls"
,
load_model_cls_srt
)
setattr
(
ModelRegistry
,
"is_multimodal_model"
,
lambda
model_architectures
:
False
)
setattr
(
ModelRegistry
,
"is_attention_free_model"
,
lambda
model_architectures
:
False
)
setattr
(
ModelRegistry
,
"model_has_inner_state"
,
lambda
model_architectures
:
False
)
setattr
(
ModelRegistry
,
"is_embedding_model"
,
lambda
model_architectures
:
False
)
python/sglang/srt/model_loader/__init__.py
0 → 100644
View file @
85e1a6f3
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
from
torch
import
nn
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.model_loader.loader
import
BaseModelLoader
,
get_model_loader
from
sglang.srt.model_loader.utils
import
(
get_architecture_class_name
,
get_model_architecture
,
)
def
get_model
(
*
,
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
loader
=
get_model_loader
(
load_config
)
return
loader
.
load_model
(
model_config
=
model_config
,
device_config
=
device_config
,
)
__all__
=
[
"get_model"
,
"get_model_loader"
,
"BaseModelLoader"
,
"get_architecture_class_name"
,
"get_model_architecture"
,
]
python/sglang/srt/model_loader/loader.py
0 → 100644
View file @
85e1a6f3
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
# ruff: noqa: SIM117
import
collections
import
dataclasses
import
fnmatch
import
glob
import
json
import
logging
import
math
import
os
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
cast
import
gguf
import
huggingface_hub
import
numpy
as
np
import
torch
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
torch
import
nn
from
transformers
import
AutoModelForCausalLM
,
PretrainedConfig
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
,
LoadFormat
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_loader.utils
import
(
get_model_architecture
,
set_default_torch_dtype
,
)
from
sglang.srt.model_loader.weight_utils
import
(
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
get_gguf_extra_tensor_names
,
get_quant_config
,
gguf_quant_weights_iterator
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
,
)
from
sglang.srt.utils
import
(
get_device_capability
,
is_pin_memory_available
,
set_weight_attrs
,
)
@
contextmanager
def
device_loading_context
(
module
:
torch
.
nn
.
Module
,
target_device
:
torch
.
device
):
if
target_device
.
type
==
"cpu"
:
# If target is CPU, no need to move anything
yield
module
return
original_device_states
:
Dict
[
str
,
torch
.
device
]
=
{}
# Store original device states and move parameters to GPU if they're on CPU
for
name
,
p
in
module
.
named_parameters
():
if
p
.
device
.
type
==
"cpu"
:
original_device_states
[
name
]
=
p
.
device
p
.
data
=
p
.
data
.
to
(
target_device
)
# Parameters already on target device are not touched
try
:
yield
module
finally
:
# Restore parameters to their original devices, ignoring new parameters
pin_memory
=
is_pin_memory_available
()
for
name
,
p
in
module
.
named_parameters
():
if
name
in
original_device_states
:
original_device
:
torch
.
device
=
original_device_states
[
name
]
if
original_device
.
type
==
"cpu"
:
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
stride
=
p
.
data
.
stride
(),
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
else
:
p
.
data
=
p
.
data
.
to
(
original_device
)
# New parameters or parameters already on target device are untouched
logger
=
logging
.
getLogger
(
__name__
)
def
_get_quantization_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
Optional
[
QuantizationConfig
]:
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
major
,
minor
=
get_device_capability
()
if
major
is
not
None
and
minor
is
not
None
:
assert
0
<=
minor
<
10
capability
=
major
*
10
+
minor
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
f
"The quantization method
{
model_config
.
quantization
}
"
"is not supported for the current GPU. "
f
"Minimum capability:
{
quant_config
.
get_min_capability
()
}
. "
f
"Current capability:
{
capability
}
."
)
supported_dtypes
=
quant_config
.
get_supported_act_dtypes
()
if
model_config
.
dtype
not
in
supported_dtypes
:
raise
ValueError
(
f
"
{
model_config
.
dtype
}
is not supported for quantization "
f
"method
{
model_config
.
quantization
}
. Supported dtypes: "
f
"
{
supported_dtypes
}
"
)
return
quant_config
return
None
def
_initialize_model
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_class
,
_
=
get_model_architecture
(
model_config
)
quant_config
=
_get_quantization_config
(
model_config
,
load_config
)
return
model_class
(
config
=
model_config
.
hf_config
,
quant_config
=
quant_config
,
)
class
BaseModelLoader
(
ABC
):
"""Base class for model loaders."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
self
.
load_config
=
load_config
@
abstractmethod
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
"""Download a model so that it can be immediately loaded."""
raise
NotImplementedError
@
abstractmethod
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
raise
NotImplementedError
class
DefaultModelLoader
(
BaseModelLoader
):
"""Model loader that can load different file types from disk."""
@
dataclasses
.
dataclass
class
Source
:
"""A source for weights."""
model_or_path
:
str
"""The model ID or path."""
revision
:
Optional
[
str
]
"""The optional model revision."""
prefix
:
str
=
""
"""A prefix to prepend to all weights."""
fall_back_to_pt
:
bool
=
True
"""Whether .pt weights can be used."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
_maybe_download_from_modelscope
(
self
,
model
:
str
,
revision
:
Optional
[
str
]
)
->
Optional
[
str
]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if
"SGLANG_USE_MODELSCOPE"
in
os
.
environ
:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from
modelscope.hub.snapshot_download
import
snapshot_download
if
not
os
.
path
.
exists
(
model
):
model_path
=
snapshot_download
(
model_id
=
model
,
cache_dir
=
self
.
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
revision
=
revision
,
ignore_file_pattern
=
self
.
load_config
.
ignore_patterns
,
)
else
:
model_path
=
model
return
model_path
return
None
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path
=
(
self
.
_maybe_download_from_modelscope
(
model_name_or_path
,
revision
)
or
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
load_format
=
self
.
load_config
.
load_format
use_safetensors
=
False
index_file
=
SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
LoadFormat
.
SAFETENSORS
:
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
MISTRAL
:
use_safetensors
=
True
allow_patterns
=
[
"consolidated*.safetensors"
]
index_file
=
"consolidated.safetensors.index.json"
elif
load_format
==
LoadFormat
.
PT
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
allow_patterns
=
[
"*.bin"
]
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
if
fall_back_to_pt
:
allow_patterns
+=
[
"*.pt"
]
if
not
is_local
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
else
:
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
len
(
hf_weights_files
)
>
0
:
if
pattern
==
"*.safetensors"
:
use_safetensors
=
True
break
if
use_safetensors
:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if
not
is_local
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
index_file
,
self
.
load_config
.
download_dir
,
revision
,
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
,
index_file
)
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
def
_get_weights_iterator
(
self
,
source
:
"Source"
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
source
.
model_or_path
,
source
.
revision
,
source
.
fall_back_to_pt
)
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
weights_iterator
=
np_cache_weights_iterator
(
source
.
model_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
hf_weights_files
,
)
elif
use_safetensors
:
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
)
# Apply the prefix.
return
((
source
.
prefix
+
name
,
tensor
)
for
(
name
,
tensor
)
in
weights_iterator
)
def
_get_all_weights
(
self
,
model_config
:
ModelConfig
,
model
:
nn
.
Module
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
primary_weights
=
DefaultModelLoader
.
Source
(
model_config
.
model_path
,
model_config
.
revision
,
prefix
=
""
,
fall_back_to_pt
=
getattr
(
model
,
"fall_back_to_pt_during_load"
,
True
),
)
yield
from
self
.
_get_weights_iterator
(
primary_weights
)
secondary_weights
=
cast
(
Iterable
[
DefaultModelLoader
.
Source
],
getattr
(
model
,
"secondary_weights"
,
())
)
for
source
in
secondary_weights
:
yield
from
self
.
_get_weights_iterator
(
source
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model_path
,
model_config
.
revision
,
fall_back_to_pt
=
True
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
)
model
.
load_weights
(
self
.
_get_all_weights
(
model_config
,
model
))
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
return
model
.
eval
()
class
DummyModelLoader
(
BaseModelLoader
):
"""Model loader that will set model weights to random values."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
# Nothing to download
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
return
model
.
eval
()
class
ShardedStateLoader
(
BaseModelLoader
):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_state.py` for creating a sharded checkpoint.
"""
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
extra_config
=
(
{}
if
load_config
.
model_loader_extra_config
is
None
else
load_config
.
model_loader_extra_config
.
copy
()
)
self
.
pattern
=
extra_config
.
pop
(
"pattern"
,
self
.
DEFAULT_PATTERN
)
if
extra_config
:
raise
ValueError
(
f
"Unexpected extra config keys for load format "
f
"
{
load_config
.
load_format
}
: "
f
"
{
load_config
.
model_loader_extra_config
.
keys
()
}
"
)
@
staticmethod
def
_filter_subtensors
(
tensors
:
Dict
[
str
,
torch
.
Tensor
])
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups
:
Dict
[
Any
,
List
[
Tuple
[
str
,
torch
.
Tensor
]]]
=
(
collections
.
defaultdict
(
list
)
)
for
key
,
tensor
in
tensors
.
items
():
if
tensor
.
numel
():
ptr
=
tensor
.
untyped_storage
().
data_ptr
()
same_storage_groups
[
tensor
.
device
,
ptr
].
append
((
key
,
tensor
))
def
get_end_ptr
(
tensor
:
torch
.
Tensor
)
->
int
:
return
tensor
.
view
(
-
1
)[
-
1
].
data_ptr
()
+
tensor
.
element_size
()
result
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
for
group
in
same_storage_groups
.
values
():
for
k
,
t
in
group
:
a
,
b
=
t
.
data_ptr
(),
get_end_ptr
(
t
)
for
k2
,
t2
in
group
:
if
not
t2
.
is_contiguous
():
continue
a2
,
b2
=
t2
.
data_ptr
(),
get_end_ptr
(
t2
)
if
a
<
a2
or
b2
<
b
:
continue
if
a2
<
a
or
b
<
b2
or
not
t
.
is_contiguous
():
break
# t2 covers strictly more memory than t.
if
k2
<
k
:
# Same tensors, keep the one with the smaller key.
break
else
:
result
[
k
]
=
t
return
result
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]):
if
os
.
path
.
isdir
(
model_name_or_path
):
return
model_name_or_path
else
:
allow_patterns
=
[
"*.safetensors"
]
return
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model_path
,
model_config
.
revision
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
from
safetensors.torch
import
safe_open
from
vllm.distributed
import
get_tensor_model_parallel_rank
local_model_path
=
self
.
_prepare_weights
(
model_config
.
model_path
,
model_config
.
revision
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
rank
=
get_tensor_model_parallel_rank
()
pattern
=
os
.
path
.
join
(
local_model_path
,
self
.
pattern
.
format
(
rank
=
rank
,
part
=
"*"
),
)
filepaths
=
glob
.
glob
(
pattern
)
if
not
filepaths
:
# TODO: support un-sharded checkpoints too
raise
ValueError
(
f
"Could not find checkpoint files '
{
pattern
}
', only "
f
"pre-sharded checkpoints are currently supported!"
)
state_dict
=
self
.
_filter_subtensors
(
model
.
state_dict
())
for
path
in
filepaths
:
with
safe_open
(
path
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
# noqa: SIM118
tensor
=
f
.
get_tensor
(
key
)
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data
=
state_dict
[
key
].
data
param_shape
=
state_dict
[
key
].
shape
for
dim
,
size
in
enumerate
(
tensor
.
shape
):
if
size
<
param_shape
[
dim
]:
param_data
=
param_data
.
narrow
(
dim
,
0
,
size
)
if
tensor
.
shape
!=
param_shape
:
logger
.
warning
(
"loading tensor of shape %s into "
"parameter '%s' of shape %s"
,
tensor
.
shape
,
key
,
param_shape
,
)
param_data
.
copy_
(
tensor
)
state_dict
.
pop
(
key
)
if
state_dict
:
raise
ValueError
(
f
"Missing keys
{
tuple
(
state_dict
)
}
in loaded state!"
)
return
model
.
eval
()
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
from
safetensors.torch
import
save_file
from
vllm.distributed
import
get_tensor_model_parallel_rank
if
pattern
is
None
:
pattern
=
ShardedStateLoader
.
DEFAULT_PATTERN
rank
=
get_tensor_model_parallel_rank
()
part_idx
=
0
total_size
=
0
state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
model
.
state_dict
())
state_dict_part
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
for
key
,
tensor
in
state_dict
.
items
():
param_size
=
tensor
.
nelement
()
*
tensor
.
element_size
()
if
max_size
is
not
None
and
total_size
+
param_size
>
max_size
:
filename
=
pattern
.
format
(
rank
=
rank
,
part
=
part_idx
)
save_file
(
state_dict_part
,
os
.
path
.
join
(
path
,
filename
),
)
part_idx
+=
1
total_size
=
0
state_dict_part
=
{}
state_dict_part
[
key
]
=
tensor
total_size
+=
param_size
if
len
(
state_dict_part
)
>
0
:
filename
=
pattern
.
format
(
rank
=
rank
,
part
=
part_idx
)
save_file
(
state_dict_part
,
os
.
path
.
join
(
path
,
filename
),
)
class
BitsAndBytesModelLoader
(
BaseModelLoader
):
"""Model loader to load model weights with BitAndBytes quantization."""
possible_config_file_names
=
[
"adapter_config.json"
]
default_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
".fc1."
,
".fc2."
,
".dense."
,
".query_key_value."
,
".qkv_proj."
,
".dense_h_to_4h."
,
".dense_4h_to_h."
,
".out_proj."
,
]
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
# we don't need to quantize the whole model, only the target modules
# that are specified in the adapter config file. If the adapter config
# file is not provided, we will quantize the default modules.
if
(
not
load_config
.
model_loader_extra_config
or
"qlora_adapter_name_or_path"
not
in
load_config
.
model_loader_extra_config
):
self
.
target_modules
=
[]
return
qlora_adapter
=
load_config
.
model_loader_extra_config
[
"qlora_adapter_name_or_path"
]
config_file_path
=
self
.
_get_config_file
(
qlora_adapter
)
with
open
(
config_file_path
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
self
.
target_modules
=
config
[
"target_modules"
]
def
_get_config_file
(
self
,
qlora_adapter
:
str
)
->
str
:
is_local
=
os
.
path
.
isdir
(
qlora_adapter
)
config_file_path
=
None
if
is_local
:
for
file
in
self
.
possible_config_file_names
:
config_file_path
=
os
.
path
.
join
(
qlora_adapter
,
file
)
if
os
.
path
.
exists
(
config_file_path
):
break
else
:
hf_api
=
HfApi
()
repo_files
=
hf_api
.
list_repo_files
(
repo_id
=
qlora_adapter
)
for
file
in
self
.
possible_config_file_names
:
if
file
in
repo_files
:
config_file_path
=
hf_hub_download
(
repo_id
=
qlora_adapter
,
filename
=
file
)
break
if
not
config_file_path
:
raise
ValueError
(
f
"Cannot find adapter config file in
{
qlora_adapter
}
"
)
return
config_file_path
def
_get_weight_files
(
self
,
model_name_or_path
:
str
,
allowed_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
List
[
str
],
str
]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
if
is_local
:
for
pattern
in
allowed_patterns
:
weight_files
=
glob
.
glob
(
os
.
path
.
join
(
model_name_or_path
,
pattern
))
if
weight_files
:
return
weight_files
,
pattern
else
:
hf_api
=
HfApi
()
repo_files
=
hf_api
.
list_repo_files
(
repo_id
=
model_name_or_path
)
for
pattern
in
allowed_patterns
:
matching_files
=
fnmatch
.
filter
(
repo_files
,
pattern
)
if
matching_files
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
[
pattern
],
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
return
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
)),
pattern
raise
RuntimeError
(
f
"No model weights found in: `
{
model_name_or_path
}
`"
)
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]
)
->
Tuple
[
List
[
str
],
bool
]:
"""Prepare weight files for the model."""
allowed_patterns
=
[
"*.safetensors"
,
"*.bin"
,
"*.pt"
]
hf_weights_files
,
matched_pattern
=
self
.
_get_weight_files
(
model_name_or_path
,
allowed_patterns
,
revision
)
if
matched_pattern
!=
"*.safetensors"
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_weights_files
,
matched_pattern
==
"*.safetensors"
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
if
use_safetensors
:
return
safetensors_weights_iterator
(
hf_weights_files
)
else
:
return
pt_weights_iterator
(
hf_weights_files
)
def
_get_quantized_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
pre_quant
:
bool
,
load_8bit
:
bool
,
)
->
Tuple
[
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
Dict
[
str
,
Any
]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
as well as the quantization state dictionary."""
# only load the bitsandbytes module when needed
try
:
import
bitsandbytes
if
bitsandbytes
.
__version__
<
"0.44.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.44.0."
)
except
ImportError
as
err
:
raise
ImportError
(
"Please install bitsandbytes>=0.44.0 via "
"`pip install bitsandbytes>=0.44.0` to use "
"bitsandbytes quantizer."
)
from
err
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
model_name_or_path
,
revision
)
quant_state_dict
:
Dict
[
str
,
Any
]
=
{}
if
pre_quant
:
if
load_8bit
:
return
(
self
.
_quantized_8bit_generator
(
hf_weights_files
,
use_safetensors
,
quant_state_dict
),
quant_state_dict
,
)
else
:
return
(
self
.
_quantized_4bit_generator
(
hf_weights_files
,
use_safetensors
,
quant_state_dict
),
quant_state_dict
,
)
return
(
self
.
_unquantized_generator
(
hf_weights_files
,
use_safetensors
,
quant_state_dict
),
quant_state_dict
,
)
def
_quantized_8bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
not
weight_name
.
lower
().
endswith
(
".scb"
):
continue
weight_key
=
weight_name
.
lower
().
replace
(
".scb"
,
".qweight"
)
quant_state_dict
[
weight_key
]
=
weight_tensor
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
not
weight_name
.
endswith
((
".weight"
,
".bias"
)):
continue
qweight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
if
qweight_name
in
quant_state_dict
:
set_weight_attrs
(
weight_tensor
,
{
"load_in_8bit"
:
True
})
yield
qweight_name
,
weight_tensor
else
:
yield
weight_name
,
weight_tensor
def
_quantized_4bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
from
bitsandbytes.functional
import
QuantState
# First iterate over all quant state weights
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
weight_name
.
endswith
((
".weight"
,
".bias"
)):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if
"quant_state.bitsandbytes"
in
weight_name
:
temp_state_dict
[
weight_name
]
=
weight_tensor
.
cpu
().
data
else
:
temp_state_dict
[
weight_name
]
=
weight_tensor
# Closure to parse quant_state for each prequant weight
def
_parse_quant_state
(
param_name
:
str
,
temp_state_dict
:
Dict
)
->
QuantState
:
quant_state
=
{}
for
k
in
temp_state_dict
:
if
param_name
+
"."
in
k
:
quant_state
[
k
]
=
temp_state_dict
[
k
]
return
QuantState
.
from_dict
(
quant_state
,
device
=
"cuda"
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
not
weight_name
.
endswith
((
".weight"
,
".bias"
)):
continue
if
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__nf4"
in
temp_state_dict
)
or
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__fp4"
in
temp_state_dict
):
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
quant_state_dict
[
weight_name
]
=
quant_state
yield
weight_name
.
replace
(
".weight"
,
".qweight"
),
weight_tensor
else
:
yield
weight_name
,
weight_tensor
def
_unquantized_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
from
bitsandbytes.functional
import
quantize_4bit
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
)
and
weight_name
.
endswith
(
".weight"
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
if
any
(
module
in
weight_name
for
module
in
self
.
column_parallel_weights_modules
):
total_size
=
weight_tensor
.
size
(
-
1
)
start_index
=
total_size
//
tp_size
*
tp_rank
end_index
=
total_size
//
tp_size
*
(
tp_rank
+
1
)
weight_sub_tensor
=
weight_tensor
[...,
start_index
:
end_index
]
else
:
total_size
=
weight_tensor
.
size
(
0
)
start_index
=
total_size
//
tp_size
*
tp_rank
end_index
=
total_size
//
tp_size
*
(
tp_rank
+
1
)
weight_sub_tensor
=
weight_tensor
[
start_index
:
end_index
,
...]
# bitsandbytes requires data in GPU
if
weight_sub_tensor
.
is_cuda
:
loaded_weight
=
weight_sub_tensor
else
:
loaded_weight
=
weight_sub_tensor
.
cuda
()
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if
loaded_weight
.
is_contiguous
()
is
False
:
loaded_weight
=
loaded_weight
.
contiguous
()
with
set_default_torch_dtype
(
torch
.
float32
):
processed_weight
,
quant_state
=
quantize_4bit
(
loaded_weight
,
compress_statistics
=
True
,
quant_type
=
"nf4"
)
quant_state_dict
[
weight_name
]
=
quant_state
else
:
processed_weight
=
weight_tensor
yield
weight_name
,
processed_weight
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
model
:
nn
.
Module
)
->
None
:
if
not
hasattr
(
model
,
"load_weights"
):
raise
AttributeError
(
"The required method 'load_weights' is not defined in class"
f
"
{
type
(
model
).
__name__
}
."
)
if
not
hasattr
(
model
,
"bitsandbytes_stacked_params_mapping"
):
raise
AttributeError
(
f
"Model
{
type
(
model
).
__name__
}
does not support BitsAndBytes "
"quantization yet."
)
if
len
(
self
.
target_modules
)
==
0
:
if
hasattr
(
model
,
"default_bitsandbytes_target_modules"
):
self
.
target_modules
=
model
.
default_bitsandbytes_target_modules
else
:
self
.
target_modules
=
self
.
default_target_modules
if
hasattr
(
model
,
"column_parallel_weights_modules"
):
self
.
column_parallel_weights_modules
=
model
.
column_parallel_weights_modules
else
:
self
.
column_parallel_weights_modules
=
[]
self
.
model_type
=
type
(
model
).
__name__
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
" May take a while ..."
)
quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
pre_quant
=
False
if
quant_config
is
not
None
:
quant_method
=
quant_config
.
get
(
"quant_method"
)
if
quant_method
==
"bitsandbytes"
:
pre_quant
=
True
else
:
raise
ValueError
(
f
"BitsAndBytes loader does not support
{
quant_method
}
"
"quantization"
)
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if
pre_quant
and
get_tensor_model_parallel_world_size
()
>
1
:
raise
ValueError
(
"Prequant BitsAndBytes models with TP is not supported."
"Please try with PP."
)
load_8bit
=
False
if
pre_quant
:
load_8bit
=
quant_config
.
get
(
"load_in_8bit"
,
False
)
qweight_iterator
,
quant_state_dict
=
self
.
_get_quantized_weights_iterator
(
model_config
.
model_path
,
model_config
.
revision
,
pre_quant
,
load_8bit
)
model
.
load_weights
(
qweight_iterator
)
torch
.
cuda
.
empty_cache
()
param_dict
=
dict
(
model
.
named_parameters
())
stacked_quant_state_dict
:
Dict
[
str
,
Dict
[
int
,
Any
]]
=
{}
for
quant_param_name
in
quant_state_dict
:
non_stacked_param_name
=
quant_param_name
shard_index
=
0
for
shard_name
,
(
weight_name
,
index
,
)
in
model
.
bitsandbytes_stacked_params_mapping
.
items
():
if
shard_name
in
quant_param_name
:
shard_index
=
index
quant_param_name
=
quant_param_name
.
replace
(
shard_name
,
weight_name
)
break
if
quant_param_name
not
in
param_dict
:
raise
ValueError
(
f
"Parameter
{
quant_param_name
}
not found in the model."
)
if
quant_param_name
not
in
stacked_quant_state_dict
:
stacked_quant_state_dict
[
quant_param_name
]
=
{}
stacked_quant_state_dict
[
quant_param_name
][
shard_index
]
=
quant_state_dict
[
non_stacked_param_name
]
# save quant_states and offsets as the attributes of the parameters
for
param_name
,
param
in
param_dict
.
items
():
if
param_name
in
stacked_quant_state_dict
:
quant_states
=
stacked_quant_state_dict
[
param_name
]
set_weight_attrs
(
param
,
{
"bnb_quant_state"
:
quant_states
})
pack_ratio
=
getattr
(
param
,
"pack_factor"
,
-
1
)
if
pack_ratio
==
-
1
:
raise
ValueError
(
f
"pack_factor not set for parameter
{
param_name
}
."
)
num_elements
=
[
0
]
*
len
(
quant_states
)
for
seq
,
quant_state
in
quant_states
.
items
():
num_elements
[
seq
]
=
math
.
prod
(
quant_state
.
shape
)
//
pack_ratio
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
if
load_8bit
:
set_weight_attrs
(
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)}
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model_path
,
model_config
.
revision
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
)
self
.
_load_weights
(
model_config
,
model
)
return
model
.
eval
()
class
GGUFModelLoader
(
BaseModelLoader
):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
_prepare_weights
(
self
,
model_name_or_path
:
str
):
if
os
.
path
.
isfile
(
model_name_or_path
):
return
model_name_or_path
else
:
raise
ValueError
(
f
"
{
model_name_or_path
}
is not a file."
)
def
_get_gguf_weights_map
(
self
,
model_config
:
ModelConfig
):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
`blk.N.BB.weight` and `blk.N.BB.bias`
where N signifies the block number of a layer, and BB signifies the
attention/mlp layer components.
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config
=
model_config
.
hf_config
model_type
=
config
.
model_type
# hack: ggufs have a different name than transformers
if
model_type
==
"cohere"
:
model_type
=
"command-r"
arch
=
None
for
key
,
value
in
gguf
.
MODEL_ARCH_NAMES
.
items
():
if
value
==
model_type
:
arch
=
key
break
if
arch
is
None
:
raise
RuntimeError
(
f
"Unknown gguf model_type:
{
model_type
}
"
)
num_layers
=
config
.
num_hidden_layers
name_map
=
gguf
.
get_tensor_name_map
(
arch
,
num_layers
)
with
torch
.
device
(
"meta"
):
dummy_model
=
AutoModelForCausalLM
.
from_config
(
config
)
state_dict
=
dummy_model
.
state_dict
()
gguf_to_hf_name_map
=
{}
for
hf_name
in
state_dict
:
name
,
suffix
=
hf_name
.
rsplit
(
"."
,
1
)
gguf_name
=
name_map
.
get_name
(
name
)
gguf_to_hf_name_map
[
f
"
{
gguf_name
}
.
{
suffix
}
"
]
=
hf_name
return
gguf_to_hf_name_map
def
_get_weights_iterator
(
self
,
model_name_or_path
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
return
gguf_quant_weights_iterator
(
model_name_or_path
,
gguf_to_hf_name_map
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model_path
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
local_model_path
=
self
.
_prepare_weights
(
model_config
.
model_path
)
gguf_weights_map
=
self
.
_get_gguf_weights_map
(
model_config
)
# we can only know if tie word embeddings after mapping weights
if
"lm_head.weight"
in
get_gguf_extra_tensor_names
(
local_model_path
,
gguf_weights_map
):
model_config
.
hf_config
.
update
({
"tie_word_embeddings"
:
True
})
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
(
local_model_path
,
gguf_weights_map
)
)
return
model
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
if
isinstance
(
load_config
.
load_format
,
type
):
return
load_config
.
load_format
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
DUMMY
:
return
DummyModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
SHARDED_STATE
:
return
ShardedStateLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
BITSANDBYTES
:
return
BitsAndBytesModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
GGUF
:
return
GGUFModelLoader
(
load_config
)
return
DefaultModelLoader
(
load_config
)
python/sglang/srt/model_loader/utils.py
0 → 100644
View file @
85e1a6f3
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/utils.py
"""Utilities for selecting and loading models."""
import
contextlib
from
typing
import
Tuple
,
Type
import
torch
from
torch
import
nn
from
sglang.srt.configs.model_config
import
ModelConfig
@
contextlib
.
contextmanager
def
set_default_torch_dtype
(
dtype
:
torch
.
dtype
):
"""Sets the default torch dtype to the given dtype."""
old_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
yield
torch
.
set_default_dtype
(
old_dtype
)
def
get_model_architecture
(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
from
sglang.srt.models.registry
import
ModelRegistry
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported
=
[
"fp8"
,
"compressed-tensors"
,
"gptq_marlin"
,
"awq_marlin"
]
if
(
model_config
.
quantization
is
not
None
and
model_config
.
quantization
not
in
mixtral_supported
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
def
get_architecture_class_name
(
model_config
:
ModelConfig
)
->
str
:
return
get_model_architecture
(
model_config
)[
1
]
python/sglang/srt/model_loader/weight_utils.py
0 → 100644
View file @
85e1a6f3
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
"""Utilities for downloading and initializing model weights."""
import
fnmatch
import
glob
import
hashlib
import
json
import
logging
import
os
import
tempfile
from
collections
import
defaultdict
from
typing
import
Any
,
Callable
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Union
import
filelock
import
gguf
import
huggingface_hub.constants
import
numpy
as
np
import
torch
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.layers.quantization
import
QuantizationConfig
,
get_quantization_config
from
sglang.srt.utils
import
print_warning_once
logger
=
logging
.
getLogger
(
__name__
)
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir
=
tempfile
.
gettempdir
()
def
enable_hf_transfer
():
"""automatically activates hf_transfer"""
if
"HF_HUB_ENABLE_HF_TRANSFER"
not
in
os
.
environ
:
try
:
# enable hf hub transfer if available
import
hf_transfer
# type: ignore # noqa
huggingface_hub
.
constants
.
HF_HUB_ENABLE_HF_TRANSFER
=
True
except
ImportError
:
pass
enable_hf_transfer
()
class
DisabledTqdm
(
tqdm
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
,
disable
=
True
)
def
get_lock
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
):
lock_dir
=
cache_dir
or
temp_dir
os
.
makedirs
(
os
.
path
.
dirname
(
lock_dir
),
exist_ok
=
True
)
model_name
=
model_name_or_path
.
replace
(
"/"
,
"-"
)
hash_name
=
hashlib
.
sha256
(
model_name
.
encode
()).
hexdigest
()
# add hash to avoid conflict with old users' lock files
lock_file_name
=
hash_name
+
model_name
+
".lock"
# mode 0o666 is required for the filelock to be shared across users
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
),
mode
=
0o666
)
return
lock
def
_shared_pointers
(
tensors
):
ptrs
=
defaultdict
(
list
)
for
k
,
v
in
tensors
.
items
():
ptrs
[
v
.
data_ptr
()].
append
(
k
)
failing
=
[]
for
_
,
names
in
ptrs
.
items
():
if
len
(
names
)
>
1
:
failing
.
append
(
names
)
return
failing
def
convert_bin_to_safetensor_file
(
pt_filename
:
str
,
sf_filename
:
str
,
)
->
None
:
loaded
=
torch
.
load
(
pt_filename
,
map_location
=
"cpu"
)
if
"state_dict"
in
loaded
:
loaded
=
loaded
[
"state_dict"
]
shared
=
_shared_pointers
(
loaded
)
for
shared_weights
in
shared
:
for
name
in
shared_weights
[
1
:]:
loaded
.
pop
(
name
)
# For tensors to be contiguous
loaded
=
{
k
:
v
.
contiguous
()
for
k
,
v
in
loaded
.
items
()}
dirname
=
os
.
path
.
dirname
(
sf_filename
)
os
.
makedirs
(
dirname
,
exist_ok
=
True
)
save_file
(
loaded
,
sf_filename
,
metadata
=
{
"format"
:
"pt"
})
# check file size
sf_size
=
os
.
stat
(
sf_filename
).
st_size
pt_size
=
os
.
stat
(
pt_filename
).
st_size
if
(
sf_size
-
pt_size
)
/
pt_size
>
0.01
:
raise
RuntimeError
(
f
"""The file size different is more than 1%:
-
{
sf_filename
}
:
{
sf_size
}
-
{
pt_filename
}
:
{
pt_size
}
"""
)
# check if the tensors are the same
reloaded
=
load_file
(
sf_filename
)
for
k
in
loaded
:
pt_tensor
=
loaded
[
k
]
sf_tensor
=
reloaded
[
k
]
if
not
torch
.
equal
(
pt_tensor
,
sf_tensor
):
raise
RuntimeError
(
f
"The output tensors do not match for key
{
k
}
"
)
# TODO(woosuk): Move this to other place.
def
get_quant_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# GGUF doesn't have config file
if
model_config
.
quantization
==
"gguf"
:
return
quant_cls
.
from_config
({})
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
# some vision model may keep quantization_config in their text_config
hf_text_config
=
getattr
(
model_config
.
hf_config
,
"text_config"
,
None
)
if
hf_quant_config
is
None
and
hf_text_config
is
not
None
:
hf_quant_config
=
getattr
(
hf_text_config
,
"quantization_config"
,
None
)
if
hf_quant_config
is
None
:
# compressed-tensors uses a compressions_config
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"compression_config"
,
None
)
if
hf_quant_config
is
not
None
:
return
quant_cls
.
from_config
(
hf_quant_config
)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
if
model_config
.
quantization
==
"bitsandbytes"
:
if
(
not
load_config
.
model_loader_extra_config
or
"qlora_adapter_name_or_path"
not
in
load_config
.
model_loader_extra_config
):
return
quant_cls
.
from_config
({
"adapter_name_or_path"
:
""
})
model_name_or_path
=
load_config
.
model_loader_extra_config
[
"qlora_adapter_name_or_path"
]
else
:
model_name_or_path
=
model_config
.
model_path
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
if
not
is_local
:
# Download the config files.
with
get_lock
(
model_name_or_path
,
load_config
.
download_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
revision
=
model_config
.
revision
,
allow_patterns
=
"*.json"
,
cache_dir
=
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
tqdm_class
=
DisabledTqdm
,
)
else
:
hf_folder
=
model_name_or_path
possible_config_filenames
=
quant_cls
.
get_config_filenames
()
# If the quantization config is not found, use the default config.
if
not
possible_config_filenames
:
return
quant_cls
()
config_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
"*.json"
))
quant_config_files
=
[
f
for
f
in
config_files
if
any
(
f
.
endswith
(
x
)
for
x
in
possible_config_filenames
)
]
if
len
(
quant_config_files
)
==
0
:
raise
ValueError
(
f
"Cannot find the config file for
{
model_config
.
quantization
}
"
)
if
len
(
quant_config_files
)
>
1
:
raise
ValueError
(
f
"Found multiple config files for
{
model_config
.
quantization
}
: "
f
"
{
quant_config_files
}
"
)
quant_config_file
=
quant_config_files
[
0
]
with
open
(
quant_config_file
)
as
f
:
config
=
json
.
load
(
f
)
if
model_config
.
quantization
==
"bitsandbytes"
:
config
[
"adapter_name_or_path"
]
=
model_name_or_path
elif
model_config
.
quantization
==
"modelopt"
:
if
config
[
"producer"
][
"name"
]
==
"modelopt"
:
return
quant_cls
.
from_config
(
config
)
else
:
raise
ValueError
(
f
"Unsupported quantization config"
f
" found for
{
model_config
.
quantization
}
in
{
f
}
."
)
return
quant_cls
.
from_config
(
config
)
def
download_weights_from_hf
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
allow_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
,
ignore_patterns
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
)
->
str
:
"""Download model weights from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
allow_patterns (List[str]): The allowed patterns for the
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
will be ignored.
Returns:
str: The path to the downloaded model weights.
"""
if
not
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
:
# Before we download we look at that is available:
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
# depending on what is available we download different things
for
pattern
in
allow_patterns
:
matching
=
fnmatch
.
filter
(
file_list
,
pattern
)
if
len
(
matching
)
>
0
:
allow_patterns
=
[
pattern
]
break
logger
.
info
(
"Using model weights format %s"
,
allow_patterns
)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
cache_dir
=
cache_dir
,
tqdm_class
=
DisabledTqdm
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
return
hf_folder
def
download_safetensors_index_file_from_hf
(
model_name_or_path
:
str
,
index_file
:
str
,
cache_dir
:
Optional
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
None
:
"""Download hf safetensors index file from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
revision (Optional[str]): The revision of the model.
"""
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
try
:
# Download the safetensors index file.
hf_hub_download
(
repo_id
=
model_name_or_path
,
filename
=
index_file
,
cache_dir
=
cache_dir
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
# If file not found on remote or locally, we should not fail since
# only some models will have index_file.
except
huggingface_hub
.
utils
.
EntryNotFoundError
:
logger
.
info
(
"No %s found in remote."
,
index_file
)
except
huggingface_hub
.
utils
.
LocalEntryNotFoundError
:
logger
.
info
(
"No %s found in local cache."
,
index_file
)
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the index_file to
# look up which safetensors files should be used.
def
filter_duplicate_safetensors_files
(
hf_weights_files
:
List
[
str
],
hf_folder
:
str
,
index_file
:
str
)
->
List
[
str
]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name
=
os
.
path
.
join
(
hf_folder
,
index_file
)
if
not
os
.
path
.
isfile
(
index_file_name
):
return
hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with
open
(
index_file_name
)
as
f
:
weight_map
=
json
.
load
(
f
)[
"weight_map"
]
weight_files_in_index
=
set
()
for
weight_name
in
weight_map
:
weight_files_in_index
.
add
(
os
.
path
.
join
(
hf_folder
,
weight_map
[
weight_name
]))
# Filter out any fields that are not found in the index file.
hf_weights_files
=
[
f
for
f
in
hf_weights_files
if
f
in
weight_files_in_index
]
return
hf_weights_files
def
filter_files_not_needed_for_inference
(
hf_weights_files
:
List
[
str
])
->
List
[
str
]:
"""
Exclude files that are not needed for inference.
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
"""
blacklist
=
[
"training_args.bin"
,
"optimizer.bin"
,
"optimizer.pt"
,
"scheduler.pt"
,
"scaler.pt"
,
]
hf_weights_files
=
[
f
for
f
in
hf_weights_files
if
not
any
(
f
.
endswith
(
x
)
for
x
in
blacklist
)
]
return
hf_weights_files
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT
=
"{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]
\n
"
# noqa: E501
def
np_cache_weights_iterator
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
hf_folder
:
str
,
hf_weights_files
:
List
[
str
],
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model np files.
Will dump the model weights to numpy files if they are not already dumped.
"""
enable_tqdm
=
(
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_rank
()
==
0
)
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder
=
os
.
path
.
join
(
hf_folder
,
"np"
)
os
.
makedirs
(
np_folder
,
exist_ok
=
True
)
weight_names_file
=
os
.
path
.
join
(
np_folder
,
"weight_names.json"
)
# Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
if
not
os
.
path
.
exists
(
weight_names_file
):
weight_names
:
List
[
str
]
=
[]
for
bin_file
in
tqdm
(
hf_weights_files
,
desc
=
"Loading np_cache checkpoint shards"
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
state
.
items
():
param_path
=
os
.
path
.
join
(
np_folder
,
name
)
with
open
(
param_path
,
"wb"
)
as
f
:
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
weight_names
.
append
(
name
)
with
open
(
weight_names_file
,
"w"
)
as
f
:
json
.
dump
(
weight_names
,
f
)
with
open
(
weight_names_file
)
as
f
:
weight_names
=
json
.
load
(
f
)
for
name
in
weight_names
:
param_path
=
os
.
path
.
join
(
np_folder
,
name
)
with
open
(
param_path
,
"rb"
)
as
f
:
param
=
np
.
load
(
f
)
yield
name
,
torch
.
from_numpy
(
param
)
def
safetensors_weights_iterator
(
hf_weights_files
:
List
[
str
],
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm
=
(
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_rank
()
==
0
)
for
st_file
in
tqdm
(
hf_weights_files
,
desc
=
"Loading safetensors checkpoint shards"
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
):
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
param
=
f
.
get_tensor
(
name
)
yield
name
,
param
def
pt_weights_iterator
(
hf_weights_files
:
List
[
str
],
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model bin/pt files."""
enable_tqdm
=
(
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_rank
()
==
0
)
for
bin_file
in
tqdm
(
hf_weights_files
,
desc
=
"Loading pt checkpoint shards"
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
yield
from
state
.
items
()
del
state
torch
.
cuda
.
empty_cache
()
def
get_gguf_extra_tensor_names
(
gguf_file
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
]
)
->
List
[
str
]:
reader
=
gguf
.
GGUFReader
(
gguf_file
)
expected_gguf_keys
=
set
(
gguf_to_hf_name_map
.
keys
())
exact_gguf_keys
=
set
([
tensor
.
name
for
tensor
in
reader
.
tensors
])
extra_keys
=
expected_gguf_keys
-
exact_gguf_keys
return
[
gguf_to_hf_name_map
[
key
]
for
key
in
extra_keys
]
def
gguf_quant_weights_iterator
(
gguf_file
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""
Iterate over the quant weights in the model gguf files and convert
them to torch tensors
"""
reader
=
gguf
.
GGUFReader
(
gguf_file
)
for
tensor
in
reader
.
tensors
:
if
tensor
.
name
in
gguf_to_hf_name_map
:
weight_type
=
tensor
.
tensor_type
name
=
gguf_to_hf_name_map
[
tensor
.
name
]
if
weight_type
.
name
!=
"F32"
:
weight_type_name
=
name
.
replace
(
"weight"
,
"qweight_type"
)
weight_type
=
torch
.
tensor
(
weight_type
)
yield
weight_type_name
,
weight_type
for
tensor
in
reader
.
tensors
:
if
tensor
.
name
in
gguf_to_hf_name_map
:
weight
=
tensor
.
data
weight_type
=
tensor
.
tensor_type
name
=
gguf_to_hf_name_map
[
tensor
.
name
]
if
weight_type
.
name
!=
"F32"
:
name
=
name
.
replace
(
"weight"
,
"qweight"
)
param
=
torch
.
tensor
(
weight
)
yield
name
,
param
def
convert_pyslice_to_tensor
(
x
:
Any
)
->
torch
.
Tensor
:
"""convert PySafeSlice object from safetensors to torch.Tensor
PySafeSlice object supports indexing, which is done before loading the
actual tensor and can reduce the amount of memory being read into the
memory. However, it does not support more advanced functionalities
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
tensor with these more complicated operators, we need to convert to
tensor first.
"""
if
not
isinstance
(
x
,
torch
.
Tensor
):
x
=
x
[:]
return
x
def
default_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Default weight loader."""
try
:
if
param
.
numel
()
==
1
and
loaded_weight
.
numel
()
==
1
:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param
.
data
.
fill_
(
loaded_weight
.
item
())
else
:
assert
param
.
size
()
==
loaded_weight
.
size
(),
(
f
"Attempted to load weight (
{
loaded_weight
.
size
()
}
) "
f
"into parameter (
{
param
.
size
()
}
)"
)
param
.
data
.
copy_
(
loaded_weight
)
except
Exception
:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def
row_parallel_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Load weights that are row-parallelized."""
tp_rank
=
get_tensor_model_parallel_rank
()
shard_dim
=
0
if
param
.
dim
()
!=
1
else
None
if
shard_dim
is
not
None
:
shard_size
=
param
.
data
.
shape
[
shard_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
start_idx
,
shard_size
)
return
default_weight_loader
(
param
,
loaded_weight
)
LoaderFunction
=
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
]
def
sharded_weight_loader
(
shard_axis
:
int
)
->
LoaderFunction
:
"""Create a weight loader that shards the weights along the given axis"""
def
loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
param
.
data
.
shape
[
shard_axis
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
shard_axis
,
start_idx
,
shard_size
)
return
default_weight_loader
(
param
,
loaded_weight
)
return
loader
def
composed_weight_loader
(
loader
:
LoaderFunction
,
fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
)
->
LoaderFunction
:
"""Create a weight loader that post-processes the weights after loading"""
def
composed_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
loader
(
param
,
loaded_weight
)
param
.
data
.
copy_
(
fn
(
param
))
return
return
composed_loader
def
initialize_dummy_weights
(
model
:
torch
.
nn
.
Module
,
low
:
float
=
-
1e-3
,
high
:
float
=
1e-3
,
seed
:
int
=
1234
,
)
->
None
:
"""Initialize model weights with random values.
The model weights must be randomly initialized for accurate performance
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
We use per-parameter random seed, so that dummy weights are consistent,
even if the model is partitioned across multiple devices. When the seed
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
for
param
in
model
.
state_dict
().
values
():
if
torch
.
is_floating_point
(
param
):
generator
=
torch
.
Generator
(
device
=
param
.
data
.
device
)
generator
.
manual_seed
(
seed
)
if
torch
.
finfo
(
param
.
data
.
dtype
).
bits
<
16
:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype
=
param
.
data
.
dtype
tmp_param
=
param
.
data
.
to
(
torch
.
float16
)
tmp_param
=
tmp_param
.
uniform_
(
low
,
high
,
generator
=
generator
).
to
(
dtype
)
param
.
data
.
copy_
(
tmp_param
)
else
:
param
.
uniform_
(
low
,
high
,
generator
=
generator
)
def
maybe_remap_kv_scale_name
(
name
:
str
,
params_dict
:
dict
)
->
Optional
[
str
]:
"""Remap the name of FP8 k/v_scale parameters.
This function handles the remapping of FP8 k/v_scale parameter names.
It detects if the given name ends with a suffix and attempts to remap
it to the expected name format in the model. If the remapped name is not
found in the params_dict, a warning is printed and None is returned.
Args:
name (str): The original loaded checkpoint parameter name.
params_dict (dict): Dictionary containing the model's named parameters.
Returns:
str: The remapped parameter name if successful, or the original name
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""
if
name
.
endswith
(
".kv_scale"
):
print_warning_once
(
"DEPRECATED. Found kv_scale in the checkpoint. "
"This format is deprecated in favor of separate k_scale and "
"v_scale tensors and will be removed in a future release. "
"Functionally, we will remap kv_scale to k_scale and duplicate "
"k_scale to v_scale"
)
# NOTE: we remap the deprecated kv_scale to k_scale
remapped_name
=
name
.
replace
(
".kv_scale"
,
".attn.k_scale"
)
if
remapped_name
not
in
params_dict
:
print_warning_once
(
f
"Found kv_scale in the checkpoint (e.g.
{
name
}
), "
"but not found the expected name in the model "
f
"(e.g.
{
remapped_name
}
). kv_scale is "
"not loaded."
)
return
None
return
remapped_name
possible_scale_names
=
[
".k_scale"
,
".v_scale"
]
for
scale_name
in
possible_scale_names
:
if
name
.
endswith
(
scale_name
):
remapped_name
=
name
.
replace
(
scale_name
,
f
".attn
{
scale_name
}
"
)
if
remapped_name
not
in
params_dict
:
print_warning_once
(
f
"Found
{
scale_name
}
in the checkpoint (e.g.
{
name
}
), "
"but not found the expected name in the model "
f
"(e.g.
{
remapped_name
}
).
{
scale_name
}
is "
"not loaded."
)
return
None
return
remapped_name
# If there were no matches, return the untouched param name
return
name
python/sglang/srt/models/baichuan.py
View file @
85e1a6f3
...
@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
...
@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
...
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
...
@@ -329,7 +329,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -329,7 +329,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
position_embedding
:
str
,
position_embedding
:
str
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -404,13 +403,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
...
@@ -404,13 +403,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
if
config
.
hidden_size
==
4096
:
# baichuan2 7b
if
config
.
hidden_size
==
4096
:
# baichuan2 7b
super
().
__init__
(
config
,
"ROPE"
,
cache_config
,
quant_config
)
super
().
__init__
(
config
,
"ROPE"
,
quant_config
)
else
:
# baichuan 13b, baichuan2 13b
else
:
# baichuan 13b, baichuan2 13b
super
().
__init__
(
config
,
"ALIBI"
,
cache_config
,
quant_config
)
super
().
__init__
(
config
,
"ALIBI"
,
quant_config
)
EntryClass
=
[
BaichuanForCausalLM
]
EntryClass
=
[
BaichuanForCausalLM
]
python/sglang/srt/models/chatglm.py
View file @
85e1a6f3
...
@@ -23,7 +23,6 @@ from torch import nn
...
@@ -23,7 +23,6 @@ from torch import nn
from
torch.nn
import
LayerNorm
from
torch.nn
import
LayerNorm
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
...
@@ -41,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -41,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
LoraConfig
=
None
LoraConfig
=
None
...
@@ -50,7 +50,6 @@ class GLMAttention(nn.Module):
...
@@ -50,7 +50,6 @@ class GLMAttention(nn.Module):
self
,
self
,
config
,
config
,
layer_id
:
int
=
0
,
layer_id
:
int
=
0
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -186,7 +185,6 @@ class GLMBlock(nn.Module):
...
@@ -186,7 +185,6 @@ class GLMBlock(nn.Module):
self
,
self
,
config
,
config
,
layer_id
:
int
,
layer_id
:
int
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -203,7 +201,7 @@ class GLMBlock(nn.Module):
...
@@ -203,7 +201,7 @@ class GLMBlock(nn.Module):
)
)
# Self attention.
# Self attention.
self
.
self_attention
=
GLMAttention
(
config
,
layer_id
,
cache_config
,
quant_config
)
self
.
self_attention
=
GLMAttention
(
config
,
layer_id
,
quant_config
)
self
.
hidden_dropout
=
config
.
hidden_dropout
self
.
hidden_dropout
=
config
.
hidden_dropout
# Layernorm on the attention output
# Layernorm on the attention output
...
@@ -258,7 +256,6 @@ class GLMTransformer(nn.Module):
...
@@ -258,7 +256,6 @@ class GLMTransformer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -269,10 +266,7 @@ class GLMTransformer(nn.Module):
...
@@ -269,10 +266,7 @@ class GLMTransformer(nn.Module):
# Transformer layers.
# Transformer layers.
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
[
GLMBlock
(
config
,
i
,
quant_config
)
for
i
in
range
(
self
.
num_layers
)]
GLMBlock
(
config
,
i
,
cache_config
,
quant_config
)
for
i
in
range
(
self
.
num_layers
)
]
)
)
if
self
.
post_layer_norm
:
if
self
.
post_layer_norm
:
...
@@ -306,7 +300,6 @@ class ChatGLMM(nn.Module):
...
@@ -306,7 +300,6 @@ class ChatGLMM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -318,7 +311,7 @@ class ChatGLMM(nn.Module):
...
@@ -318,7 +311,7 @@ class ChatGLMM(nn.Module):
self
.
num_layers
=
config
.
num_layers
self
.
num_layers
=
config
.
num_layers
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
kv_channels
=
config
.
kv_channels
self
.
kv_channels
=
config
.
kv_channels
self
.
encoder
=
GLMTransformer
(
config
,
cache_config
,
quant_config
)
self
.
encoder
=
GLMTransformer
(
config
,
quant_config
)
self
.
output_layer
=
ParallelLMHead
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
self
.
output_layer
=
ParallelLMHead
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
...
@@ -357,15 +350,13 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -357,15 +350,13 @@ class ChatGLMForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
ChatGLMConfig
,
config
:
ChatGLMConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoraConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
:
ChatGLMConfig
=
config
self
.
config
:
ChatGLMConfig
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
8192
)
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
8192
)
self
.
transformer
=
ChatGLMM
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
ChatGLMM
(
config
,
quant_config
)
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
python/sglang/srt/models/commandr.py
View file @
85e1a6f3
...
@@ -49,7 +49,6 @@ from vllm.distributed import (
...
@@ -49,7 +49,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
...
@@ -62,6 +61,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
...
@@ -62,6 +61,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
...
@@ -318,7 +318,6 @@ class CohereForCausalLM(nn.Module):
...
@@ -318,7 +318,6 @@ class CohereForCausalLM(nn.Module):
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
...
python/sglang/srt/models/dbrx.py
View file @
85e1a6f3
...
@@ -25,7 +25,6 @@ from vllm.distributed import (
...
@@ -25,7 +25,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
sglang.srt.layers.fused_moe_triton
import
fused_moe
from
sglang.srt.layers.fused_moe_triton
import
fused_moe
...
@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
...
@@ -366,7 +366,6 @@ class DbrxForCausalLM(nn.Module):
...
@@ -366,7 +366,6 @@ class DbrxForCausalLM(nn.Module):
self
,
self
,
config
:
DbrxConfig
,
config
:
DbrxConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment