Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
7d5a155e
"server/text_generation_server/models/flash_starcoder2.py" did not exist on "299217c95ca314e8dbbeca26ce8cdceb440ec53b"
Unverified
Commit
7d5a155e
authored
Jul 24, 2023
by
Zhuohan Li
Committed by
GitHub
Jul 24, 2023
Browse files
[Fix] Fix GPTBigcoder for distributed execution (#503)
parent
1dde34e0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
62 additions
and
24 deletions
+62
-24
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+62
-24
No files found.
vllm/model_executor/models/gpt_bigcode.py
View file @
7d5a155e
...
...
@@ -54,15 +54,30 @@ class GPTBigCodeAttention(nn.Module):
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
num_kv_heads
=
1
if
config
.
multi_query
else
self
.
num_heads
self
.
kv_dim
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
+
2
*
self
.
kv_dim
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
multi_query
=
config
.
multi_query
if
self
.
multi_query
:
self
.
num_kv_heads
=
1
self
.
kv_dim
=
self
.
head_dim
self
.
c_attn_q
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_attn_kv
=
nn
.
Linear
(
self
.
hidden_size
,
2
*
self
.
kv_dim
,
bias
=
True
)
else
:
self
.
num_kv_heads
=
self
.
num_heads
self
.
kv_dim
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
+
2
*
self
.
kv_dim
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
...
...
@@ -80,9 +95,14 @@ class GPTBigCodeAttention(nn.Module):
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
hidden_size
,
self
.
kv_dim
,
self
.
kv_dim
],
dim
=-
1
)
if
self
.
multi_query
:
q
,
_
=
self
.
c_attn_q
(
hidden_states
)
kv
=
self
.
c_attn_kv
(
hidden_states
)
k
,
v
=
kv
.
split
([
self
.
kv_dim
,
self
.
kv_dim
],
dim
=-
1
)
else
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
hidden_size
,
self
.
kv_dim
,
self
.
kv_dim
],
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
...
...
@@ -251,21 +271,9 @@ class GPTBigCodeForCausalLM(nn.Module):
# NOTE: "c_attn.bias" should not be skipped.
continue
param
=
state_dict
[
name
]
if
not
name
.
startswith
(
"transformer."
):
name
=
"transformer."
+
name
if
name
==
"transformer.wte.weight"
:
# Consider padding in the vocab size.
padded_vocab_size
=
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
# For the fused QKV linear layer, manually shard the weights.
if
"c_attn"
in
name
:
# GPT-2's fused QKV has the shape of
...
...
@@ -291,9 +299,39 @@ class GPTBigCodeForCausalLM(nn.Module):
# Split the heads when using normal multi-head attention
wk
=
wk
[
head_size
*
head_start
:
head_size
*
head_end
]
wv
=
wv
[
head_size
*
head_start
:
head_size
*
head_end
]
# Else, keep the weights as is for multi-query attention
loaded_weight
=
torch
.
cat
([
wq
,
wk
,
wv
],
dim
=
0
)
else
:
# For multi-query attention, we split the query
# but replicate the key and value.
loaded_weight_q
=
wq
loaded_weight_kv
=
torch
.
cat
([
wk
,
wv
],
dim
=
0
)
q_weight_name
=
name
.
replace
(
"c_attn"
,
"c_attn_q"
)
kv_weight_name
=
name
.
replace
(
"c_attn"
,
"c_attn_kv"
)
load_tensor_parallel_weights
(
state_dict
[
q_weight_name
],
loaded_weight_q
,
q_weight_name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
tensor_model_parallel_rank
)
load_tensor_parallel_weights
(
state_dict
[
kv_weight_name
],
loaded_weight_kv
,
kv_weight_name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
tensor_model_parallel_rank
)
continue
param
=
state_dict
[
name
]
loaded_weight
=
torch
.
cat
([
wq
,
wk
,
wv
],
dim
=
0
)
if
name
==
"transformer.wte.weight"
:
# Consider padding in the vocab size.
padded_vocab_size
=
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment