Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a8c787d2
Unverified
Commit
a8c787d2
authored
Jun 12, 2024
by
Qubitium
Committed by
GitHub
Jun 11, 2024
Browse files
Add ChatGLM Model Support (#516)
Co-authored-by:
ZX
<
zx@lbx.dev
>
parent
5f283991
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
468 additions
and
3 deletions
+468
-3
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+10
-1
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+68
-2
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+390
-0
No files found.
python/sglang/srt/managers/controller/model_runner.py
View file @
a8c787d2
...
@@ -330,7 +330,7 @@ class ModelRunner:
...
@@ -330,7 +330,7 @@ class ModelRunner:
self
.
token_to_kv_pool
=
TokenToKVPool
(
self
.
token_to_kv_pool
=
TokenToKVPool
(
self
.
max_total_num_tokens
,
self
.
max_total_num_tokens
,
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
head_num
=
self
.
model_config
.
num_k
ey_value
_heads
//
self
.
tp_size
,
head_num
=
self
.
model_config
.
get_
num_k
v
_heads
(
self
.
tp_size
)
,
head_dim
=
self
.
model_config
.
head_dim
,
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
)
)
...
@@ -446,11 +446,20 @@ def import_model_classes():
...
@@ -446,11 +446,20 @@ def import_model_classes():
model_arch_name_to_cls
[
tmp
.
__name__
]
=
tmp
model_arch_name_to_cls
[
tmp
.
__name__
]
=
tmp
else
:
else
:
model_arch_name_to_cls
[
entry
.
__name__
]
=
entry
model_arch_name_to_cls
[
entry
.
__name__
]
=
entry
# compat: some models such as chatglm has incorrect class set in config.json
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
if
hasattr
(
module
,
"EntryClassRemapping"
)
and
isinstance
(
module
.
EntryClassRemapping
,
list
):
for
remap
in
module
.
EntryClassRemapping
:
if
isinstance
(
remap
,
tuple
)
and
len
(
remap
)
==
2
:
model_arch_name_to_cls
[
remap
[
0
]]
=
remap
[
1
]
return
model_arch_name_to_cls
return
model_arch_name_to_cls
def
load_model_cls_srt
(
model_arch
:
str
)
->
Optional
[
Type
[
nn
.
Module
]]:
def
load_model_cls_srt
(
model_arch
:
str
)
->
Optional
[
Type
[
nn
.
Module
]]:
model_arch_name_to_cls
=
import_model_classes
()
model_arch_name_to_cls
=
import_model_classes
()
if
model_arch
not
in
model_arch_name_to_cls
:
if
model_arch
not
in
model_arch_name_to_cls
:
raise
ValueError
(
raise
ValueError
(
f
"Unsupported architectures:
{
model_arch
}
. "
f
"Unsupported architectures:
{
model_arch
}
. "
...
...
python/sglang/srt/model_config.py
View file @
a8c787d2
from
typing
import
Optional
from
typing
import
Optional
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
from
transformers
import
PretrainedConfig
class
ModelConfig
:
class
ModelConfig
:
...
@@ -18,7 +19,7 @@ class ModelConfig:
...
@@ -18,7 +19,7 @@ class ModelConfig:
self
.
model_overide_args
=
model_overide_args
self
.
model_overide_args
=
model_overide_args
self
.
hf_config
=
get_config
(
self
.
path
,
trust_remote_code
,
revision
,
self
.
hf_config
=
get_config
(
self
.
path
,
trust_remote_code
,
revision
,
model_overide_args
=
model_overide_args
)
model_overide_args
=
model_overide_args
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
if
context_length
is
not
None
:
if
context_length
is
not
None
:
self
.
context_len
=
context_length
self
.
context_len
=
context_length
else
:
else
:
...
@@ -43,4 +44,69 @@ class ModelConfig:
...
@@ -43,4 +44,69 @@ class ModelConfig:
self
.
num_key_value_heads
=
self
.
num_attention_heads
self
.
num_key_value_heads
=
self
.
num_attention_heads
self
.
hidden_size
=
self
.
hf_config
.
hidden_size
self
.
hidden_size
=
self
.
hf_config
.
hidden_size
self
.
num_hidden_layers
=
self
.
hf_config
.
num_hidden_layers
self
.
num_hidden_layers
=
self
.
hf_config
.
num_hidden_layers
self
.
vocab_size
=
self
.
hf_config
.
vocab_size
self
.
vocab_size
=
self
.
hf_config
.
vocab_size
\ No newline at end of file
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def
get_total_num_kv_heads
(
self
)
->
int
:
"""Returns the total number of KV heads."""
# For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types
=
[
"falcon"
,
"RefinedWeb"
,
"RefinedWebModel"
]
new_decoder_arch_falcon
=
(
self
.
hf_config
.
model_type
in
falcon_model_types
and
getattr
(
self
.
hf_config
,
"new_decoder_architecture"
,
False
))
if
not
new_decoder_arch_falcon
and
getattr
(
self
.
hf_text_config
,
"multi_query"
,
False
):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return
1
# For DBRX and MPT
if
self
.
hf_config
.
model_type
in
[
"dbrx"
,
"mpt"
]:
return
getattr
(
self
.
hf_config
.
attn_config
,
"kv_n_heads"
,
self
.
hf_config
.
num_attention_heads
)
attributes
=
[
# For Falcon:
"n_head_kv"
,
"num_kv_heads"
,
# For LLaMA-2:
"num_key_value_heads"
,
# For ChatGLM:
"multi_query_group_num"
,
]
for
attr
in
attributes
:
num_kv_heads
=
getattr
(
self
.
hf_text_config
,
attr
,
None
)
if
num_kv_heads
is
not
None
:
return
num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return
self
.
hf_text_config
.
num_attention_heads
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
def
get_num_kv_heads
(
self
,
tensor_parallel_size
)
->
int
:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads
=
self
.
get_total_num_kv_heads
()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return
max
(
1
,
total_num_kv_heads
//
tensor_parallel_size
)
def
get_hf_text_config
(
config
:
PretrainedConfig
):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
if
hasattr
(
config
,
"text_config"
):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert
hasattr
(
config
.
text_config
,
"num_attention_heads"
)
return
config
.
text_config
else
:
return
config
python/sglang/srt/models/chatglm.py
0 → 100644
View file @
a8c787d2
# coding=utf-8
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
peft
import
LoraConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
ChatGLMConfig
class
GLMAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_id
:
int
=
0
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
multi_query_attention
=
config
.
multi_query_attention
self
.
total_num_kv_heads
=
(
config
.
multi_query_group_num
if
config
.
multi_query_attention
else
config
.
num_attention_heads
)
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
config
.
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
query_key_value
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
config
.
add_bias_linear
or
config
.
add_qkv_bias
,
quant_config
=
quant_config
,
)
self
.
dense
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
config
.
hidden_size
,
bias
=
config
.
add_bias_linear
,
quant_config
=
quant_config
,
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio
=
getattr
(
config
,
"rope_ratio"
,
1.0
)
max_positions
=
getattr
(
config
,
"seq_length"
,
8192
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
//
2
,
max_position
=
max_positions
,
base
=
10000
*
rope_ratio
,
is_neox_style
=
False
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
context_layer
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
,
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
return
attn_output
class
GLMMLP
(
nn
.
Module
):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
add_bias
=
config
.
add_bias_linear
# Project to 4h.
self
.
dense_h_to_4h
=
MergedColumnParallelLinear
(
config
.
hidden_size
,
[
config
.
ffn_hidden_size
]
*
2
,
bias
=
config
.
add_bias_linear
,
quant_config
=
quant_config
,
)
self
.
activation_func
=
SiluAndMul
()
# Project back to h.
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
ffn_hidden_size
,
config
.
hidden_size
,
bias
=
config
.
add_bias_linear
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
hidden_states
):
# [s, b, 4hp]
intermediate_parallel
,
_
=
self
.
dense_h_to_4h
(
hidden_states
)
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
)
# [s, b, h]
output
,
_
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
class
GLMBlock
(
nn
.
Module
):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
config
,
layer_id
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
apply_residual_connection_post_layernorm
=
(
config
.
apply_residual_connection_post_layernorm
)
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
# Layernorm on the input data.
self
.
input_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
# Self attention.
self
.
self_attention
=
GLMAttention
(
config
,
layer_id
,
cache_config
,
quant_config
)
self
.
hidden_dropout
=
config
.
hidden_dropout
# Layernorm on the attention output
self
.
post_attention_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
# MLP
self
.
mlp
=
GLMMLP
(
config
,
quant_config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
=
self
.
self_attention
(
hidden_states
=
layernorm_output
,
position_ids
=
position_ids
,
input_metadata
=
input_metadata
,
)
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
layernorm_input
=
residual
+
attention_output
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
layernorm_input
output
=
self
.
mlp
(
layernorm_output
)
+
residual
return
output
class
GLMTransformer
(
nn
.
Module
):
"""Transformer class."""
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
post_layer_norm
=
config
.
post_layer_norm
# Number of layers.
self
.
num_layers
=
config
.
num_layers
# Transformer layers.
self
.
layers
=
nn
.
ModuleList
([
GLMBlock
(
config
,
i
,
cache_config
,
quant_config
)
for
i
in
range
(
self
.
num_layers
)
])
if
self
.
post_layer_norm
:
layer_norm_func
=
RMSNorm
if
config
.
rmsnorm
else
LayerNorm
# Final layer norm before output.
self
.
final_layernorm
=
layer_norm_func
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
for
i
in
range
(
self
.
num_layers
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
=
hidden_states
,
position_ids
=
position_ids
,
input_metadata
=
input_metadata
,
)
# Final layer norm.
if
self
.
post_layer_norm
:
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
return
hidden_states
class
ChatGLMModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
embedding
=
VocabParallelEmbedding
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
self
.
num_layers
=
config
.
num_layers
self
.
multi_query_group_num
=
config
.
multi_query_group_num
self
.
kv_channels
=
config
.
kv_channels
self
.
encoder
=
GLMTransformer
(
config
,
cache_config
,
quant_config
)
self
.
output_layer
=
ParallelLMHead
(
config
.
padded_vocab_size
,
config
.
hidden_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embedding
(
input_ids
)
# Run encoder.
hidden_states
=
self
.
encoder
(
hidden_states
=
inputs_embeds
,
position_ids
=
position_ids
,
input_metadata
=
input_metadata
,
)
return
hidden_states
class
ChatGLMForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"query_key_value"
:
[
"query_key_value"
],
"dense_h_to_4h"
:
[
"dense_h_to_4h"
]
}
# LoRA specific attributes
supported_lora_modules
=
[
"query_key_value"
,
"dense"
,
"dense_h_to_4h"
,
"dense_4h_to_h"
,
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
self
,
config
:
ChatGLMConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoraConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
:
ChatGLMConfig
=
config
self
.
quant_config
=
quant_config
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
8192
)
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
if
"rotary_pos_emb.inv_freq"
in
name
:
continue
if
"word_embeddings"
in
name
:
name
=
name
.
replace
(
".word_embeddings"
,
""
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
ChatGLMForCausalLM
# compat: glm model.config class == ChatGLMModel
EntryClassRemapping
=
[(
"ChatGLMModel"
,
ChatGLMForCausalLM
)]
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment