Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
6428f1d0
Unverified
Commit
6428f1d0
authored
Dec 12, 2023
by
Megha Agarwal
Committed by
GitHub
Dec 12, 2023
Browse files
Support MPT with GQA (#1938)
Co-authored-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
7e1b21da
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
6 deletions
+28
-6
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+8
-4
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+20
-2
No files found.
vllm/model_executor/layers/attention.py
View file @
6428f1d0
...
@@ -138,7 +138,8 @@ class PagedAttention(nn.Module):
...
@@ -138,7 +138,8 @@ class PagedAttention(nn.Module):
input_metadata
.
attn_bias
=
attn_bias
input_metadata
.
attn_bias
=
attn_bias
else
:
else
:
input_metadata
.
attn_bias
=
_make_alibi_bias
(
input_metadata
.
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
batch_size
,
seq_len
,
query
.
dtype
)
self
.
alibi_slopes
,
self
.
num_kv_heads
,
batch_size
,
seq_len
,
query
.
dtype
)
# TODO(woosuk): Too many view operations. Let's try to reduce them
# TODO(woosuk): Too many view operations. Let's try to reduce them
# in the future for code readability.
# in the future for code readability.
...
@@ -180,31 +181,34 @@ class PagedAttention(nn.Module):
...
@@ -180,31 +181,34 @@ class PagedAttention(nn.Module):
def
_make_alibi_bias
(
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
alibi_slopes
:
torch
.
Tensor
,
num_kv_heads
:
int
,
batch_size
:
int
,
batch_size
:
int
,
seq_len
:
int
,
seq_len
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
)
->
LowerTriangularMaskWithTensorBias
:
)
->
LowerTriangularMaskWithTensorBias
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
,
device
=
"cuda"
)
# NOTE(zhuohan): HF uses
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# the bias below more accurately follows the original ALiBi
# paper.
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
.
to
(
alibi_slopes
.
device
)
# When using custom attention bias, xformers requires the bias to
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
# be sliced from a tensor whose length is a multiple of 8.
padded_len
=
(
seq_len
+
7
)
//
8
*
8
padded_len
=
(
seq_len
+
7
)
//
8
*
8
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
torch
.
empty
(
bias
=
torch
.
empty
(
batch_size
,
batch_size
,
alibi_slopes
.
shape
[
0
]
,
num_heads
,
seq_len
,
seq_len
,
padded_len
,
padded_len
,
device
=
alibi_slopes
.
device
,
device
=
alibi_slopes
.
device
,
dtype
=
dtype
,
dtype
=
dtype
,
)[:,
:,
:,
:
seq_len
].
copy_
(
bias
)
)[:,
:,
:,
:
seq_len
].
copy_
(
bias
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
if
num_heads
!=
num_kv_heads
:
bias
=
bias
.
unflatten
(
1
,
(
num_kv_heads
,
num_heads
//
num_kv_heads
))
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
return
attn_bias
return
attn_bias
...
...
vllm/model_executor/models/mpt.py
View file @
6428f1d0
...
@@ -50,9 +50,14 @@ class MPTAttention(nn.Module):
...
@@ -50,9 +50,14 @@ class MPTAttention(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
d_model
=
config
.
d_model
self
.
total_num_heads
=
config
.
n_heads
self
.
total_num_heads
=
config
.
n_heads
self
.
head_dim
=
self
.
d_model
//
self
.
total_num_heads
self
.
clip_qkv
=
config
.
attn_config
[
"clip_qkv"
]
self
.
clip_qkv
=
config
.
attn_config
[
"clip_qkv"
]
self
.
qk_ln
=
config
.
attn_config
[
"qk_ln"
]
self
.
qk_ln
=
config
.
attn_config
[
"qk_ln"
]
self
.
alibi_bias_max
=
config
.
attn_config
[
"alibi_bias_max"
]
self
.
alibi_bias_max
=
config
.
attn_config
[
"alibi_bias_max"
]
if
"kv_n_heads"
in
config
.
attn_config
:
self
.
total_num_kv_heads
=
config
.
attn_config
[
'kv_n_heads'
]
else
:
self
.
total_num_kv_heads
=
self
.
total_num_heads
assert
not
config
.
attn_config
[
"prefix_lm"
]
assert
not
config
.
attn_config
[
"prefix_lm"
]
assert
config
.
attn_config
[
"alibi"
]
assert
config
.
attn_config
[
"alibi"
]
...
@@ -61,6 +66,7 @@ class MPTAttention(nn.Module):
...
@@ -61,6 +66,7 @@ class MPTAttention(nn.Module):
self
.
d_model
,
self
.
d_model
,
self
.
d_model
//
self
.
total_num_heads
,
self
.
d_model
//
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
not
config
.
no_bias
,
bias
=
not
config
.
no_bias
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
...
@@ -78,6 +84,17 @@ class MPTAttention(nn.Module):
...
@@ -78,6 +84,17 @@ class MPTAttention(nn.Module):
assert
self
.
total_num_heads
%
tp_world_size
==
0
assert
self
.
total_num_heads
%
tp_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
if
self
.
total_num_kv_heads
>=
tp_world_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_world_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_world_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_world_size
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
# Create the alibi slopes and slice them.
# Create the alibi slopes and slice them.
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_start
=
tp_rank
*
self
.
num_heads
...
@@ -91,7 +108,8 @@ class MPTAttention(nn.Module):
...
@@ -91,7 +108,8 @@ class MPTAttention(nn.Module):
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
scaling
,
scaling
,
alibi_slopes
=
alibi_slopes
)
alibi_slopes
=
alibi_slopes
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -105,7 +123,7 @@ class MPTAttention(nn.Module):
...
@@ -105,7 +123,7 @@ class MPTAttention(nn.Module):
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
if
self
.
clip_qkv
is
not
None
:
if
self
.
clip_qkv
is
not
None
:
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
]
,
dim
=-
1
)
if
self
.
qk_ln
:
if
self
.
qk_ln
:
q
=
self
.
q_ln
(
q
)
q
=
self
.
q_ln
(
q
)
k
=
self
.
k_ln
(
k
)
k
=
self
.
k_ln
(
k
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment