Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dd2a6a82
Unverified
Commit
dd2a6a82
authored
Sep 02, 2024
by
Isotr0py
Committed by
GitHub
Sep 02, 2024
Browse files
[Bugfix] Fix internlm2 tensor parallel inference (#8055)
parent
4ca65a97
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
13 deletions
+34
-13
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+34
-13
No files found.
vllm/model_executor/models/internlm2.py
View file @
dd2a6a82
# -*- coding: utf-8 -*-
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -7,7 +8,10 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -70,20 +74,21 @@ class InternLM2Attention(nn.Module):
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
assert
self
.
total_num_heads
%
self
.
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
self
.
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
if
self
.
total_num_kv_heads
>=
self
.
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
assert
self
.
total_num_kv_heads
%
self
.
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
assert
self
.
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
self
.
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
...
...
@@ -122,11 +127,27 @@ class InternLM2Attention(nn.Module):
quant_config
=
quant_config
)
def
split_qkv
(
self
,
qkv
:
torch
.
Tensor
):
qkv
=
qkv
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
key_value_groups
+
2
,
128
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
[
self
.
key_value_groups
,
1
,
1
],
dim
=
2
)
q
=
q
.
reshape
(
-
1
,
self
.
q_size
)
k
=
k
.
reshape
(
-
1
,
self
.
kv_size
)
v
=
v
.
reshape
(
-
1
,
self
.
kv_size
)
seq_len
=
qkv
.
shape
[
0
]
if
self
.
tp_size
>
1
:
qkv_map
=
[
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
]
*
self
.
tp_size
qkv
=
tensor_model_parallel_all_gather
(
qkv
)
qkv
=
torch
.
split
(
qkv
,
qkv_map
,
dim
=-
1
)
qkv
=
qkv
[::
3
]
+
qkv
[
1
::
3
]
+
qkv
[
2
::
3
]
qkv
=
torch
.
cat
(
qkv
,
dim
=-
1
)
qkv
=
qkv
.
view
(
seq_len
,
self
.
total_num_kv_heads
,
self
.
key_value_groups
+
2
,
self
.
head_dim
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
[
self
.
key_value_groups
,
1
,
1
],
dim
=-
2
)
q
=
q
.
reshape
(
seq_len
,
self
.
q_size
*
self
.
tp_size
)
k
=
k
.
reshape
(
seq_len
,
self
.
kv_size
*
self
.
tp_size
)
v
=
v
.
reshape
(
seq_len
,
self
.
kv_size
*
self
.
tp_size
)
if
self
.
tp_size
>
1
:
splitter
=
partial
(
split_tensor_along_last_dim
,
num_partitions
=
self
.
tp_size
)
q
=
splitter
(
q
)[
self
.
tp_rank
]
k
=
splitter
(
k
)[
self
.
tp_rank
]
v
=
splitter
(
v
)[
self
.
tp_rank
]
return
q
,
k
,
v
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment