Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
63e97e5e
Unverified
Commit
63e97e5e
authored
Jan 23, 2024
by
Arcmoon
Committed by
GitHub
Jan 22, 2024
Browse files
Suppport qwen model and solve some problems (#75)
parent
e08bca28
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
274 additions
and
4 deletions
+274
-4
README.md
README.md
+1
-0
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+0
-1
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+1
-0
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+4
-0
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+3
-1
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+261
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+4
-2
No files found.
README.md
View file @
63e97e5e
...
@@ -316,6 +316,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
...
@@ -316,6 +316,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
-
Mixtral
-
Mixtral
-
LLaVA
-
LLaVA
-
`python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
-
`python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
-
Qwen
-
AWQ quantization
-
AWQ quantization
## Benchmark And Performance
## Benchmark And Performance
...
...
python/sglang/srt/layers/radix_attention.py
View file @
63e97e5e
...
@@ -61,7 +61,6 @@ class RadixAttention(nn.Module):
...
@@ -61,7 +61,6 @@ class RadixAttention(nn.Module):
def
extend_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
extend_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
extend_attention_fwd
(
extend_attention_fwd
(
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
k
.
contiguous
(),
k
.
contiguous
(),
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
63e97e5e
...
@@ -55,6 +55,7 @@ class DetokenizerManager:
...
@@ -55,6 +55,7 @@ class DetokenizerManager:
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
int
(
output_tokens
[
i
][
0
])
int
(
output_tokens
[
i
][
0
])
)
)
first_token
=
first_token
.
decode
(
"utf-8"
)
if
first_token
.
startswith
(
"▁"
):
if
first_token
.
startswith
(
"▁"
):
output_strs
[
i
]
=
" "
+
output_strs
[
i
]
output_strs
[
i
]
=
" "
+
output_strs
[
i
]
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
63e97e5e
...
@@ -240,6 +240,7 @@ class ModelRunner:
...
@@ -240,6 +240,7 @@ class ModelRunner:
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.llava
import
LlavaLlamaForCausalLM
from
sglang.srt.models.llava
import
LlavaLlamaForCausalLM
from
sglang.srt.models.mixtral
import
MixtralForCausalLM
from
sglang.srt.models.mixtral
import
MixtralForCausalLM
from
sglang.srt.models.qwen
import
QWenLMHeadModel
# Select model class
# Select model class
architectures
=
getattr
(
self
.
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
self
.
model_config
.
hf_config
,
"architectures"
,
[])
...
@@ -258,6 +259,9 @@ class ModelRunner:
...
@@ -258,6 +259,9 @@ class ModelRunner:
if
arch
==
"MixtralForCausalLM"
:
if
arch
==
"MixtralForCausalLM"
:
model_class
=
MixtralForCausalLM
model_class
=
MixtralForCausalLM
break
break
if
arch
==
"QWenLMHeadModel"
:
model_class
=
QWenLMHeadModel
break
if
model_class
is
None
:
if
model_class
is
None
:
raise
ValueError
(
f
"Unsupported architectures:
{
architectures
}
"
)
raise
ValueError
(
f
"Unsupported architectures:
{
architectures
}
"
)
...
...
python/sglang/srt/model_config.py
View file @
63e97e5e
...
@@ -20,8 +20,10 @@ class ModelConfig:
...
@@ -20,8 +20,10 @@ class ModelConfig:
# Unify the config keys for hf_config
# Unify the config keys for hf_config
self
.
context_len
=
get_context_length
(
self
.
hf_config
)
self
.
context_len
=
get_context_length
(
self
.
hf_config
)
self
.
head_dim
=
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
self
.
head_dim
=
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
self
.
num_key_value_heads
=
self
.
hf_config
.
num_key_value_heads
self
.
num_attention_heads
=
self
.
hf_config
.
num_attention_heads
self
.
num_attention_heads
=
self
.
hf_config
.
num_attention_heads
self
.
num_key_value_heads
=
getattr
(
self
.
hf_config
,
"num_key_value_heads"
,
None
)
if
self
.
num_key_value_heads
is
None
:
self
.
num_key_value_heads
=
self
.
num_attention_heads
self
.
hidden_size
=
self
.
hf_config
.
hidden_size
self
.
hidden_size
=
self
.
hf_config
.
hidden_size
self
.
num_hidden_layers
=
self
.
hf_config
.
num_hidden_layers
self
.
num_hidden_layers
=
self
.
hf_config
.
num_hidden_layers
self
.
vocab_size
=
self
.
hf_config
.
vocab_size
self
.
vocab_size
=
self
.
hf_config
.
vocab_size
python/sglang/srt/models/qwen.py
0 → 100644
View file @
63e97e5e
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
torch
import
nn
from
vllm.transformers_utils.configs.qwen
import
QWenConfig
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
class
QWenMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
=
"silu"
,
):
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
2
*
[
intermediate_size
],
bias
=
False
,
gather_output
=
False
,
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
c_proj
(
x
)
return
x
class
QWenAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
max_position_embeddings
:
int
,
layer_id
:
int
=
0
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
(
)
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
(
self
.
total_num_heads
//
tensor_model_parallel_world_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
# pylint: disable=invalid-name
self
.
c_attn
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
True
)
self
.
c_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_heads
,
layer_id
=
layer_id
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
output
,
_
=
self
.
c_proj
(
attn_output
)
return
output
class
QWenBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
QWenConfig
,
layer_id
):
super
().
__init__
()
self
.
ln_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
self
.
attn
=
QWenAttention
(
config
.
hidden_size
,
config
.
num_attention_heads
,
config
.
max_position_embeddings
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
layer_id
=
layer_id
)
self
.
ln_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
QWenMLP
(
config
.
hidden_size
,
config
.
intermediate_size
//
2
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
hidden_states
=
self
.
attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
ln_2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
QWenModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
QWenConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
wte
=
VocabParallelEmbedding
(
vocab_size
,
config
.
hidden_size
,
)
self
.
h
=
nn
.
ModuleList
(
[
QWenBlock
(
config
,
i
)
for
i
in
range
(
config
.
num_hidden_layers
)])
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
for
i
in
range
(
len
(
self
.
h
)):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
input_metadata
,
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
class
QWenLMHeadModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
QWenConfig
,
linear_method
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
transformer
=
QWenModel
(
config
)
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
lm_head
=
ParallelLMHead
(
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
):
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
next_tokens
=
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[]
_row_parallel_weights
=
[
"c_proj.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"w2"
,
0
),
(
"gate_up_proj"
,
"w1"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
python/sglang/srt/utils.py
View file @
63e97e5e
...
@@ -108,9 +108,11 @@ def get_exception_traceback():
...
@@ -108,9 +108,11 @@ def get_exception_traceback():
def
get_int_token_logit_bias
(
tokenizer
,
vocab_size
):
def
get_int_token_logit_bias
(
tokenizer
,
vocab_size
):
from
transformers
import
LlamaTokenizer
,
LlamaTokenizerFast
from
transformers
import
LlamaTokenizer
,
LlamaTokenizerFast
# a bug when model's vocab size > tokenizer.vocab_size
vocab_size
=
tokenizer
.
vocab_size
logit_bias
=
np
.
zeros
(
vocab_size
,
dtype
=
np
.
float32
)
logit_bias
=
np
.
zeros
(
vocab_size
,
dtype
=
np
.
float32
)
for
t_id
in
range
(
vocab_size
):
for
t_id
in
range
(
vocab_size
):
ss
=
tokenizer
.
decode
(
t_id
).
strip
()
ss
=
tokenizer
.
decode
(
[
t_id
]
).
strip
()
if
not
(
ss
.
isdigit
()
or
len
(
ss
)
==
0
or
t_id
==
tokenizer
.
eos_token_id
):
if
not
(
ss
.
isdigit
()
or
len
(
ss
)
==
0
or
t_id
==
tokenizer
.
eos_token_id
):
logit_bias
[
t_id
]
=
-
1e5
logit_bias
[
t_id
]
=
-
1e5
# else:
# else:
...
@@ -214,4 +216,4 @@ def load_image(image_file):
...
@@ -214,4 +216,4 @@ def load_image(image_file):
else
:
else
:
image
=
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image_file
)))
image
=
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image_file
)))
return
image
return
image
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment