Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4dd830e
Unverified
Commit
f4dd830e
authored
Oct 05, 2024
by
youkaichao
Committed by
GitHub
Oct 05, 2024
Browse files
[core] use forward context for flash infer (#9097)
parent
5df18348
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
127 additions
and
67 deletions
+127
-67
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+127
-67
No files found.
vllm/attention/backends/flashinfer.py
View file @
f4dd830e
...
@@ -26,6 +26,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
...
@@ -26,6 +26,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.forward_context
import
get_forward_context
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
make_tensor_with_pad
)
...
@@ -761,10 +762,53 @@ class FlashInferImpl(AttentionImpl):
...
@@ -761,10 +762,53 @@ class FlashInferImpl(AttentionImpl):
"encoder/decoder cross-attention "
"encoder/decoder cross-attention "
"are not implemented for "
"are not implemented for "
"FlashInferImpl"
)
"FlashInferImpl"
)
return
torch
.
ops
.
vllm
.
unified_flash_infer
(
query
,
key
,
value
,
self
.
num_heads
,
self
.
head_size
,
self
.
num_kv_heads
,
kv_cache
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
self
.
scale
,
self
.
sliding_window
,
self
.
alibi_slopes
,
self
.
logits_soft_cap
,
)
@
torch
.
library
.
custom_op
(
"vllm::unified_flash_infer"
,
mutates_args
=
[
"kv_cache"
])
def
unified_flash_infer
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
current_metadata
=
get_forward_context
()
assert
current_metadata
is
not
None
assert
isinstance
(
current_metadata
,
FlashInferMetadata
)
attn_metadata
:
FlashInferMetadata
=
current_metadata
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
if
attn_metadata
.
num_prefill_tokens
>
0
:
if
attn_metadata
.
num_prefill_tokens
>
0
:
assert
attn_metadata
.
num_decode_tokens
==
0
,
(
assert
attn_metadata
.
num_decode_tokens
==
0
,
(
...
@@ -780,19 +824,18 @@ class FlashInferImpl(AttentionImpl):
...
@@ -780,19 +824,18 @@ class FlashInferImpl(AttentionImpl):
kv_cache
[:,
0
],
kv_cache
[:,
0
],
kv_cache
[:,
1
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
kv_cache_dtype
,
k_scale
,
k_scale
,
v_scale
,
v_scale
,
)
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
# to process the cache when the kv_cache_dtype is fp8
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
kv_cache_dtype
.
startswith
(
"fp8"
):
torch_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
torch_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
kv_cache_dtype
)
kv_cache_dtype
)
kv_cache
=
kv_cache
.
view
(
torch_dtype
)
kv_cache
=
kv_cache
.
view
(
torch_dtype
)
query
=
query
.
contiguous
(
query
=
query
.
contiguous
()
# Flashinfer requires query to be contiguous
)
# Flashinfer requires query to be contiguous
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# We will use flash attention for prefill
# We will use flash attention for prefill
# when kv_cache is not provided.
# when kv_cache is not provided.
...
@@ -807,27 +850,44 @@ class FlashInferImpl(AttentionImpl):
...
@@ -807,27 +850,44 @@ class FlashInferImpl(AttentionImpl):
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
s
elf
.
scale
,
softmax_scale
=
s
oftmax_
scale
,
causal
=
True
,
causal
=
True
,
window_size
=
self
.
sliding_
window
,
window_size
=
window
_size
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
)
)
else
:
else
:
assert
prefill_meta
is
not
None
assert
prefill_meta
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
query
,
query
,
kv_cache
,
logits_soft_cap
=
logits_soft_cap
,
causal
=
True
)
kv_cache
,
logits_soft_cap
=
self
.
logits_soft_cap
,
causal
=
True
)
else
:
else
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
query
,
query
,
kv_cache
,
kv_cache
,
sm_scale
=
s
elf
.
scale
,
sm_scale
=
s
oftmax_
scale
,
logits_soft_cap
=
self
.
logits_soft_cap
,
logits_soft_cap
=
logits_soft_cap
,
k_scale
=
k_scale
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
v_scale
=
v_scale
)
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
@
unified_flash_infer
.
register_fake
def
_
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
kv_cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
).
contiguous
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment