Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
de894728
Unverified
Commit
de894728
authored
Oct 13, 2023
by
Lu Wang
Committed by
GitHub
Oct 13, 2023
Browse files
Fix the issue for AquilaChat2-* models (#1339)
parent
e7c8555d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
2 deletions
+10
-2
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+1
-0
vllm/model_executor/models/aquila.py
vllm/model_executor/models/aquila.py
+3
-2
vllm/transformers_utils/configs/aquila.py
vllm/transformers_utils/configs/aquila.py
+6
-0
No files found.
vllm/model_executor/model_loader.py
View file @
de894728
...
...
@@ -14,6 +14,7 @@ from vllm.model_executor.weight_utils import (get_quant_config,
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY
=
{
"AquilaModel"
:
AquilaForCausalLM
,
"AquilaForCausalLM"
:
AquilaForCausalLM
,
# AquilaChat2
"BaiChuanForCausalLM"
:
BaiChuanForCausalLM
,
# baichuan-7b
"BaichuanForCausalLM"
:
BaichuanForCausalLM
,
# baichuan-13b
"BloomForCausalLM"
:
BloomForCausalLM
,
...
...
vllm/model_executor/models/aquila.py
View file @
de894728
...
...
@@ -147,6 +147,7 @@ class AquilaAttention(nn.Module):
rotary_dim
=
self
.
head_dim
,
base
=
self
.
rope_theta
,
max_position
=
self
.
max_position_embeddings
,
num_kv_heads
=
self
.
num_kv_heads
,
)
def
forward
(
...
...
@@ -177,7 +178,7 @@ class AquilaDecoderLayer(nn.Module):
self
.
self_attn
=
AquilaAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_
attention
_heads
,
num_kv_heads
=
config
.
num_
key_value
_heads
,
rope_theta
=
rope_theta
,
max_position_embeddings
=
max_position_embeddings
,
)
...
...
@@ -308,7 +309,7 @@ class AquilaForCausalLM(nn.Module):
q_proj_shard_size
=
(
self
.
config
.
hidden_size
//
tp_size
)
kv_proj_shard_size
=
(
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
*
self
.
config
.
num_
attention
_heads
//
tp_size
)
self
.
config
.
num_
key_value
_heads
//
tp_size
)
attention_weight_specs
=
[
# (weight_name, shard_size, offset)
(
"q_proj"
,
q_proj_shard_size
,
0
),
...
...
vllm/transformers_utils/configs/aquila.py
View file @
de894728
...
...
@@ -33,6 +33,7 @@ class AquilaConfig(PretrainedConfig):
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
None
,
hidden_act
=
"silu"
,
max_position_embeddings
=
2048
,
initializer_range
=
0.006
,
...
...
@@ -49,6 +50,11 @@ class AquilaConfig(PretrainedConfig):
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment