Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
4f858475
Unverified
Commit
4f858475
authored
Aug 22, 2023
by
zhaoyang-star
Committed by
GitHub
Aug 21, 2023
Browse files
Fix mqa is false case in gpt_bigcode (#806)
parent
65fc1c31
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
8 deletions
+11
-8
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+11
-8
No files found.
vllm/model_executor/models/gpt_bigcode.py
View file @
4f858475
...
@@ -49,10 +49,11 @@ class GPTBigCodeAttention(nn.Module):
...
@@ -49,10 +49,11 @@ class GPTBigCodeAttention(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
total_num_heads
=
config
.
num_attention_heads
total_num_heads
=
config
.
num_attention_heads
tensor_model_parallel_world_size
=
(
self
.
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
get_tensor_model_parallel_world_size
())
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
assert
total_num_heads
%
self
.
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
num_heads
=
(
total_num_heads
//
self
.
tensor_model_parallel_world_size
)
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
scale
=
self
.
head_dim
**-
0.5
...
@@ -101,7 +102,10 @@ class GPTBigCodeAttention(nn.Module):
...
@@ -101,7 +102,10 @@ class GPTBigCodeAttention(nn.Module):
k
,
v
=
kv
.
split
([
self
.
kv_dim
,
self
.
kv_dim
],
dim
=-
1
)
k
,
v
=
kv
.
split
([
self
.
kv_dim
,
self
.
kv_dim
],
dim
=-
1
)
else
:
else
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
hidden_size
,
self
.
kv_dim
,
self
.
kv_dim
],
q
,
k
,
v
=
qkv
.
split
([
self
.
hidden_size
//
self
.
tensor_model_parallel_world_size
,
self
.
kv_dim
,
self
.
kv_dim
],
dim
=-
1
)
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
...
@@ -255,8 +259,6 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -255,8 +259,6 @@ class GPTBigCodeForCausalLM(nn.Module):
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
use_np_cache
:
bool
=
False
):
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
...
@@ -286,7 +288,8 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -286,7 +288,8 @@ class GPTBigCodeForCausalLM(nn.Module):
hidden_size
=
self
.
config
.
hidden_size
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
total_num_heads
head_size
=
hidden_size
//
total_num_heads
total_kv_size
=
head_size
*
total_num_kv_heads
total_kv_size
=
head_size
*
total_num_kv_heads
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
num_heads
=
(
total_num_heads
//
self
.
tensor_model_parallel_world_size
)
head_start
=
tensor_model_parallel_rank
*
num_heads
head_start
=
tensor_model_parallel_rank
*
num_heads
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
...
@@ -326,7 +329,7 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -326,7 +329,7 @@ class GPTBigCodeForCausalLM(nn.Module):
if
name
==
"transformer.wte.weight"
:
if
name
==
"transformer.wte.weight"
:
# Consider padding in the vocab size.
# Consider padding in the vocab size.
padded_vocab_size
=
param
.
shape
[
padded_vocab_size
=
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
0
]
*
self
.
tensor_model_parallel_world_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
loaded_weight
.
shape
[
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment