Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4dd830e
Unverified
Commit
f4dd830e
authored
Oct 05, 2024
by
youkaichao
Committed by
GitHub
Oct 05, 2024
Browse files
[core] use forward context for flash infer (#9097)
parent
5df18348
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
127 additions
and
67 deletions
+127
-67
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+127
-67
No files found.
vllm/attention/backends/flashinfer.py
View file @
f4dd830e
...
...
@@ -26,6 +26,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.forward_context
import
get_forward_context
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
...
...
@@ -761,73 +762,132 @@ class FlashInferImpl(AttentionImpl):
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
attn_metadata
.
num_prefill_tokens
>
0
:
assert
attn_metadata
.
num_decode_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
attn_metadata
.
num_decode_tokens
>
0
:
assert
attn_metadata
.
num_prefill_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
kv_cache
.
numel
()
>
0
:
# Use the same reshape and cache kernel as flash attention.
ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[:,
0
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
return
torch
.
ops
.
vllm
.
unified_flash_infer
(
query
,
key
,
value
,
self
.
num_heads
,
self
.
head_size
,
self
.
num_kv_heads
,
kv_cache
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
self
.
scale
,
self
.
sliding_window
,
self
.
alibi_slopes
,
self
.
logits_soft_cap
,
)
@
torch
.
library
.
custom_op
(
"vllm::unified_flash_infer"
,
mutates_args
=
[
"kv_cache"
])
def
unified_flash_infer
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
current_metadata
=
get_forward_context
()
assert
current_metadata
is
not
None
assert
isinstance
(
current_metadata
,
FlashInferMetadata
)
attn_metadata
:
FlashInferMetadata
=
current_metadata
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
if
attn_metadata
.
num_prefill_tokens
>
0
:
assert
attn_metadata
.
num_decode_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
attn_metadata
.
num_decode_tokens
>
0
:
assert
attn_metadata
.
num_prefill_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
kv_cache
.
numel
()
>
0
:
# Use the same reshape and cache kernel as flash attention.
ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[:,
0
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
,
k_scale
,
v_scale
,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if
kv_cache_dtype
.
startswith
(
"fp8"
):
torch_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
kv_cache_dtype
)
kv_cache
=
kv_cache
.
view
(
torch_dtype
)
query
=
query
.
contiguous
()
# Flashinfer requires query to be contiguous
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if
kv_cache
.
numel
()
==
0
:
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
torch_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
kv_cache_dtype
)
kv_cache
=
kv_cache
.
view
(
torch_dtype
)
query
=
query
.
contiguous
(
)
# Flashinfer requires query to be contiguous
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if
kv_cache
.
numel
()
==
0
:
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
)
else
:
assert
prefill_meta
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
query
,
kv_cache
,
logits_soft_cap
=
self
.
logits_soft_cap
,
causal
=
True
)
else
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
query
,
kv_cache
,
sm_scale
=
self
.
scale
,
logits_soft_cap
=
self
.
logits_soft_cap
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
return
output
.
view
(
num_tokens
,
hidden_size
)
assert
prefill_meta
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
query
,
kv_cache
,
logits_soft_cap
=
logits_soft_cap
,
causal
=
True
)
else
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
query
,
kv_cache
,
sm_scale
=
softmax_scale
,
logits_soft_cap
=
logits_soft_cap
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
return
output
.
view
(
num_tokens
,
hidden_size
)
@
unified_flash_infer
.
register_fake
def
_
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
).
contiguous
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment