Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
656aed58
Unverified
Commit
656aed58
authored
Jan 09, 2025
by
Yunmeng
Committed by
GitHub
Jan 09, 2025
Browse files
Remove vllm dependency in model config (#2809)
parent
b5fb4ef5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
372 additions
and
16 deletions
+372
-16
python/sglang/srt/configs/__init__.py
python/sglang/srt/configs/__init__.py
+4
-0
python/sglang/srt/configs/chatglm.py
python/sglang/srt/configs/chatglm.py
+78
-0
python/sglang/srt/configs/dbrx.py
python/sglang/srt/configs/dbrx.py
+279
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+9
-14
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+1
-1
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+1
-1
No files found.
python/sglang/srt/configs/__init__.py
View file @
656aed58
from
sglang.srt.configs.chatglm
import
ChatGLMConfig
from
sglang.srt.configs.dbrx
import
DbrxConfig
from
sglang.srt.configs.exaone
import
ExaoneConfig
from
sglang.srt.configs.exaone
import
ExaoneConfig
from
sglang.srt.configs.qwen2vl
import
Qwen2VLConfig
,
Qwen2VLVisionConfig
from
sglang.srt.configs.qwen2vl
import
Qwen2VLConfig
,
Qwen2VLVisionConfig
...
@@ -5,4 +7,6 @@ __all__ = [
...
@@ -5,4 +7,6 @@ __all__ = [
"ExaoneConfig"
,
"ExaoneConfig"
,
"Qwen2VLConfig"
,
"Qwen2VLConfig"
,
"Qwen2VLVisionConfig"
,
"Qwen2VLVisionConfig"
,
"ChatGLMConfig"
,
"DbrxConfig"
,
]
]
python/sglang/srt/configs/chatglm.py
0 → 100644
View file @
656aed58
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/chatglm.py
# ChatGLM2 and ChatGLM3 share the same config.
# ChatGLM4 is officially supported by Huggingface
# transformers >= 4.46.0 is required
# https://huggingface.co/docs/transformers/en/model_doc/glm
from
transformers
import
PretrainedConfig
class
ChatGLMConfig
(
PretrainedConfig
):
model_type
=
"chatglm"
attribute_map
=
{
"num_hidden_layers"
:
"num_layers"
,
"n_head_kv"
:
"multi_query_group_num"
,
}
def
__init__
(
self
,
num_layers
=
28
,
padded_vocab_size
=
65024
,
hidden_size
=
4096
,
ffn_hidden_size
=
13696
,
kv_channels
=
128
,
num_attention_heads
=
32
,
seq_length
=
2048
,
hidden_dropout
=
0.0
,
attention_dropout
=
0.0
,
layernorm_epsilon
=
1e-5
,
rmsnorm
=
True
,
apply_residual_connection_post_layernorm
=
False
,
post_layer_norm
=
True
,
add_bias_linear
=
False
,
add_qkv_bias
=
False
,
interleaved_qkv
=
False
,
bias_dropout_fusion
=
True
,
multi_query_attention
=
False
,
multi_query_group_num
=
1
,
apply_query_key_layer_scaling
=
True
,
attention_softmax_in_fp32
=
True
,
fp32_residual_connection
=
False
,
quantization_bit
=
0
,
pre_seq_len
=
None
,
prefix_projection
=
False
,
**
kwargs
):
self
.
num_layers
=
num_layers
self
.
vocab_size
=
padded_vocab_size
self
.
padded_vocab_size
=
padded_vocab_size
self
.
hidden_size
=
hidden_size
self
.
ffn_hidden_size
=
ffn_hidden_size
self
.
kv_channels
=
kv_channels
self
.
num_attention_heads
=
num_attention_heads
self
.
seq_length
=
seq_length
# It is to be compatible with long lora.
self
.
max_position_embeddings
=
seq_length
self
.
hidden_dropout
=
hidden_dropout
self
.
attention_dropout
=
attention_dropout
self
.
layernorm_epsilon
=
layernorm_epsilon
self
.
rmsnorm
=
rmsnorm
self
.
apply_residual_connection_post_layernorm
=
(
apply_residual_connection_post_layernorm
)
self
.
post_layer_norm
=
post_layer_norm
self
.
add_bias_linear
=
add_bias_linear
self
.
add_qkv_bias
=
add_qkv_bias
self
.
bias_dropout_fusion
=
bias_dropout_fusion
self
.
multi_query_attention
=
multi_query_attention
self
.
multi_query_group_num
=
multi_query_group_num
self
.
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
attention_softmax_in_fp32
self
.
fp32_residual_connection
=
fp32_residual_connection
self
.
quantization_bit
=
quantization_bit
self
.
pre_seq_len
=
pre_seq_len
self
.
prefix_projection
=
prefix_projection
self
.
interleaved_qkv
=
interleaved_qkv
super
().
__init__
(
**
kwargs
)
python/sglang/srt/configs/dbrx.py
0 → 100644
View file @
656aed58
# Adapted from
# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/dbrx.py
"""Dbrx configuration."""
from
typing
import
Any
,
Optional
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
# type: ignore
class
DbrxAttentionConfig
(
PretrainedConfig
):
"""Configuration class for Dbrx Attention.
[`DbrxAttention`] class. It is used to instantiate attention layers
according to the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
attn_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the attention layers.
clip_qkv (`float`, *optional*, defaults to None):
If not `None`, clip the queries, keys, and values in the attention layer to this value.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
rope_theta (float): The base frequency for rope.
"""
def
__init__
(
self
,
attn_pdrop
:
float
=
0
,
clip_qkv
:
Optional
[
float
]
=
None
,
kv_n_heads
:
int
=
1
,
rope_theta
:
float
=
10000.0
,
**
kwargs
:
Any
,
):
super
().
__init__
(
**
kwargs
)
self
.
attn_pdrop
=
attn_pdrop
self
.
clip_qkv
=
clip_qkv
self
.
kv_n_heads
=
kv_n_heads
self
.
rope_theta
=
rope_theta
for
k
in
[
"model_type"
]:
if
k
in
kwargs
:
kwargs
.
pop
(
k
)
if
len
(
kwargs
)
!=
0
:
raise
ValueError
(
f
"Found unknown
{
kwargs
=
}
"
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
,
**
kwargs
:
Any
)
->
"PretrainedConfig"
:
cls
.
_set_token_in_kwargs
(
kwargs
)
config_dict
,
kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
,
**
kwargs
)
if
config_dict
.
get
(
"model_type"
)
==
"dbrx"
:
config_dict
=
config_dict
[
"attn_config"
]
if
(
"model_type"
in
config_dict
and
hasattr
(
cls
,
"model_type"
)
and
config_dict
[
"model_type"
]
!=
cls
.
model_type
):
logger
.
warning
(
"You are using a model of type %s to instantiate a model of "
"type %s. This is not supported for all configurations of "
"models and can yield errors."
,
config_dict
[
"model_type"
],
cls
.
model_type
,
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
class
DbrxFFNConfig
(
PretrainedConfig
):
"""Configuration class for Dbrx FFN.
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
The dict should have a key 'name' with the value being the name of
the activation function along with any additional keyword arguments.
ffn_hidden_size (int, optional): The hidden size of the feedforward network.
moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
This should only be used for benchmarking purposes.
"""
def
__init__
(
self
,
ffn_act_fn
:
Optional
[
dict
]
=
None
,
ffn_hidden_size
:
int
=
3584
,
moe_num_experts
:
int
=
4
,
moe_top_k
:
int
=
1
,
moe_jitter_eps
:
Optional
[
float
]
=
None
,
moe_loss_weight
:
float
=
0.01
,
moe_normalize_expert_weights
:
Optional
[
float
]
=
1
,
uniform_expert_assignment
:
bool
=
False
,
**
kwargs
:
Any
,
):
super
().
__init__
()
if
ffn_act_fn
is
None
:
ffn_act_fn
=
{
"name"
:
"silu"
}
self
.
ffn_act_fn
=
ffn_act_fn
self
.
ffn_hidden_size
=
ffn_hidden_size
self
.
moe_num_experts
=
moe_num_experts
self
.
moe_top_k
=
moe_top_k
self
.
moe_jitter_eps
=
moe_jitter_eps
self
.
moe_loss_weight
=
moe_loss_weight
self
.
moe_normalize_expert_weights
=
moe_normalize_expert_weights
self
.
uniform_expert_assignment
=
uniform_expert_assignment
for
k
in
[
"model_type"
]:
if
k
in
kwargs
:
kwargs
.
pop
(
k
)
if
len
(
kwargs
)
!=
0
:
raise
ValueError
(
f
"Found unknown
{
kwargs
=
}
"
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
,
**
kwargs
:
Any
)
->
"PretrainedConfig"
:
cls
.
_set_token_in_kwargs
(
kwargs
)
config_dict
,
kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
,
**
kwargs
)
if
config_dict
.
get
(
"model_type"
)
==
"dbrx"
:
config_dict
=
config_dict
[
"ffn_config"
]
if
(
"model_type"
in
config_dict
and
hasattr
(
cls
,
"model_type"
)
and
config_dict
[
"model_type"
]
!=
cls
.
model_type
):
logger
.
warning
(
"You are using a model of type %s to instantiate a model of "
"type %s. This is not supported for all "
"configurations of models and can yield errors."
,
config_dict
[
"model_type"
],
cls
.
model_type
,
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
class
DbrxConfig
(
PretrainedConfig
):
"""Configuration class for Dbrx.
[`DbrxModel`]. It is used to instantiate a Dbrx model according to the
specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
d_model (`int`, *optional*, defaults to 6144):
Dimensionality of the embeddings and hidden states.
n_heads (`int`, *optional*, defaults to 48):
Number of attention heads for each attention layer in the Transformer encoder.
n_layers (`int`, *optional*, defaults to 40):
Number of hidden layers in the Transformer encoder.
max_seq_len (`int`, *optional*, defaults to 32768):
The maximum sequence length of the model.
vocab_size (`int`, *optional*, defaults to 100352):
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
the `inputs_ids` passed when calling [`DbrxModel`].
resid_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability applied to the attention output before combining with residual.
emb_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the embedding layer.
attn_config (`dict`, *optional*):
A dictionary used to configure the model's attention module.
ffn_config (`dict`, *optional*):
A dictionary used to configure the model's FFN module.
use_cache (`bool`, *optional*, defaults to `False`):
Whether or not the model should return the last key/values attentions (not used by all models).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
Example:
```python
>>> from transformers import DbrxConfig, DbrxModel
>>> # Initializing a Dbrx configuration
>>> configuration = DbrxConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = DbrxModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type
=
"dbrx"
attribute_map
=
{
"num_attention_heads"
:
"n_heads"
,
"hidden_size"
:
"d_model"
,
"num_hidden_layers"
:
"n_layers"
,
"max_position_embeddings"
:
"max_seq_len"
,
}
def
__init__
(
self
,
d_model
:
int
=
2048
,
n_heads
:
int
=
16
,
n_layers
:
int
=
24
,
max_seq_len
:
int
=
2048
,
vocab_size
:
int
=
32000
,
resid_pdrop
:
float
=
0.0
,
emb_pdrop
:
float
=
0.0
,
attn_config
:
Optional
[
DbrxAttentionConfig
]
=
None
,
ffn_config
:
Optional
[
DbrxFFNConfig
]
=
None
,
use_cache
:
bool
=
True
,
initializer_range
:
float
=
0.02
,
output_router_logits
:
bool
=
False
,
router_aux_loss_coef
:
float
=
0.05
,
**
kwargs
:
Any
,
):
if
attn_config
is
None
:
self
.
attn_config
=
DbrxAttentionConfig
()
elif
isinstance
(
attn_config
,
dict
):
self
.
attn_config
=
DbrxAttentionConfig
(
**
attn_config
)
else
:
self
.
attn_config
=
attn_config
if
ffn_config
is
None
:
self
.
ffn_config
=
DbrxFFNConfig
()
elif
isinstance
(
ffn_config
,
dict
):
self
.
ffn_config
=
DbrxFFNConfig
(
**
ffn_config
)
else
:
self
.
ffn_config
=
ffn_config
self
.
d_model
=
d_model
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
resid_pdrop
=
resid_pdrop
self
.
emb_pdrop
=
emb_pdrop
self
.
use_cache
=
use_cache
self
.
initializer_range
=
initializer_range
self
.
output_router_logits
=
output_router_logits
self
.
router_aux_loss_coef
=
router_aux_loss_coef
tie_word_embeddings
=
kwargs
.
pop
(
"tie_word_embeddings"
,
False
)
if
tie_word_embeddings
:
raise
ValueError
(
"tie_word_embeddings is not supported for Dbrx models."
)
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
python/sglang/srt/hf_transformers_utils.py
View file @
656aed58
...
@@ -30,20 +30,15 @@ from transformers import (
...
@@ -30,20 +30,15 @@ from transformers import (
)
)
from
transformers.models.auto.modeling_auto
import
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from
transformers.models.auto.modeling_auto
import
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
try
:
from
sglang.srt.configs
import
ChatGLMConfig
,
DbrxConfig
,
ExaoneConfig
,
Qwen2VLConfig
from
vllm.transformers_utils.configs
import
ChatGLMConfig
,
DbrxConfig
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
from
sglang.srt.configs
import
ExaoneConfig
,
Qwen2VLConfig
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
DbrxConfig
.
model_type
:
DbrxConfig
,
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
ExaoneConfig
.
model_type
:
ExaoneConfig
,
ChatGLMConfig
.
model_type
:
ChatGLMConfig
,
Qwen2VLConfig
.
model_type
:
Qwen2VLConfig
,
DbrxConfig
.
model_type
:
DbrxConfig
,
}
ExaoneConfig
.
model_type
:
ExaoneConfig
,
Qwen2VLConfig
.
model_type
:
Qwen2VLConfig
,
}
except
ImportError
:
# We want this file to run without vllm dependency
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{}
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
with
contextlib
.
suppress
(
ValueError
):
with
contextlib
.
suppress
(
ValueError
):
...
...
python/sglang/srt/models/chatglm.py
View file @
656aed58
...
@@ -23,8 +23,8 @@ from torch import nn
...
@@ -23,8 +23,8 @@ from torch import nn
from
torch.nn
import
LayerNorm
from
torch.nn
import
LayerNorm
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
sglang.srt.configs
import
ChatGLMConfig
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
...
...
python/sglang/srt/models/dbrx.py
View file @
656aed58
...
@@ -25,8 +25,8 @@ from vllm.distributed import (
...
@@ -25,8 +25,8 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
sglang.srt.configs
import
DbrxConfig
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment