Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
19d2135c
Unverified
Commit
19d2135c
authored
May 21, 2024
by
Lianmin Zheng
Committed by
GitHub
May 21, 2024
Browse files
Use model loader from vllm (#459)
parent
ced77c66
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
152 additions
and
978 deletions
+152
-978
examples/quick_start/srt_example_yi_vl.py
examples/quick_start/srt_example_yi_vl.py
+2
-0
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+1
-0
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+57
-81
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+4
-5
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+7
-12
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+6
-14
python/sglang/srt/models/dbrx_config.py
python/sglang/srt/models/dbrx_config.py
+0
-281
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+5
-13
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+9
-13
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+6
-15
python/sglang/srt/models/llava_mistral.py
python/sglang/srt/models/llava_mistral.py
+6
-15
python/sglang/srt/models/llava_qwen.py
python/sglang/srt/models/llava_qwen.py
+6
-15
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+7
-18
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+7
-18
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+6
-12
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+4
-12
python/sglang/srt/models/stablelm.py
python/sglang/srt/models/stablelm.py
+6
-14
python/sglang/srt/models/yivl.py
python/sglang/srt/models/yivl.py
+12
-22
python/sglang/srt/weight_utils.py
python/sglang/srt/weight_utils.py
+0
-417
python/sglang/utils.py
python/sglang/utils.py
+1
-1
No files found.
examples/quick_start/srt_example_yi_vl.py
View file @
19d2135c
"""
"""
Usage: python3 srt_example_yi_vl.py
Usage: python3 srt_example_yi_vl.py
Requirements: transformers==4.38
"""
"""
import
sglang
as
sgl
import
sglang
as
sgl
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
19d2135c
...
@@ -41,6 +41,7 @@ from sglang.utils import get_exception_traceback
...
@@ -41,6 +41,7 @@ from sglang.utils import get_exception_traceback
logger
=
logging
.
getLogger
(
"model_rpc"
)
logger
=
logging
.
getLogger
(
"model_rpc"
)
vllm_default_logger
.
setLevel
(
logging
.
WARN
)
vllm_default_logger
.
setLevel
(
logging
.
WARN
)
logging
.
getLogger
(
"vllm.utils"
).
setLevel
(
logging
.
WARN
)
logging
.
getLogger
(
"vllm.utils"
).
setLevel
(
logging
.
WARN
)
logging
.
getLogger
(
"vllm.selector"
).
setLevel
(
logging
.
WARN
)
class
ModelRpcServer
:
class
ModelRpcServer
:
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
19d2135c
import
importlib
import
importlib
import
importlib.resources
import
importlib.resources
import
inspect
import
logging
import
logging
import
pkgutil
import
pkgutil
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
List
from
typing
import
List
,
Optional
,
Type
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.distributed
import
initialize_model_parallel
from
vllm.distributed
import
initialize_model_parallel
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.managers.router.infer_batch
import
Batch
,
ForwardMode
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_multimodal_model
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_multimodal_model
QUANTIZATION_CONFIG_MAPPING
=
{
"awq"
:
AWQConfig
,
"gptq"
:
GPTQConfig
,
"marlin"
:
MarlinConfig
,
}
logger
=
logging
.
getLogger
(
"model_runner"
)
logger
=
logging
.
getLogger
(
"model_runner"
)
...
@@ -31,35 +26,6 @@ logger = logging.getLogger("model_runner")
...
@@ -31,35 +26,6 @@ logger = logging.getLogger("model_runner")
global_server_args_dict
=
{}
global_server_args_dict
=
{}
@
lru_cache
()
def
import_model_classes
():
model_arch_name_to_cls
=
{}
package_name
=
"sglang.srt.models"
package
=
importlib
.
import_module
(
package_name
)
for
_
,
name
,
ispkg
in
pkgutil
.
iter_modules
(
package
.
__path__
,
package_name
+
"."
):
if
not
ispkg
:
module
=
importlib
.
import_module
(
name
)
if
hasattr
(
module
,
"EntryClass"
):
model_arch_name_to_cls
[
module
.
EntryClass
.
__name__
]
=
module
.
EntryClass
return
model_arch_name_to_cls
def
get_model_cls_by_arch_name
(
model_arch_names
):
model_arch_name_to_cls
=
import_model_classes
()
model_class
=
None
for
arch
in
model_arch_names
:
if
arch
in
model_arch_name_to_cls
:
model_class
=
model_arch_name_to_cls
[
arch
]
break
else
:
raise
ValueError
(
f
"Unsupported architectures:
{
arch
}
. "
f
"Supported list:
{
list
(
model_arch_name_to_cls
.
keys
())
}
"
)
return
model_class
@
dataclass
@
dataclass
class
InputMetadata
:
class
InputMetadata
:
model_runner
:
"ModelRunner"
model_runner
:
"ModelRunner"
...
@@ -287,49 +253,32 @@ class ModelRunner:
...
@@ -287,49 +253,32 @@ class ModelRunner:
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
def
load_model
(
self
):
def
load_model
(
self
):
"""See also vllm/model_executor/model_loader.py::get_model"""
# Select model class
architectures
=
getattr
(
self
.
model_config
.
hf_config
,
"architectures"
,
[])
model_class
=
get_model_cls_by_arch_name
(
architectures
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight begin."
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight begin."
)
# Load weights
device_config
=
DeviceConfig
()
quant_config
=
None
load_config
=
LoadConfig
()
vllm_model_config
=
VllmModelConfig
(
quant_cfg
=
getattr
(
self
.
model_config
.
hf_config
,
"quantization_config"
,
None
)
model
=
self
.
model_config
.
path
,
if
quant_cfg
is
not
None
:
tokenizer
=
None
,
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
tokenizer_mode
=
None
,
# compat: autogptq >=0.8.0 use checkpoint_format: str
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
# compat: autogptq <=0.7.1 is_marlin_format: bool
dtype
=
torch
.
float16
,
is_format_marlin
=
quant_cfg
.
get
(
seed
=
42
,
"checkpoint_format"
revision
=
self
.
model_config
.
revision
,
)
==
"marlin"
or
quant_cfg
.
get
(
"is_marlin_format"
,
False
)
skip_tokenizer_init
=
True
,
)
# Use marlin if the GPTQ model is serialized in marlin format.
if
self
.
model_config
.
model_overide_args
is
not
None
:
if
quant_method
==
"gptq"
and
is_format_marlin
:
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_overide_args
)
quant_method
=
"marlin"
self
.
model
=
get_model
(
quant_config_class
=
QUANTIZATION_CONFIG_MAPPING
.
get
(
quant_method
)
model_config
=
vllm_model_config
,
device_config
=
device_config
,
if
quant_config_class
is
None
:
load_config
=
load_config
,
raise
ValueError
(
f
"Unsupported quantization method:
{
quant_method
}
"
)
lora_config
=
None
,
vision_language_config
=
None
,
quant_config
=
quant_config_class
.
from_config
(
quant_cfg
)
parallel_config
=
None
,
logger
.
info
(
f
"quant_config:
{
quant_config
}
"
)
scheduler_config
=
None
,
)
with
set_default_torch_dtype
(
torch
.
float16
):
with
torch
.
device
(
"cuda"
):
model
=
model_class
(
config
=
self
.
model_config
.
hf_config
,
quant_config
=
quant_config
)
model
.
load_weights
(
self
.
model_config
.
path
,
cache_dir
=
None
,
load_format
=
self
.
load_format
,
revision
=
None
,
)
self
.
model
=
model
.
eval
()
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight end."
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight end."
)
def
profile_max_num_token
(
self
,
total_gpu_memory
):
def
profile_max_num_token
(
self
,
total_gpu_memory
):
...
@@ -455,3 +404,30 @@ class ModelRunner:
...
@@ -455,3 +404,30 @@ class ModelRunner:
return
self
.
forward_prefill
(
batch
)
return
self
.
forward_prefill
(
batch
)
else
:
else
:
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
raise
ValueError
(
f
"Invaid forward mode:
{
forward_mode
}
"
)
@
lru_cache
()
def
import_model_classes
():
model_arch_name_to_cls
=
{}
package_name
=
"sglang.srt.models"
package
=
importlib
.
import_module
(
package_name
)
for
_
,
name
,
ispkg
in
pkgutil
.
iter_modules
(
package
.
__path__
,
package_name
+
"."
):
if
not
ispkg
:
module
=
importlib
.
import_module
(
name
)
if
hasattr
(
module
,
"EntryClass"
):
model_arch_name_to_cls
[
module
.
EntryClass
.
__name__
]
=
module
.
EntryClass
return
model_arch_name_to_cls
def
load_model_cls_srt
(
model_arch
:
str
)
->
Optional
[
Type
[
nn
.
Module
]]:
model_arch_name_to_cls
=
import_model_classes
()
if
model_arch
not
in
model_arch_name_to_cls
:
raise
ValueError
(
f
"Unsupported architectures:
{
model_arch
}
. "
f
"Supported list:
{
list
(
model_arch_name_to_cls
.
keys
())
}
"
)
return
model_arch_name_to_cls
[
model_arch
]
# Monkey patch model loader
setattr
(
ModelRegistry
,
"load_model_cls"
,
load_model_cls_srt
)
\ No newline at end of file
python/sglang/srt/model_config.py
View file @
19d2135c
...
@@ -15,10 +15,9 @@ class ModelConfig:
...
@@ -15,10 +15,9 @@ class ModelConfig:
self
.
path
=
path
self
.
path
=
path
self
.
trust_remote_code
=
trust_remote_code
self
.
trust_remote_code
=
trust_remote_code
self
.
revision
=
revision
self
.
revision
=
revision
self
.
hf_config
=
get_config
(
self
.
path
,
trust_remote_code
,
revision
)
self
.
model_overide_args
=
model_overide_args
self
.
hf_config
=
get_config
(
self
.
path
,
trust_remote_code
,
revision
,
if
model_overide_args
is
not
None
:
model_overide_args
=
model_overide_args
)
self
.
hf_config
.
update
(
model_overide_args
)
if
context_length
is
not
None
:
if
context_length
is
not
None
:
self
.
context_len
=
context_length
self
.
context_len
=
context_length
...
@@ -44,4 +43,4 @@ class ModelConfig:
...
@@ -44,4 +43,4 @@ class ModelConfig:
self
.
num_key_value_heads
=
self
.
num_attention_heads
self
.
num_key_value_heads
=
self
.
num_attention_heads
self
.
hidden_size
=
self
.
hf_config
.
hidden_size
self
.
hidden_size
=
self
.
hf_config
.
hidden_size
self
.
num_hidden_layers
=
self
.
hf_config
.
num_hidden_layers
self
.
num_hidden_layers
=
self
.
hf_config
.
num_hidden_layers
self
.
vocab_size
=
self
.
hf_config
.
vocab_size
self
.
vocab_size
=
self
.
hf_config
.
vocab_size
\ No newline at end of file
python/sglang/srt/models/commandr.py
View file @
19d2135c
...
@@ -18,9 +18,12 @@
...
@@ -18,9 +18,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/commandr.py#L1
# This file is based on the LLama model definition file in transformers
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
"""PyTorch Cohere model."""
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Iterable
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
...
@@ -41,11 +44,11 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
...
@@ -41,11 +44,11 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
@
torch
.
compile
@
torch
.
compile
...
@@ -324,13 +327,7 @@ class CohereForCausalLM(nn.Module):
...
@@ -324,13 +327,7 @@ class CohereForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
)
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -341,9 +338,7 @@ class CohereForCausalLM(nn.Module):
...
@@ -341,9 +338,7 @@ class CohereForCausalLM(nn.Module):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
loaded_params
=
set
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
if
shard_name
not
in
name
:
continue
continue
...
...
python/sglang/srt/models/dbrx.py
View file @
19d2135c
# Adapted from:
# Adapted from:
# https://github.com/vllm-project/vllm/blob/
14ccd94c89d0ffd9da283545d93ab1dfea5da340
/vllm/model_executor/models/dbrx.py
# https://github.com/vllm-project/vllm/blob/
c7f2cf2b7f67bce5842fedfdba508440fe257375
/vllm/model_executor/models/dbrx.py
#L1
# coding=utf-8
# coding=utf-8
from
typing
import
Optional
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -24,12 +24,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -24,12 +24,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.models.dbrx_config
import
DbrxConfig
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
DbrxRouter
(
nn
.
Module
):
class
DbrxRouter
(
nn
.
Module
):
...
@@ -377,13 +377,7 @@ class DbrxForCausalLM(nn.Module):
...
@@ -377,13 +377,7 @@ class DbrxForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
expert_params_mapping
=
[
expert_params_mapping
=
[
(
(
"ws"
if
weight_name
in
[
"w1"
,
"v1"
]
else
"w2s"
,
"ws"
if
weight_name
in
[
"w1"
,
"v1"
]
else
"w2s"
,
...
@@ -392,9 +386,7 @@ class DbrxForCausalLM(nn.Module):
...
@@ -392,9 +386,7 @@ class DbrxForCausalLM(nn.Module):
for
weight_name
in
[
"w1"
,
"v1"
,
"w2"
]
for
weight_name
in
[
"w1"
,
"v1"
,
"w2"
]
]
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
for
param_name
,
weight_name
in
expert_params_mapping
:
for
param_name
,
weight_name
in
expert_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
...
...
python/sglang/srt/models/dbrx_config.py
deleted
100644 → 0
View file @
ced77c66
# Adapted from:
# https://github.com/vllm-project/vllm/blob/14ccd94c89d0ffd9da283545d93ab1dfea5da340/vllm/transformers_utils/configs/dbrx.py
# yapf: disable
# ruff: noqa: E501
# coding=utf-8
# Copied from
# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
"""Dbrx configuration."""
# FIXME: remove this once vllm releases a new version
from
typing
import
Any
,
Optional
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
class
DbrxAttentionConfig
(
PretrainedConfig
):
"""Configuration class for Dbrx Attention.
[`DbrxAttention`] class. It is used to instantiate attention layers
according to the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
attn_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the attention layers.
clip_qkv (`float`, *optional*, defaults to None):
If not `None`, clip the queries, keys, and values in the attention layer to this value.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
rope_theta (float): The base frequency for rope.
"""
def
__init__
(
self
,
attn_pdrop
:
float
=
0
,
clip_qkv
:
Optional
[
float
]
=
None
,
kv_n_heads
:
int
=
1
,
rope_theta
:
float
=
10000.0
,
**
kwargs
:
Any
,
):
super
().
__init__
(
**
kwargs
)
self
.
attn_pdrop
=
attn_pdrop
self
.
clip_qkv
=
clip_qkv
self
.
kv_n_heads
=
kv_n_heads
self
.
rope_theta
=
rope_theta
for
k
in
[
"model_type"
]:
if
k
in
kwargs
:
kwargs
.
pop
(
k
)
if
len
(
kwargs
)
!=
0
:
raise
ValueError
(
f
"Found unknown
{
kwargs
=
}
"
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
,
**
kwargs
:
Any
)
->
"PretrainedConfig"
:
cls
.
_set_token_in_kwargs
(
kwargs
)
config_dict
,
kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
,
**
kwargs
)
if
config_dict
.
get
(
"model_type"
)
==
"dbrx"
:
config_dict
=
config_dict
[
"attn_config"
]
if
(
"model_type"
in
config_dict
and
hasattr
(
cls
,
"model_type"
)
and
config_dict
[
"model_type"
]
!=
cls
.
model_type
):
logger
.
warning
(
f
"You are using a model of type
{
config_dict
[
'model_type'
]
}
to instantiate a model of type "
+
f
"
{
cls
.
model_type
}
. This is not supported for all configurations of models and can yield errors."
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
class
DbrxFFNConfig
(
PretrainedConfig
):
"""Configuration class for Dbrx FFN.
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
The dict should have a key 'name' with the value being the name of
the activation function along with any additional keyword arguments.
ffn_hidden_size (int, optional): The hidden size of the feedforward network.
moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
This should only be used for benchmarking purposes.
"""
def
__init__
(
self
,
ffn_act_fn
:
Optional
[
dict
]
=
None
,
ffn_hidden_size
:
int
=
3584
,
moe_num_experts
:
int
=
4
,
moe_top_k
:
int
=
1
,
moe_jitter_eps
:
Optional
[
float
]
=
None
,
moe_loss_weight
:
float
=
0.01
,
moe_normalize_expert_weights
:
Optional
[
float
]
=
1
,
uniform_expert_assignment
:
bool
=
False
,
**
kwargs
:
Any
,
):
super
().
__init__
()
if
ffn_act_fn
is
None
:
ffn_act_fn
=
{
"name"
:
"silu"
}
self
.
ffn_act_fn
=
ffn_act_fn
self
.
ffn_hidden_size
=
ffn_hidden_size
self
.
moe_num_experts
=
moe_num_experts
self
.
moe_top_k
=
moe_top_k
self
.
moe_jitter_eps
=
moe_jitter_eps
self
.
moe_loss_weight
=
moe_loss_weight
self
.
moe_normalize_expert_weights
=
moe_normalize_expert_weights
self
.
uniform_expert_assignment
=
uniform_expert_assignment
for
k
in
[
"model_type"
]:
if
k
in
kwargs
:
kwargs
.
pop
(
k
)
if
len
(
kwargs
)
!=
0
:
raise
ValueError
(
f
"Found unknown
{
kwargs
=
}
"
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
,
**
kwargs
:
Any
)
->
"PretrainedConfig"
:
cls
.
_set_token_in_kwargs
(
kwargs
)
config_dict
,
kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
,
**
kwargs
)
if
config_dict
.
get
(
"model_type"
)
==
"dbrx"
:
config_dict
=
config_dict
[
"ffn_config"
]
if
(
"model_type"
in
config_dict
and
hasattr
(
cls
,
"model_type"
)
and
config_dict
[
"model_type"
]
!=
cls
.
model_type
):
logger
.
warning
(
f
"You are using a model of type
{
config_dict
[
'model_type'
]
}
to instantiate a model of type "
+
f
"
{
cls
.
model_type
}
. This is not supported for all configurations of models and can yield errors."
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
class
DbrxConfig
(
PretrainedConfig
):
"""Configuration class for Dbrx.
[`DbrxModel`]. It is used to instantiate a Dbrx model according to the
specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
d_model (`int`, *optional*, defaults to 6144):
Dimensionality of the embeddings and hidden states.
n_heads (`int`, *optional*, defaults to 48):
Number of attention heads for each attention layer in the Transformer encoder.
n_layers (`int`, *optional*, defaults to 40):
Number of hidden layers in the Transformer encoder.
max_seq_len (`int`, *optional*, defaults to 32768):
The maximum sequence length of the model.
vocab_size (`int`, *optional*, defaults to 100352):
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
the `inputs_ids` passed when calling [`DbrxModel`].
resid_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability applied to the attention output before combining with residual.
emb_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the embedding layer.
attn_config (`dict`, *optional*):
A dictionary used to configure the model's attention module.
ffn_config (`dict`, *optional*):
A dictionary used to configure the model's FFN module.
use_cache (`bool`, *optional*, defaults to `False`):
Whether or not the model should return the last key/values attentions (not used by all models).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
Example:
```python
>>> from transformers import DbrxConfig, DbrxModel
>>> # Initializing a Dbrx configuration
>>> configuration = DbrxConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = DbrxModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type
=
"dbrx"
attribute_map
=
{
"num_attention_heads"
:
"n_heads"
,
"hidden_size"
:
"d_model"
,
"num_hidden_layers"
:
"n_layers"
,
"max_position_embeddings"
:
"max_seq_len"
,
}
def
__init__
(
self
,
d_model
:
int
=
2048
,
n_heads
:
int
=
16
,
n_layers
:
int
=
24
,
max_seq_len
:
int
=
2048
,
vocab_size
:
int
=
32000
,
resid_pdrop
:
float
=
0.0
,
emb_pdrop
:
float
=
0.0
,
attn_config
:
Optional
[
DbrxAttentionConfig
]
=
None
,
ffn_config
:
Optional
[
DbrxFFNConfig
]
=
None
,
use_cache
:
bool
=
True
,
initializer_range
:
float
=
0.02
,
output_router_logits
:
bool
=
False
,
router_aux_loss_coef
:
float
=
0.05
,
**
kwargs
:
Any
,
):
if
attn_config
is
None
:
self
.
attn_config
=
DbrxAttentionConfig
()
elif
isinstance
(
attn_config
,
dict
):
self
.
attn_config
=
DbrxAttentionConfig
(
**
attn_config
)
else
:
self
.
attn_config
=
attn_config
if
ffn_config
is
None
:
self
.
ffn_config
=
DbrxFFNConfig
()
elif
isinstance
(
ffn_config
,
dict
):
self
.
ffn_config
=
DbrxFFNConfig
(
**
ffn_config
)
else
:
self
.
ffn_config
=
ffn_config
self
.
d_model
=
d_model
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
resid_pdrop
=
resid_pdrop
self
.
emb_pdrop
=
emb_pdrop
self
.
use_cache
=
use_cache
self
.
initializer_range
=
initializer_range
self
.
output_router_logits
=
output_router_logits
self
.
router_aux_loss_coef
=
router_aux_loss_coef
tie_word_embeddings
=
kwargs
.
pop
(
"tie_word_embeddings"
,
False
)
if
tie_word_embeddings
:
raise
ValueError
(
"tie_word_embeddings is not supported for Dbrx models."
)
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
python/sglang/srt/models/gemma.py
View file @
19d2135c
# Adapted from:
# Adapted from:
# https://github.com/vllm-project/vllm/blob/
d65fac2738f0287a41955b45df76a2d5a919bff6
/vllm/model_executor/models/gemma.py
# https://github.com/vllm-project/vllm/blob/
c7f2cf2b7f67bce5842fedfdba508440fe257375
/vllm/model_executor/models/gemma.py
#L1
"""Inference-only Gemma model compatible with HuggingFace weights."""
"""Inference-only Gemma model compatible with HuggingFace weights."""
from
typing
import
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -18,11 +18,11 @@ from vllm.model_executor.layers.linear import (
...
@@ -18,11 +18,11 @@ from vllm.model_executor.layers.linear import (
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
GemmaMLP
(
nn
.
Module
):
class
GemmaMLP
(
nn
.
Module
):
...
@@ -285,13 +285,7 @@ class GemmaForCausalLM(nn.Module):
...
@@ -285,13 +285,7 @@ class GemmaForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
)
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -302,9 +296,7 @@ class GemmaForCausalLM(nn.Module):
...
@@ -302,9 +296,7 @@ class GemmaForCausalLM(nn.Module):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
loaded_params
=
set
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
if
shard_name
not
in
name
:
continue
continue
...
...
python/sglang/srt/models/llama2.py
View file @
19d2135c
# Adapted from
# Adapted from
# https://github.com/vllm-project/vllm/blob/
671af2b1c0b3ed6d856d37c21a561cc429a10701
/vllm/model_executor/models/llama.py#L1
# https://github.com/vllm-project/vllm/blob/
c7f2cf2b7f67bce5842fedfdba508440fe257375
/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Iterable
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -20,11 +20,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -20,11 +20,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
LlamaMLP
(
nn
.
Module
):
class
LlamaMLP
(
nn
.
Module
):
...
@@ -152,6 +152,10 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -152,6 +152,10 @@ class LlamaDecoderLayer(nn.Module):
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
LlamaAttention
(
self
.
self_attn
=
LlamaAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
@@ -270,13 +274,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -270,13 +274,7 @@ class LlamaForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -286,9 +284,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -286,9 +284,7 @@ class LlamaForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
...
...
python/sglang/srt/models/llava.py
View file @
19d2135c
"""Inference-only LLaVa model compatible with HuggingFace weights."""
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
from
typing
import
List
,
Iterable
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -8,6 +8,7 @@ from torch import nn
...
@@ -8,6 +8,7 @@ from torch import nn
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
...
@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import (
...
@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import (
unpad_image_shape
,
unpad_image_shape
,
)
)
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
LlavaLlamaForCausalLM
(
nn
.
Module
):
class
LlavaLlamaForCausalLM
(
nn
.
Module
):
...
@@ -233,13 +233,7 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -233,13 +233,7 @@ class LlavaLlamaForCausalLM(nn.Module):
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
# load clip vision model by cfg['mm_vision_tower']:
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path
=
self
.
config
.
mm_vision_tower
vision_path
=
self
.
config
.
mm_vision_tower
...
@@ -272,9 +266,8 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -272,9 +266,8 @@ class LlavaLlamaForCausalLM(nn.Module):
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
}
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
weights
=
list
(
weights
)
model_name_or_path
,
cache_dir
,
load_format
,
revision
for
name
,
loaded_weight
in
weights
:
):
# FIXME: why projector weights read two times?
# FIXME: why projector weights read two times?
if
"projector"
in
name
or
"vision_tower"
in
name
:
if
"projector"
in
name
or
"vision_tower"
in
name
:
for
weight_name
,
param_name
in
projector_weights
.
items
():
for
weight_name
,
param_name
in
projector_weights
.
items
():
...
@@ -285,9 +278,7 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -285,9 +278,7 @@ class LlavaLlamaForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
# load language model
# load language model
self
.
language_model
.
load_weights
(
self
.
language_model
.
load_weights
(
weights
)
model_name_or_path
,
cache_dir
,
load_format
,
revision
)
monkey_path_clip_vision_embed_forward
()
monkey_path_clip_vision_embed_forward
()
...
...
python/sglang/srt/models/llava_mistral.py
View file @
19d2135c
"""Inference-only LLaVa model compatible with HuggingFace weights."""
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
from
typing
import
List
,
Iterable
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -8,6 +8,7 @@ from torch import nn
...
@@ -8,6 +8,7 @@ from torch import nn
from
transformers
import
CLIPVisionConfig
,
CLIPVisionModel
,
LlavaConfig
,
MistralConfig
from
transformers
import
CLIPVisionConfig
,
CLIPVisionModel
,
LlavaConfig
,
MistralConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
...
@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import (
...
@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import (
unpad_image_shape
,
unpad_image_shape
,
)
)
from
sglang.srt.models.mistral
import
MistralForCausalLM
from
sglang.srt.models.mistral
import
MistralForCausalLM
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
LlavaMistralForCausalLM
(
nn
.
Module
):
class
LlavaMistralForCausalLM
(
nn
.
Module
):
...
@@ -246,13 +246,7 @@ class LlavaMistralForCausalLM(nn.Module):
...
@@ -246,13 +246,7 @@ class LlavaMistralForCausalLM(nn.Module):
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
# load clip vision model by cfg['mm_vision_tower']:
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path
=
self
.
config
.
mm_vision_tower
vision_path
=
self
.
config
.
mm_vision_tower
...
@@ -285,9 +279,8 @@ class LlavaMistralForCausalLM(nn.Module):
...
@@ -285,9 +279,8 @@ class LlavaMistralForCausalLM(nn.Module):
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
}
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
weights
=
list
(
weights
)
model_name_or_path
,
cache_dir
,
load_format
,
revision
for
name
,
loaded_weight
in
weights
:
):
# FIXME: why projector weights read two times?
# FIXME: why projector weights read two times?
if
"projector"
in
name
or
"vision_tower"
in
name
:
if
"projector"
in
name
or
"vision_tower"
in
name
:
for
weight_name
,
param_name
in
projector_weights
.
items
():
for
weight_name
,
param_name
in
projector_weights
.
items
():
...
@@ -298,9 +291,7 @@ class LlavaMistralForCausalLM(nn.Module):
...
@@ -298,9 +291,7 @@ class LlavaMistralForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
# load language model
# load language model
self
.
language_model
.
load_weights
(
self
.
language_model
.
load_weights
(
weights
)
model_name_or_path
,
cache_dir
,
load_format
,
revision
)
monkey_path_clip_vision_embed_forward
()
monkey_path_clip_vision_embed_forward
()
...
...
python/sglang/srt/models/llava_qwen.py
View file @
19d2135c
"""Inference-only LLaVa model compatible with HuggingFace weights."""
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
from
typing
import
List
,
Iterable
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -8,6 +8,7 @@ from torch import nn
...
@@ -8,6 +8,7 @@ from torch import nn
from
transformers
import
CLIPVisionConfig
,
CLIPVisionModel
,
LlavaConfig
,
Qwen2Config
from
transformers
import
CLIPVisionConfig
,
CLIPVisionModel
,
LlavaConfig
,
Qwen2Config
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
...
@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import (
...
@@ -17,7 +18,6 @@ from sglang.srt.mm_utils import (
unpad_image_shape
,
unpad_image_shape
,
)
)
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
LlavaQwenForCausalLM
(
nn
.
Module
):
class
LlavaQwenForCausalLM
(
nn
.
Module
):
...
@@ -246,13 +246,7 @@ class LlavaQwenForCausalLM(nn.Module):
...
@@ -246,13 +246,7 @@ class LlavaQwenForCausalLM(nn.Module):
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
# load clip vision model by cfg['mm_vision_tower']:
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path
=
self
.
config
.
mm_vision_tower
vision_path
=
self
.
config
.
mm_vision_tower
...
@@ -285,9 +279,8 @@ class LlavaQwenForCausalLM(nn.Module):
...
@@ -285,9 +279,8 @@ class LlavaQwenForCausalLM(nn.Module):
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
}
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
weights
=
list
(
weights
)
model_name_or_path
,
cache_dir
,
load_format
,
revision
for
name
,
loaded_weight
in
weights
:
):
# FIXME: why projector weights read two times?
# FIXME: why projector weights read two times?
if
"projector"
in
name
or
"vision_tower"
in
name
:
if
"projector"
in
name
or
"vision_tower"
in
name
:
for
weight_name
,
param_name
in
projector_weights
.
items
():
for
weight_name
,
param_name
in
projector_weights
.
items
():
...
@@ -298,9 +291,7 @@ class LlavaQwenForCausalLM(nn.Module):
...
@@ -298,9 +291,7 @@ class LlavaQwenForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
# load language model
# load language model
self
.
language_model
.
load_weights
(
self
.
language_model
.
load_weights
(
weights
)
model_name_or_path
,
cache_dir
,
load_format
,
revision
)
monkey_path_clip_vision_embed_forward
()
monkey_path_clip_vision_embed_forward
()
...
...
python/sglang/srt/models/llavavid.py
View file @
19d2135c
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
import
os
from
typing
import
List
,
Iterable
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
CLIPVisionModel
,
LlamaConfig
,
LlavaConfig
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.infer_batch
import
ForwardMode
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
...
@@ -18,7 +18,6 @@ from sglang.srt.mm_utils import (
...
@@ -18,7 +18,6 @@ from sglang.srt.mm_utils import (
unpad_image_shape
,
unpad_image_shape
,
)
)
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
LlavaVidForCausalLM
(
nn
.
Module
):
class
LlavaVidForCausalLM
(
nn
.
Module
):
...
@@ -65,7 +64,6 @@ class LlavaVidForCausalLM(nn.Module):
...
@@ -65,7 +64,6 @@ class LlavaVidForCausalLM(nn.Module):
pad_ids
=
pad_value
*
(
pad_ids
=
pad_value
*
(
(
new_image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
(
new_image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
)
)
# print(input_ids)
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
# old_len + pad_len - 1, because we need to remove image_token_id
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids
=
(
new_input_ids
=
(
...
@@ -200,13 +198,7 @@ class LlavaVidForCausalLM(nn.Module):
...
@@ -200,13 +198,7 @@ class LlavaVidForCausalLM(nn.Module):
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
# load clip vision model by cfg['mm_vision_tower']:
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path
=
self
.
config
.
mm_vision_tower
vision_path
=
self
.
config
.
mm_vision_tower
...
@@ -244,9 +236,8 @@ class LlavaVidForCausalLM(nn.Module):
...
@@ -244,9 +236,8 @@ class LlavaVidForCausalLM(nn.Module):
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
}
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
weights
=
list
(
weights
)
model_name_or_path
,
cache_dir
,
load_format
,
revision
for
name
,
loaded_weight
in
weights
:
):
# FIXME: why projector weights read two times?
# FIXME: why projector weights read two times?
if
"projector"
in
name
or
"vision_tower"
in
name
:
if
"projector"
in
name
or
"vision_tower"
in
name
:
for
weight_name
,
param_name
in
projector_weights
.
items
():
for
weight_name
,
param_name
in
projector_weights
.
items
():
...
@@ -261,9 +252,7 @@ class LlavaVidForCausalLM(nn.Module):
...
@@ -261,9 +252,7 @@ class LlavaVidForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
# load language model
# load language model
self
.
language_model
.
load_weights
(
self
.
language_model
.
load_weights
(
weights
)
model_name_or_path
,
cache_dir
,
load_format
,
revision
)
monkey_path_clip_vision_embed_forward
()
monkey_path_clip_vision_embed_forward
()
...
...
python/sglang/srt/models/mixtral.py
View file @
19d2135c
# Adapted from
# Adapted from
# https://github.com/vllm-project/vllm/blob/
d0215a58e78572d91dadafe9d832a2db89b09a13
/vllm/model_executor/models/mixtral.py#L1
# https://github.com/vllm-project/vllm/blob/
c7f2cf2b7f67bce5842fedfdba508440fe257375
/vllm/model_executor/models/mixtral
_quant
.py#L1
"""Inference-only Mixtral model."""
"""Inference-only Mixtral model."""
from
typing
import
Optional
from
typing
import
Iterable
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -25,11 +25,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -25,11 +25,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
MixtralMLP
(
nn
.
Module
):
class
MixtralMLP
(
nn
.
Module
):
...
@@ -107,7 +108,7 @@ class MixtralMoE(nn.Module):
...
@@ -107,7 +108,7 @@ class MixtralMoE(nn.Module):
]
]
)
)
self
.
gate
=
ReplicatedLinear
(
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
linear_method
=
None
config
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
quant_config
=
None
)
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -333,13 +334,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -333,13 +334,7 @@ class MixtralForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -348,13 +343,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -348,13 +343,7 @@ class MixtralForCausalLM(nn.Module):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
,
fall_back_to_pt
=
False
,
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
...
python/sglang/srt/models/qwen.py
View file @
19d2135c
from
typing
import
Any
,
Dict
,
Optional
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
from
typing
import
Any
,
Dict
,
Optional
,
Iterable
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -17,11 +19,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -17,11 +19,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
QWenMLP
(
nn
.
Module
):
class
QWenMLP
(
nn
.
Module
):
...
@@ -245,22 +247,14 @@ class QWenLMHeadModel(nn.Module):
...
@@ -245,22 +247,14 @@ class QWenLMHeadModel(nn.Module):
)
)
return
next_tokens
return
next_tokens
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"w2"
,
0
),
(
"gate_up_proj"
,
"w2"
,
0
),
(
"gate_up_proj"
,
"w1"
,
1
),
(
"gate_up_proj"
,
"w1"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
...
python/sglang/srt/models/qwen2.py
View file @
19d2135c
# Adapted from llama2.py
# Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model.
# Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Iterable
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -19,11 +19,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -19,11 +19,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
Qwen2Config
=
None
Qwen2Config
=
None
...
@@ -271,13 +271,7 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -271,13 +271,7 @@ class Qwen2ForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -287,9 +281,7 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -287,9 +281,7 @@ class Qwen2ForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
...
...
python/sglang/srt/models/stablelm.py
View file @
19d2135c
#
This code is based on
:
#
Adapted from
:
# https://github.com/vllm-project/vllm/blob/
main
/vllm/model_executor/models/stablelm.py
# https://github.com/vllm-project/vllm/blob/
c7f2cf2b7f67bce5842fedfdba508440fe257375
/vllm/model_executor/models/stablelm.py
#L1
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
model compatible with HuggingFace weights."""
model compatible with HuggingFace weights."""
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Iterable
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -20,11 +20,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -20,11 +20,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
StablelmMLP
(
nn
.
Module
):
class
StablelmMLP
(
nn
.
Module
):
...
@@ -245,13 +245,7 @@ class StableLmForCausalLM(nn.Module):
...
@@ -245,13 +245,7 @@ class StableLmForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
)
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -261,9 +255,7 @@ class StableLmForCausalLM(nn.Module):
...
@@ -261,9 +255,7 @@ class StableLmForCausalLM(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
...
...
python/sglang/srt/models/yivl.py
View file @
19d2135c
"""Inference-only Yi-VL model."""
"""Inference-only Yi-VL model."""
import
os
from
typing
import
Tuple
,
Iterable
from
typing
import
List
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llava
import
(
from
sglang.srt.models.llava
import
(
LlavaLlamaForCausalLM
,
LlavaLlamaForCausalLM
,
clip_vision_embed_forward
,
monkey_path_clip_vision_embed_forward
,
monkey_path_clip_vision_embed_forward
,
)
)
from
sglang.srt.weight_utils
import
default_weight_loader
,
hf_model_weights_iterator
class
YiVLForCausalLM
(
LlavaLlamaForCausalLM
):
class
YiVLForCausalLM
(
LlavaLlamaForCausalLM
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
.
config
=
kwargs
[
"config"
]
self
,
config
,
quant_config
=
None
,
super
().
__init__
(
self
.
config
)
)
->
None
:
super
().
__init__
(
config
,
quant_config
)
self
.
multi_modal_projector
=
YiVLMultiModalProjector
(
self
.
config
)
self
.
multi_modal_projector
=
YiVLMultiModalProjector
(
self
.
config
)
self
.
vision_tower_subfolder
=
self
.
config
.
mm_vision_tower
.
replace
(
self
.
vision_tower_subfolder
=
self
.
config
.
mm_vision_tower
.
replace
(
"./"
,
""
"./"
,
""
)
# Everything after "./"
)
# Everything after "./"
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
self
.
vision_tower
=
CLIPVisionModel
.
from_pretrained
(
self
.
vision_tower
=
CLIPVisionModel
.
from_pretrained
(
model
_name_or_path
,
self
.
config
.
_name_or_path
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
subfolder
=
self
.
vision_tower_subfolder
,
subfolder
=
self
.
vision_tower_subfolder
,
).
cuda
()
).
cuda
()
...
@@ -68,9 +61,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
...
@@ -68,9 +61,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
}
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
weights
=
list
(
weights
)
model_name_or_path
,
cache_dir
,
load_format
,
revision
for
name
,
loaded_weight
in
weights
:
):
if
"projector"
in
name
or
"vision_tower"
in
name
:
if
"projector"
in
name
or
"vision_tower"
in
name
:
for
weight_name
,
param_name
in
projector_weights
.
items
():
for
weight_name
,
param_name
in
projector_weights
.
items
():
if
weight_name
in
name
:
if
weight_name
in
name
:
...
@@ -80,9 +72,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
...
@@ -80,9 +72,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
# load language model
# load language model
self
.
language_model
.
load_weights
(
self
.
language_model
.
load_weights
(
weights
)
model_name_or_path
,
cache_dir
,
load_format
,
revision
)
monkey_path_clip_vision_embed_forward
()
monkey_path_clip_vision_embed_forward
()
...
@@ -103,7 +93,7 @@ class YiVLMultiModalProjector(nn.Module):
...
@@ -103,7 +93,7 @@ class YiVLMultiModalProjector(nn.Module):
def
forward
(
self
,
image_features
):
def
forward
(
self
,
image_features
):
hidden_states
=
self
.
linear_1
(
image_features
)
hidden_states
=
self
.
linear_1
(
image_features
)
hidden_state
=
self
.
ln_1
(
hidden_states
)
hidden_state
s
=
self
.
ln_1
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
hidden_states
=
self
.
ln_2
(
hidden_states
)
hidden_states
=
self
.
ln_2
(
hidden_states
)
...
...
python/sglang/srt/weight_utils.py
deleted
100644 → 0
View file @
ced77c66
# The PR(https://github.com/vllm-project/vllm/pull/4097) of vllm borken the sglang code.
# In order to adapt to the latest code without modifying too much code,
# copied the previous vllm/model_executor/weight_utils.py
# Copied in https://github.com/vllm-project/vllm/blob/05434764cd99990035779cf9a4ed86623b528825/vllm/model_executor/weight_utils.py
"""Utilities for downloading and initializing model weights."""
import
fnmatch
import
glob
import
hashlib
import
json
import
os
from
collections
import
defaultdict
from
typing
import
Any
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
filelock
import
huggingface_hub.constants
import
numpy
as
np
import
torch
from
huggingface_hub
import
HfFileSystem
,
snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
get_quantization_config
,
)
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
logger
=
init_logger
(
__name__
)
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir
=
(
os
.
environ
.
get
(
"TMPDIR"
)
or
os
.
environ
.
get
(
"TEMP"
)
or
os
.
environ
.
get
(
"TMP"
)
or
"/tmp/"
)
def
enable_hf_transfer
():
"""automatically activates hf_transfer"""
if
"HF_HUB_ENABLE_HF_TRANSFER"
not
in
os
.
environ
:
try
:
# enable hf hub transfer if available
import
hf_transfer
# type: ignore # noqa
huggingface_hub
.
constants
.
HF_HUB_ENABLE_HF_TRANSFER
=
True
except
ImportError
:
pass
enable_hf_transfer
()
class
Disabledtqdm
(
tqdm
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
,
disable
=
True
)
def
get_lock
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
):
lock_dir
=
cache_dir
or
temp_dir
os
.
makedirs
(
os
.
path
.
dirname
(
lock_dir
),
exist_ok
=
True
)
model_name
=
model_name_or_path
.
replace
(
"/"
,
"-"
)
hash_name
=
hashlib
.
sha256
(
model_name
.
encode
()).
hexdigest
()
# add hash to avoid conflict with old users' lock files
lock_file_name
=
hash_name
+
model_name
+
".lock"
# mode 0o666 is required for the filelock to be shared across users
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
),
mode
=
0o666
)
return
lock
def
_shared_pointers
(
tensors
):
ptrs
=
defaultdict
(
list
)
for
k
,
v
in
tensors
.
items
():
ptrs
[
v
.
data_ptr
()].
append
(
k
)
failing
=
[]
for
_
,
names
in
ptrs
.
items
():
if
len
(
names
)
>
1
:
failing
.
append
(
names
)
return
failing
def
convert_bin_to_safetensor_file
(
pt_filename
:
str
,
sf_filename
:
str
,
)
->
None
:
loaded
=
torch
.
load
(
pt_filename
,
map_location
=
"cpu"
)
if
"state_dict"
in
loaded
:
loaded
=
loaded
[
"state_dict"
]
shared
=
_shared_pointers
(
loaded
)
for
shared_weights
in
shared
:
for
name
in
shared_weights
[
1
:]:
loaded
.
pop
(
name
)
# For tensors to be contiguous
loaded
=
{
k
:
v
.
contiguous
()
for
k
,
v
in
loaded
.
items
()}
dirname
=
os
.
path
.
dirname
(
sf_filename
)
os
.
makedirs
(
dirname
,
exist_ok
=
True
)
save_file
(
loaded
,
sf_filename
,
metadata
=
{
"format"
:
"pt"
})
# check file size
sf_size
=
os
.
stat
(
sf_filename
).
st_size
pt_size
=
os
.
stat
(
pt_filename
).
st_size
if
(
sf_size
-
pt_size
)
/
pt_size
>
0.01
:
raise
RuntimeError
(
f
"""The file size different is more than 1%:
-
{
sf_filename
}
:
{
sf_size
}
-
{
pt_filename
}
:
{
pt_size
}
"""
)
# check if the tensors are the same
reloaded
=
load_file
(
sf_filename
)
for
k
in
loaded
:
pt_tensor
=
loaded
[
k
]
sf_tensor
=
reloaded
[
k
]
if
not
torch
.
equal
(
pt_tensor
,
sf_tensor
):
raise
RuntimeError
(
f
"The output tensors do not match for key
{
k
}
"
)
# TODO(woosuk): Move this to other place.
def
get_quant_config
(
model_config
:
ModelConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
if
hf_quant_config
is
not
None
:
return
quant_cls
.
from_config
(
hf_quant_config
)
model_name_or_path
=
model_config
.
model
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
if
not
is_local
:
# Download the config files.
with
get_lock
(
model_name_or_path
,
model_config
.
download_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
revision
=
model_config
.
revision
,
allow_patterns
=
"*.json"
,
cache_dir
=
model_config
.
download_dir
,
tqdm_class
=
Disabledtqdm
,
)
else
:
hf_folder
=
model_name_or_path
config_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
"*.json"
))
quant_config_files
=
[
f
for
f
in
config_files
if
any
(
f
.
endswith
(
x
)
for
x
in
quant_cls
.
get_config_filenames
())
]
if
len
(
quant_config_files
)
==
0
:
raise
ValueError
(
f
"Cannot find the config file for
{
model_config
.
quantization
}
"
)
if
len
(
quant_config_files
)
>
1
:
raise
ValueError
(
f
"Found multiple config files for
{
model_config
.
quantization
}
: "
f
"
{
quant_config_files
}
"
)
quant_config_file
=
quant_config_files
[
0
]
with
open
(
quant_config_file
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
return
quant_cls
.
from_config
(
config
)
def
prepare_hf_model_weights
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
fall_back_to_pt
:
bool
=
True
,
revision
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
# Download model weights from huggingface.
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
and
load_format
!=
"tensorizer"
use_safetensors
=
False
# Some quantized models use .pt files for storing the weights.
if
load_format
==
"auto"
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
"safetensors"
:
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
"pt"
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
"npcache"
:
allow_patterns
=
[
"*.bin"
]
elif
load_format
==
"tensorizer"
:
allow_patterns
=
[
"*.tensors"
]
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
if
fall_back_to_pt
:
allow_patterns
+=
[
"*.pt"
]
if
not
is_local
and
load_format
!=
"tensorizer"
:
# Before we download we look at that is available:
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
# depending on what is available we download different things
for
pattern
in
allow_patterns
:
matching
=
fnmatch
.
filter
(
file_list
,
pattern
)
if
len
(
matching
)
>
0
:
allow_patterns
=
[
pattern
]
break
logger
.
info
(
f
"Using model weights format
{
allow_patterns
}
"
)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
allow_patterns
,
cache_dir
=
cache_dir
,
tqdm_class
=
Disabledtqdm
,
revision
=
revision
,
)
else
:
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
len
(
hf_weights_files
)
>
0
:
if
pattern
==
"*.safetensors"
:
use_safetensors
=
True
break
if
not
use_safetensors
:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
blacklist
=
[
"training_args.bin"
,
"optimizer.bin"
,
"optimizer.pt"
,
"scheduler.pt"
,
"scaler.pt"
,
]
hf_weights_files
=
[
f
for
f
in
hf_weights_files
if
not
any
(
f
.
endswith
(
x
)
for
x
in
blacklist
)
]
if
load_format
==
"tensorizer"
:
return
hf_folder
,
hf_weights_files
,
use_safetensors
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
def
hf_model_weights_iterator
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
Union
[
Tuple
,
str
]
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
fall_back_to_pt
:
Optional
[
bool
]
=
True
,
)
->
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]:
hf_folder
,
hf_weights_files
,
use_safetensors
=
prepare_hf_model_weights
(
model_name_or_path
,
cache_dir
=
cache_dir
,
load_format
=
load_format
,
fall_back_to_pt
=
fall_back_to_pt
,
revision
=
revision
,
)
if
load_format
==
"npcache"
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder
=
os
.
path
.
join
(
hf_folder
,
"np"
)
os
.
makedirs
(
np_folder
,
exist_ok
=
True
)
weight_names_file
=
os
.
path
.
join
(
np_folder
,
"weight_names.json"
)
# Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
if
not
os
.
path
.
exists
(
weight_names_file
):
weight_names
=
[]
for
bin_file
in
hf_weights_files
:
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
state
.
items
():
param_path
=
os
.
path
.
join
(
np_folder
,
name
)
with
open
(
param_path
,
"wb"
)
as
f
:
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
weight_names
.
append
(
name
)
with
open
(
weight_names_file
,
"w"
)
as
f
:
json
.
dump
(
weight_names
,
f
)
with
open
(
weight_names_file
,
"r"
)
as
f
:
weight_names
=
json
.
load
(
f
)
for
name
in
weight_names
:
param_path
=
os
.
path
.
join
(
np_folder
,
name
)
with
open
(
param_path
,
"rb"
)
as
f
:
param
=
np
.
load
(
f
)
yield
name
,
torch
.
from_numpy
(
param
)
elif
load_format
==
"tensorizer"
:
from
vllm.model_executor.tensorizer_loader
import
(
TensorDeserializer
,
open_stream
,
tensorizer_warning
,
)
tensorizer_args
=
load_format
.
params
tensorizer_warning
(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the examples/tensorize_vllm_model.py example "
"script for serializing vLLM models."
)
deserializer_args
=
tensorizer_args
.
deserializer_params
stream_params
=
tensorizer_args
.
stream_params
stream
=
open_stream
(
tensorizer_args
.
tensorizer_uri
,
**
stream_params
)
with
TensorDeserializer
(
stream
,
**
deserializer_args
,
device
=
"cpu"
)
as
state
:
for
name
,
param
in
state
.
items
():
yield
name
,
param
del
state
elif
use_safetensors
:
for
st_file
in
hf_weights_files
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
param
=
f
.
get_tensor
(
name
)
yield
name
,
param
else
:
for
bin_file
in
hf_weights_files
:
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
state
.
items
():
yield
name
,
param
del
state
torch
.
cuda
.
empty_cache
()
def
kv_cache_scales_loader
(
filename
:
str
,
tp_rank
:
int
,
tp_size
:
int
,
num_hidden_layers
:
int
,
model_type
:
Optional
[
str
],
)
->
Iterable
[
Tuple
[
int
,
float
]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
Keep this function in sync with the output of examples/fp8/extract_scales.py
"""
try
:
with
open
(
filename
)
as
f
:
context
=
{
"model_type"
:
model_type
,
"num_hidden_layers"
:
num_hidden_layers
,
"tp_rank"
:
tp_rank
,
"tp_size"
:
tp_size
,
}
schema_dct
=
json
.
load
(
f
)
schema
=
QuantParamSchema
.
model_validate
(
schema_dct
,
context
=
context
)
layer_scales_map
=
schema
.
kv_cache
.
scaling_factor
[
tp_rank
]
return
layer_scales_map
.
items
()
except
FileNotFoundError
:
logger
.
error
(
f
"File or directory '
{
filename
}
' not found."
)
except
json
.
JSONDecodeError
:
logger
.
error
(
f
"Error decoding JSON in file '
{
filename
}
'."
)
except
Exception
as
e
:
logger
.
error
(
f
"An error occurred while reading '
{
filename
}
':
{
e
}
"
)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger
.
warning
(
"Defaulting to KV cache scaling factors = 1.0 "
f
"for all layers in TP rank
{
tp_rank
}
"
"as an error occurred during loading."
)
return
[]
def
convert_pyslice_to_tensor
(
x
:
Any
)
->
torch
.
Tensor
:
"""convert PySafeSlice object from safetensors to torch.Tensor
PySafeSlice object supports indexing, which is done before loading the
actual tensor and can reduce the amount of memory being read into the
memory. However, it does not support more advanced functionalities
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
tensor with these more complicated operators, we need to convert to
tensor first.
"""
if
not
isinstance
(
x
,
torch
.
Tensor
):
x
=
x
[:]
return
x
def
default_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Default weight loader."""
assert
param
.
size
()
==
loaded_weight
.
size
()
param
.
data
.
copy_
(
loaded_weight
)
def
initialize_dummy_weights
(
model
:
torch
.
nn
.
Module
,
low
:
float
=
-
1e-3
,
high
:
float
=
1e-3
,
)
->
None
:
"""Initialize model weights with random values.
The model weights must be randomly initialized for accurate performance
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
"""
for
param
in
model
.
state_dict
().
values
():
if
torch
.
is_floating_point
(
param
):
param
.
data
.
uniform_
(
low
,
high
)
python/sglang/utils.py
View file @
19d2135c
...
@@ -141,7 +141,7 @@ def encode_frame(frame):
...
@@ -141,7 +141,7 @@ def encode_frame(frame):
def
encode_video_base64
(
video_path
,
num_frames
=
16
):
def
encode_video_base64
(
video_path
,
num_frames
=
16
):
import
cv2
import
cv2
# pip install opencv-python-headless
cap
=
cv2
.
VideoCapture
(
video_path
)
cap
=
cv2
.
VideoCapture
(
video_path
)
if
not
cap
.
isOpened
():
if
not
cap
.
isOpened
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment