Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7de49aa8
"tests/vscode:/vscode.git/clone" did not exist on "08b2d845d6261309bfdb46933f872eebe4e2bb31"
Unverified
Commit
7de49aa8
authored
Sep 12, 2024
by
youkaichao
Committed by
GitHub
Sep 12, 2024
Browse files
[torch.compile] hide slicing under custom op for inductor (#8384)
parent
42ffba11
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
35 deletions
+74
-35
tests/compile/test_full_graph.py
tests/compile/test_full_graph.py
+3
-1
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+71
-34
No files found.
tests/compile/test_full_graph.py
View file @
7de49aa8
...
...
@@ -16,5 +16,7 @@ def test_full_graph(model):
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B"
)
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B"
,
enforce_eager
=
True
,
load_format
=
"dummy"
)
llm
.
generate
(
prompts
,
sampling_params
)
vllm/attention/backends/flash_attn.py
View file @
7de49aa8
...
...
@@ -122,6 +122,40 @@ def _(
return
torch
.
empty_like
(
decode_query
)
@
torch
.
library
.
custom_op
(
"vllm::reshape_and_cache_flash"
,
mutates_args
=
[
"kv_cache"
])
def
reshape_and_cache_flash
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
)
->
None
:
"""Inductor cannot deal with inplace operations on views.
See https://github.com/pytorch/pytorch/issues/131192
and https://github.com/pytorch/pytorch/issues/130174
This is a workaround to hide the view operation from the inductor.
"""
return
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[
0
],
kv_cache
[
1
],
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
@
reshape_and_cache_flash
.
register_fake
# type: ignore
def
_
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
)
->
None
:
pass
class
FlashAttentionBackend
(
AttentionBackend
):
@
staticmethod
...
...
@@ -653,11 +687,10 @@ class FlashAttentionImpl(AttentionImpl):
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
ops
.
reshape_and_cache_flash
(
torch
.
ops
.
vllm
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
kv_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
k_scale
,
...
...
@@ -669,7 +702,6 @@ class FlashAttentionImpl(AttentionImpl):
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
...
...
@@ -680,6 +712,9 @@ class FlashAttentionImpl(AttentionImpl):
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
(
kv_cache
is
None
or
prefill_meta
.
block_tables
is
None
...
...
@@ -687,7 +722,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
o
ut
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
prefill_outp
ut
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
@@ -701,14 +736,11 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
# prefix-enabled attention
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
output
[:
num_prefill_tokens
]
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
# noqa
prefill_output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
# noqa
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -725,8 +757,7 @@ class FlashAttentionImpl(AttentionImpl):
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
[
num_prefill_tokens
:]
=
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
(
decode_output
=
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
(
decode_query
.
unsqueeze
(
1
),
key_cache
,
value_cache
,
...
...
@@ -738,5 +769,11 @@ class FlashAttentionImpl(AttentionImpl):
softcap
=
self
.
logits_soft_cap
,
).
squeeze
(
1
)
# Reshape the output tensor.
if
prefill_output
is
None
:
assert
decode_output
is
not
None
return
decode_output
.
view
(
num_decode_tokens
,
hidden_size
)
if
decode_output
is
None
:
assert
prefill_output
is
not
None
return
prefill_output
.
view
(
num_prefill_tokens
,
hidden_size
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
return
output
.
view
(
num_tokens
,
hidden_size
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment