Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
819b18e7
Unverified
Commit
819b18e7
authored
Nov 21, 2023
by
ljss
Committed by
GitHub
Nov 20, 2023
Browse files
Rewrite torch.repeat_interleave to remove cpu synchronization (#1599)
parent
19849db5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
10 deletions
+20
-10
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+20
-10
No files found.
vllm/model_executor/layers/attention.py
View file @
819b18e7
...
@@ -95,10 +95,15 @@ class PagedAttention(nn.Module):
...
@@ -95,10 +95,15 @@ class PagedAttention(nn.Module):
"""
"""
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Project the key and value tensors to the desired number of heads.
# Project the key and value tensors to the desired number of heads.
key
=
torch
.
repeat_interleave
(
key
,
self
.
num_queries_per_kv
,
dim
=
1
)
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
value
=
torch
.
repeat_interleave
(
value
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
self
.
num_queries_per_kv
,
key
=
key
[:,
:,
dim
=
1
)
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out
=
xops
.
memory_efficient_attention_forward
(
out
=
xops
.
memory_efficient_attention_forward
(
...
@@ -110,7 +115,7 @@ class PagedAttention(nn.Module):
...
@@ -110,7 +115,7 @@ class PagedAttention(nn.Module):
scale
=
self
.
scale
,
scale
=
self
.
scale
,
)
)
# TODO(woosuk): Unnecessary copy. Optimize.
# TODO(woosuk): Unnecessary copy. Optimize.
output
.
copy_
(
out
.
squeeze
(
0
))
output
.
copy_
(
out
.
view_as
(
output
))
return
output
return
output
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
...
@@ -427,10 +432,15 @@ class PagedAttentionWithALiBi(PagedAttention):
...
@@ -427,10 +432,15 @@ class PagedAttentionWithALiBi(PagedAttention):
"""
"""
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Project the key and value tensors to the desired number of heads.
# Project the key and value tensors to the desired number of heads.
key
=
torch
.
repeat_interleave
(
key
,
self
.
num_queries_per_kv
,
dim
=
1
)
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
value
=
torch
.
repeat_interleave
(
value
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
self
.
num_queries_per_kv
,
key
=
key
[:,
:,
dim
=
1
)
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
batch_size
=
input_metadata
.
num_prompts
batch_size
=
input_metadata
.
num_prompts
seq_len
=
input_metadata
.
max_prompt_len
seq_len
=
input_metadata
.
max_prompt_len
...
@@ -443,7 +453,7 @@ class PagedAttentionWithALiBi(PagedAttention):
...
@@ -443,7 +453,7 @@ class PagedAttentionWithALiBi(PagedAttention):
scale
=
self
.
scale
,
scale
=
self
.
scale
,
)
)
# TODO(woosuk): Unnecessary copy. Optimize.
# TODO(woosuk): Unnecessary copy. Optimize.
output
.
copy_
(
out
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
))
output
.
copy_
(
out
.
view
_as
(
output
))
return
output
return
output
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
def
get_alibi_slopes
(
self
)
->
Optional
[
torch
.
Tensor
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment