Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
28e616c4
Unverified
Commit
28e616c4
authored
Sep 28, 2023
by
Qing
Committed by
GitHub
Sep 27, 2023
Browse files
fix qwen-14b model (#1173)
parent
30e77528
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
43 deletions
+32
-43
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+7
-7
vllm/transformers_utils/configs/qwen.py
vllm/transformers_utils/configs/qwen.py
+25
-36
No files found.
vllm/model_executor/models/qwen.py
View file @
28e616c4
...
@@ -141,17 +141,17 @@ class QWenBlock(nn.Module):
...
@@ -141,17 +141,17 @@ class QWenBlock(nn.Module):
def
__init__
(
self
,
config
:
QWenConfig
):
def
__init__
(
self
,
config
:
QWenConfig
):
super
().
__init__
()
super
().
__init__
()
self
.
ln_1
=
RMSNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
attn
=
QWenAttention
(
config
.
n_embd
,
self
.
attn
=
QWenAttention
(
config
.
hidden_size
,
config
.
num_attention_heads
,
config
.
num_attention_heads
,
config
.
max_position_embeddings
,
config
.
max_position_embeddings
,
rope_theta
=
rope_theta
)
rope_theta
=
rope_theta
)
self
.
ln_2
=
RMSNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
QWenMLP
(
config
.
n_embd
,
config
.
ffn_hidden
_size
//
2
)
self
.
mlp
=
QWenMLP
(
config
.
hidden_size
,
config
.
intermediate
_size
//
2
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -190,11 +190,11 @@ class QWenModel(nn.Module):
...
@@ -190,11 +190,11 @@ class QWenModel(nn.Module):
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
wte
=
VocabParallelEmbedding
(
vocab_size
,
self
.
wte
=
VocabParallelEmbedding
(
vocab_size
,
config
.
n_embd
,
config
.
hidden_size
,
perform_initialization
=
False
)
perform_initialization
=
False
)
self
.
h
=
nn
.
ModuleList
(
self
.
h
=
nn
.
ModuleList
(
[
QWenBlock
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
[
QWenBlock
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
ln_f
=
RMSNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -230,7 +230,7 @@ class QWenLMHeadModel(nn.Module):
...
@@ -230,7 +230,7 @@ class QWenLMHeadModel(nn.Module):
self
.
transformer
=
QWenModel
(
config
)
self
.
transformer
=
QWenModel
(
config
)
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
lm_head
=
ColumnParallelLinear
(
self
.
lm_head
=
ColumnParallelLinear
(
config
.
n_embd
,
config
.
hidden_size
,
vocab_size
,
vocab_size
,
bias
=
False
,
bias
=
False
,
gather_output
=
False
,
gather_output
=
False
,
...
...
vllm/transformers_utils/configs/qwen.py
View file @
28e616c4
...
@@ -7,65 +7,54 @@ from transformers import PretrainedConfig
...
@@ -7,65 +7,54 @@ from transformers import PretrainedConfig
class
QWenConfig
(
PretrainedConfig
):
class
QWenConfig
(
PretrainedConfig
):
model_type
=
"qwen"
model_type
=
"qwen"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
keys_to_ignore_at_inference
=
[
"past_key_values"
]
attribute_map
=
{
"hidden_size"
:
"n_embd"
,
"num_attention_heads"
:
"n_head"
,
"max_position_embeddings"
:
"n_positions"
,
"num_hidden_layers"
:
"n_layer"
,
}
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size
=
151851
,
vocab_size
=
151936
,
n_embd
=
4096
,
hidden_size
=
4096
,
n_layer
=
32
,
num_hidden_layers
=
32
,
n_head
=
32
,
num_attention_heads
=
32
,
n_inner
=
None
,
emb_dropout_prob
=
0.0
,
embd_pdrop
=
0.0
,
attn_dropout_prob
=
0.0
,
attn_pdrop
=
0.0
,
layer_norm_epsilon
=
1e-6
,
layer_norm_epsilon
=
1e-5
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
max_position_embeddings
=
8192
,
scale_attn_weights
=
True
,
scale_attn_weights
=
True
,
use_cache
=
True
,
use_cache
=
True
,
eos_token_id
=
151643
,
bf16
=
False
,
apply_residual_connection_post_layernorm
=
False
,
fp16
=
False
,
bf16
=
Tru
e
,
fp32
=
Fals
e
,
kv_channels
=
128
,
kv_channels
=
128
,
rotary_pct
=
1.0
,
rotary_pct
=
1.0
,
rotary_emb_base
=
10000
,
rotary_emb_base
=
10000
,
use_dynamic_ntk
=
Fals
e
,
use_dynamic_ntk
=
Tru
e
,
use_logn_attn
=
Fals
e
,
use_logn_attn
=
Tru
e
,
use_flash_attn
=
True
,
use_flash_attn
=
"auto"
,
ffn_hidden
_size
=
22016
,
intermediate
_size
=
22016
,
no_bias
=
True
,
no_bias
=
True
,
tie_word_embeddings
=
False
,
tie_word_embeddings
=
False
,
**
kwargs
,
**
kwargs
,
):
):
self
.
eos_token_id
=
eos_token_id
super
().
__init__
(
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
n_embd
=
n_embd
self
.
hidden_size
=
hidden_size
self
.
n_layer
=
n_layer
self
.
intermediate_size
=
intermediate_size
self
.
n
_head
=
n_head
self
.
n
um_hidden_layers
=
num_hidden_layers
self
.
n
_inner
=
n_inner
self
.
n
um_attention_heads
=
num_attention_heads
self
.
emb
d_p
drop
=
emb
d_p
drop
self
.
emb
_
drop
out_prob
=
emb
_
drop
out_prob
self
.
attn_
p
drop
=
attn_
p
drop
self
.
attn_drop
out_prob
=
attn_drop
out_prob
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
scale_attn_weights
=
scale_attn_weights
self
.
scale_attn_weights
=
scale_attn_weights
self
.
use_cache
=
use_cache
self
.
use_cache
=
use_cache
self
.
apply_residual_connection_post_layernorm
=
(
self
.
max_position_embeddings
=
max_position_embeddings
apply_residual_connection_post_layernorm
)
self
.
bf16
=
bf16
self
.
bf16
=
bf16
self
.
fp16
=
fp16
self
.
fp32
=
fp32
self
.
kv_channels
=
kv_channels
self
.
kv_channels
=
kv_channels
self
.
rotary_pct
=
rotary_pct
self
.
rotary_pct
=
rotary_pct
self
.
rotary_emb_base
=
rotary_emb_base
self
.
rotary_emb_base
=
rotary_emb_base
self
.
use_dynamic_ntk
=
use_dynamic_ntk
self
.
use_dynamic_ntk
=
use_dynamic_ntk
self
.
use_logn_attn
=
use_logn_attn
self
.
use_logn_attn
=
use_logn_attn
self
.
use_flash_attn
=
use_flash_attn
self
.
use_flash_attn
=
use_flash_attn
self
.
ffn_hidden_size
=
ffn_hidden_size
self
.
no_bias
=
no_bias
self
.
no_bias
=
no_bias
s
elf
.
tie_word_embeddings
=
tie_word_embeddings
s
uper
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment