Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
6428f1d0
Unverified
Commit
6428f1d0
authored
Dec 12, 2023
by
Megha Agarwal
Committed by
GitHub
Dec 12, 2023
Browse files
Support MPT with GQA (#1938)
Co-authored-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
7e1b21da
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
6 deletions
+28
-6
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+8
-4
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+20
-2
No files found.
vllm/model_executor/layers/attention.py
View file @
6428f1d0
...
@@ -138,7 +138,8 @@ class PagedAttention(nn.Module):
...
@@ -138,7 +138,8 @@ class PagedAttention(nn.Module):
input_metadata
.
attn_bias
=
attn_bias
input_metadata
.
attn_bias
=
attn_bias
else
:
else
:
input_metadata
.
attn_bias
=
_make_alibi_bias
(
input_metadata
.
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
batch_size
,
seq_len
,
query
.
dtype
)
self
.
alibi_slopes
,
self
.
num_kv_heads
,
batch_size
,
seq_len
,
query
.
dtype
)
# TODO(woosuk): Too many view operations. Let's try to reduce them
# TODO(woosuk): Too many view operations. Let's try to reduce them
# in the future for code readability.
# in the future for code readability.
...
@@ -180,31 +181,34 @@ class PagedAttention(nn.Module):
...
@@ -180,31 +181,34 @@ class PagedAttention(nn.Module):
def
_make_alibi_bias
(
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
alibi_slopes
:
torch
.
Tensor
,
num_kv_heads
:
int
,
batch_size
:
int
,
batch_size
:
int
,
seq_len
:
int
,
seq_len
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
)
->
LowerTriangularMaskWithTensorBias
:
)
->
LowerTriangularMaskWithTensorBias
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
,
device
=
"cuda"
)
# NOTE(zhuohan): HF uses
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# the bias below more accurately follows the original ALiBi
# paper.
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
.
to
(
alibi_slopes
.
device
)
# When using custom attention bias, xformers requires the bias to
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
# be sliced from a tensor whose length is a multiple of 8.
padded_len
=
(
seq_len
+
7
)
//
8
*
8
padded_len
=
(
seq_len
+
7
)
//
8
*
8
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
torch
.
empty
(
bias
=
torch
.
empty
(
batch_size
,
batch_size
,
alibi_slopes
.
shape
[
0
]
,
num_heads
,
seq_len
,
seq_len
,
padded_len
,
padded_len
,
device
=
alibi_slopes
.
device
,
device
=
alibi_slopes
.
device
,
dtype
=
dtype
,
dtype
=
dtype
,
)[:,
:,
:,
:
seq_len
].
copy_
(
bias
)
)[:,
:,
:,
:
seq_len
].
copy_
(
bias
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
if
num_heads
!=
num_kv_heads
:
bias
=
bias
.
unflatten
(
1
,
(
num_kv_heads
,
num_heads
//
num_kv_heads
))
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
return
attn_bias
return
attn_bias
...
...
vllm/model_executor/models/mpt.py
View file @
6428f1d0
...
@@ -50,9 +50,14 @@ class MPTAttention(nn.Module):
...
@@ -50,9 +50,14 @@ class MPTAttention(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
d_model
=
config
.
d_model
self
.
total_num_heads
=
config
.
n_heads
self
.
total_num_heads
=
config
.
n_heads
self
.
head_dim
=
self
.
d_model
//
self
.
total_num_heads
self
.
clip_qkv
=
config
.
attn_config
[
"clip_qkv"
]
self
.
clip_qkv
=
config
.
attn_config
[
"clip_qkv"
]
self
.
qk_ln
=
config
.
attn_config
[
"qk_ln"
]
self
.
qk_ln
=
config
.
attn_config
[
"qk_ln"
]
self
.
alibi_bias_max
=
config
.
attn_config
[
"alibi_bias_max"
]
self
.
alibi_bias_max
=
config
.
attn_config
[
"alibi_bias_max"
]
if
"kv_n_heads"
in
config
.
attn_config
:
self
.
total_num_kv_heads
=
config
.
attn_config
[
'kv_n_heads'
]
else
:
self
.
total_num_kv_heads
=
self
.
total_num_heads
assert
not
config
.
attn_config
[
"prefix_lm"
]
assert
not
config
.
attn_config
[
"prefix_lm"
]
assert
config
.
attn_config
[
"alibi"
]
assert
config
.
attn_config
[
"alibi"
]
...
@@ -61,6 +66,7 @@ class MPTAttention(nn.Module):
...
@@ -61,6 +66,7 @@ class MPTAttention(nn.Module):
self
.
d_model
,
self
.
d_model
,
self
.
d_model
//
self
.
total_num_heads
,
self
.
d_model
//
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
not
config
.
no_bias
,
bias
=
not
config
.
no_bias
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
...
@@ -78,6 +84,17 @@ class MPTAttention(nn.Module):
...
@@ -78,6 +84,17 @@ class MPTAttention(nn.Module):
assert
self
.
total_num_heads
%
tp_world_size
==
0
assert
self
.
total_num_heads
%
tp_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
if
self
.
total_num_kv_heads
>=
tp_world_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_world_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_world_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_world_size
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
# Create the alibi slopes and slice them.
# Create the alibi slopes and slice them.
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_start
=
tp_rank
*
self
.
num_heads
...
@@ -91,7 +108,8 @@ class MPTAttention(nn.Module):
...
@@ -91,7 +108,8 @@ class MPTAttention(nn.Module):
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
scaling
,
scaling
,
alibi_slopes
=
alibi_slopes
)
alibi_slopes
=
alibi_slopes
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -105,7 +123,7 @@ class MPTAttention(nn.Module):
...
@@ -105,7 +123,7 @@ class MPTAttention(nn.Module):
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
if
self
.
clip_qkv
is
not
None
:
if
self
.
clip_qkv
is
not
None
:
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
]
,
dim
=-
1
)
if
self
.
qk_ln
:
if
self
.
qk_ln
:
q
=
self
.
q_ln
(
q
)
q
=
self
.
q_ln
(
q
)
k
=
self
.
k_ln
(
k
)
k
=
self
.
k_ln
(
k
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment