Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
a60b3530
Unverified
Commit
a60b3530
authored
Oct 02, 2023
by
Zhuohan Li
Committed by
GitHub
Oct 02, 2023
Browse files
support sharding llama2-70b on more than 8 GPUs (#1209)
Co-authored-by:
JiCheng
<
247153481@qq.com
>
parent
ebe4d1db
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
14 deletions
+29
-14
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+29
-14
No files found.
vllm/model_executor/models/llama.py
View file @
a60b3530
...
@@ -103,8 +103,16 @@ class LlamaAttention(nn.Module):
...
@@ -103,8 +103,16 @@ class LlamaAttention(nn.Module):
assert
self
.
total_num_heads
%
tp_size
==
0
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
self
.
total_num_kv_heads
=
num_kv_heads
assert
self
.
total_num_kv_heads
%
tp_size
==
0
if
self
.
total_num_kv_heads
>=
tp_size
:
self
.
num_kv_heads
=
self
.
total_num_kv_heads
//
tp_size
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
num_kv_heads_replicas
=
max
(
1
,
tp_size
//
self
.
total_num_kv_heads
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
...
@@ -114,7 +122,8 @@ class LlamaAttention(nn.Module):
...
@@ -114,7 +122,8 @@ class LlamaAttention(nn.Module):
self
.
qkv_proj
=
ParallelLinear
.
column
(
self
.
qkv_proj
=
ParallelLinear
.
column
(
hidden_size
,
hidden_size
,
(
self
.
total_num_heads
+
2
*
self
.
total_num_kv_heads
)
*
(
self
.
total_num_heads
+
2
*
self
.
total_num_kv_heads
*
num_kv_heads_replicas
)
*
self
.
head_dim
,
self
.
head_dim
,
bias
=
False
,
bias
=
False
,
gather_output
=
False
,
gather_output
=
False
,
...
@@ -323,11 +332,15 @@ class LlamaForCausalLM(nn.Module):
...
@@ -323,11 +332,15 @@ class LlamaForCausalLM(nn.Module):
row_parallel_weights
.
append
(
f
"
{
layer
}
.
{
suffix
}
"
)
row_parallel_weights
.
append
(
f
"
{
layer
}
.
{
suffix
}
"
)
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
t
ensor_model_parallel
_rank
=
get_tensor_model_parallel_rank
()
t
p
_rank
=
get_tensor_model_parallel_rank
()
q_proj_shard_size
=
(
self
.
config
.
hidden_size
//
tp_size
)
q_proj_shard_size
=
(
self
.
config
.
hidden_size
//
tp_size
)
num_kv_heads_replicas
=
max
(
1
,
tp_size
//
self
.
config
.
num_key_value_heads
)
num_kv_heads_per_gpu
=
max
(
1
,
self
.
config
.
num_key_value_heads
//
tp_size
)
kv_proj_shard_size
=
(
self
.
config
.
hidden_size
//
kv_proj_shard_size
=
(
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
*
self
.
config
.
num_attention_heads
*
self
.
config
.
num_key_value_heads
//
tp_size
)
num_kv_heads_per_gpu
)
attention_weight_specs
=
[
attention_weight_specs
=
[
# (weight_name, shard_size, offset)
# (weight_name, shard_size, offset)
(
"q_proj"
,
q_proj_shard_size
,
0
),
(
"q_proj"
,
q_proj_shard_size
,
0
),
...
@@ -363,9 +376,13 @@ class LlamaForCausalLM(nn.Module):
...
@@ -363,9 +376,13 @@ class LlamaForCausalLM(nn.Module):
shard_size
//=
self
.
quant_config
.
pack_factor
shard_size
//=
self
.
quant_config
.
pack_factor
offset
//=
self
.
quant_config
.
pack_factor
offset
//=
self
.
quant_config
.
pack_factor
loaded_weight
=
loaded_weight
[
if
weight_name
in
[
"k_proj"
,
"v_proj"
]:
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
shard_id
=
tp_rank
//
num_kv_heads_replicas
(
tensor_model_parallel_rank
+
1
)]
else
:
shard_id
=
tp_rank
loaded_weight
=
loaded_weight
[
shard_size
*
shard_id
:
shard_size
*
(
shard_id
+
1
)]
param_slice
=
param
.
data
[
offset
:
offset
+
shard_size
]
param_slice
=
param
.
data
[
offset
:
offset
+
shard_size
]
assert
param_slice
.
shape
==
loaded_weight
.
shape
assert
param_slice
.
shape
==
loaded_weight
.
shape
...
@@ -384,9 +401,8 @@ class LlamaForCausalLM(nn.Module):
...
@@ -384,9 +401,8 @@ class LlamaForCausalLM(nn.Module):
param
=
param
.
T
param
=
param
.
T
shard_size
=
param
.
shape
[
0
]
//
2
shard_size
=
param
.
shape
[
0
]
//
2
loaded_weight
=
loaded_weight
[
loaded_weight
=
loaded_weight
[
shard_size
*
tp_rank
:
shard_size
*
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tp_rank
+
1
)]
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
assert
param_slice
.
shape
==
loaded_weight
.
shape
...
@@ -402,10 +418,9 @@ class LlamaForCausalLM(nn.Module):
...
@@ -402,10 +418,9 @@ class LlamaForCausalLM(nn.Module):
if
"embed_tokens"
in
name
or
"lm_head"
in
name
:
if
"embed_tokens"
in
name
or
"lm_head"
in
name
:
load_padded_tensor_parallel_vocab
(
param
,
loaded_weight
,
load_padded_tensor_parallel_vocab
(
param
,
loaded_weight
,
t
ensor_model_parallel
_rank
)
t
p
_rank
)
continue
continue
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
column_parallel_weights
,
column_parallel_weights
,
row_parallel_weights
,
row_parallel_weights
,
tp_rank
)
tensor_model_parallel_rank
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment