Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b98cc28f
Unverified
Commit
b98cc28f
authored
Aug 28, 2024
by
Pavani Majety
Committed by
GitHub
Aug 28, 2024
Browse files
[Core][Kernels] Use FlashInfer backend for FP8 KV Cache when available. (#7798)
Co-authored-by:
Simon Mo
<
simon.mo@hey.com
>
parent
ef9baee3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
249 additions
and
12 deletions
+249
-12
tests/kernels/test_flashinfer.py
tests/kernels/test_flashinfer.py
+222
-6
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+23
-6
vllm/attention/selector.py
vllm/attention/selector.py
+4
-0
No files found.
tests/kernels/test_flashinfer.py
View file @
b98cc28f
...
...
@@ -73,11 +73,14 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
torch
.
inference_mode
def
test_flashinfer_decode_with_paged_kv
(
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
])
->
None
:
def
test_flashinfer_decode_with_paged_kv
(
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
kv_lens
)
...
...
@@ -88,6 +91,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_value_cache
=
torch
.
randn
(
NUM_BLOCKS
,
2
,
block_size
,
...
...
@@ -125,7 +129,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
wrapper
=
flashinfer
.
\
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
(
(
num_query_heads
//
num_kv_heads
)
not
in
(
1
,
2
,
4
,
8
)
)
(
num_query_heads
//
num_kv_heads
)
>
4
)
)
wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
...
...
@@ -249,3 +253,215 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
soft_cap
=
soft_cap
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
132
),
(
5
,
18
)]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
[(
32
,
8
),
(
6
,
1
)])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
def
test_flashinfer_prefill_with_paged_fp8_kv
(
seq_lens
:
List
[
Tuple
[
int
,
int
]],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
])
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
kv_cache_dtype
=
torch
.
float8_e4m3fn
query
=
torch
.
randn
(
sum
(
query_lens
),
num_query_heads
,
head_size
,
dtype
=
dtype
)
NUM_BLOCKS_FP8
=
2048
key_value_cache
=
torch
.
randn
(
NUM_BLOCKS_FP8
,
2
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
key_cache
,
value_cache
=
torch
.
chunk
(
key_value_cache
,
2
,
dim
=
1
)
key_cache
/=
head_size
**
0.5
value_cache
/=
head_size
**
0.5
k_scale
=
key_cache
.
amax
().
item
()
/
448.0
v_scale
=
value_cache
.
amax
().
item
()
/
448.0
kv_cache_fp8
=
torch
.
cat
([
key_cache
/
k_scale
,
value_cache
/
v_scale
],
dim
=
1
).
to
(
kv_cache_dtype
)
assert
(
kv_cache_fp8
.
shape
==
key_value_cache
.
shape
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS_FP8
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
qo_indptr
=
[
0
]
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
seq_len
=
kv_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indptr
.
append
(
kv_indptr
[
-
1
]
+
num_blocks
)
kv_last_page_len
=
seq_len
%
block_size
if
kv_last_page_len
==
0
:
kv_last_page_len
=
block_size
kv_last_page_lens
.
append
(
kv_last_page_len
)
qo_indptr
.
append
(
qo_indptr
[
-
1
]
+
query_lens
[
i
])
qo_indptr
=
torch
.
tensor
(
qo_indptr
,
dtype
=
torch
.
int32
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
wrapper
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
)
output
=
wrapper
.
forward
(
query
,
kv_cache_fp8
,
logits_soft_cap
=
soft_cap
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
.
squeeze
(
1
),
value_cache
=
value_cache
.
squeeze
(
1
),
query_lens
=
query_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
)
del
query
del
block_tables
# verify prefill fp8
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
[(
32
,
8
),
(
64
,
8
),
(
6
,
1
)])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
torch
.
inference_mode
def
test_flashinfer_decode_with_paged_fp8_kv
(
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
)
->
None
:
# test doesn't work for num_heads = (16,16)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
use_tensor_cores
=
(
num_query_heads
//
num_kv_heads
)
>
4
kv_cache_dtype
=
torch
.
float8_e4m3fn
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
NUM_BLOCKS_FP8
=
2048
key_value_cache
=
torch
.
randn
(
NUM_BLOCKS_FP8
,
2
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
key_cache
,
value_cache
=
torch
.
chunk
(
key_value_cache
,
2
,
dim
=
1
)
key_cache
/=
head_size
**
0.5
value_cache
/=
head_size
**
0.5
k_scale
=
key_cache
.
amax
().
item
()
/
448.0
v_scale
=
value_cache
.
amax
().
item
()
/
448.0
key_cache_fp8
=
(
key_cache
/
k_scale
).
to
(
kv_cache_dtype
)
value_cache_fp8
=
(
value_cache
/
v_scale
).
to
(
kv_cache_dtype
)
assert
(
key_cache_fp8
.
shape
[
1
]
==
1
and
value_cache_fp8
.
shape
[
1
]
==
1
)
kv_cache_fp8
=
torch
.
cat
([
key_cache_fp8
,
value_cache_fp8
],
dim
=
1
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS_FP8
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
seq_len
=
kv_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indptr
.
append
(
kv_indptr
[
-
1
]
+
num_blocks
)
kv_last_page_len
=
seq_len
%
block_size
if
kv_last_page_len
==
0
:
kv_last_page_len
=
block_size
kv_last_page_lens
.
append
(
kv_last_page_len
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
\
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
"NONE"
,
data_type
=
dtype
)
output
=
wrapper
.
forward
(
query
,
kv_cache_fp8
,
logits_soft_cap
=
soft_cap
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
key_cache
=
key_value_cache
[:,
0
,
:,
:,
:].
squeeze
(
1
)
value_cache
=
key_value_cache
[:,
1
,
:,
:,
:].
squeeze
(
1
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
)
# Temporary fix: Increasing the tolerance. Seems like a flashinfer issue
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
vllm/attention/backends/flashinfer.py
View file @
b98cc28f
...
...
@@ -83,6 +83,15 @@ class FlashInferBackend(AttentionBackend):
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
128
,
256
]
@
staticmethod
def
get_fp8_dtype_for_flashinfer
(
kv_cache_dtype
:
str
)
->
torch
.
dtype
:
if
kv_cache_dtype
in
(
"fp8"
,
"fp8_e4m3"
):
return
torch
.
float8_e4m3fn
elif
kv_cache_dtype
==
"fp8_e5m2"
:
return
torch
.
float8_e5m2
else
:
return
ValueError
(
f
"Unrecognized FP8 dtype:
{
kv_cache_dtype
}
"
)
class
FlashInferState
(
AttentionState
):
...
...
@@ -177,9 +186,9 @@ class FlashInferState(AttentionState):
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
"NHD"
,
use_tensor_cores
)
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
runner
.
kv_cache_dtype
)
paged_kv_indptr_tensor_host
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
)
...
...
@@ -340,7 +349,7 @@ class FlashInferMetadata(AttentionMetadata):
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
data_type
)
)
def
asdict_zerocopy
(
self
,
skip_fields
:
Optional
[
Set
[
str
]]
=
None
...
...
@@ -366,7 +375,8 @@ class FlashInferMetadata(AttentionMetadata):
def
decode_metadata
(
self
)
->
Optional
[
"FlashInferMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_prefills
>
0
:
assert
self
.
num_decode_tokens
==
0
assert
self
.
num_decode_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
return
None
return
self
...
...
@@ -578,6 +588,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
return
FlashInferMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
...
...
@@ -661,7 +672,6 @@ class FlashInferImpl(AttentionImpl):
if
attn_metadata
.
num_decode_tokens
>
0
:
assert
attn_metadata
.
num_prefill_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
kv_cache
is
not
None
:
# Use the same reshape and cache kernel as flash attention.
ops
.
reshape_and_cache_flash
(
...
...
@@ -674,6 +684,11 @@ class FlashInferImpl(AttentionImpl):
k_scale
,
v_scale
,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache in fp8
torch_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
kv_cache_dtype
)
kv_cache
=
kv_cache
.
view
(
torch_dtype
)
query
=
query
.
contiguous
(
)
# Flashinfer requires query to be contiguous
...
...
@@ -711,5 +726,7 @@ class FlashInferImpl(AttentionImpl):
query
,
kv_cache
,
sm_scale
=
self
.
scale
,
logits_soft_cap
=
self
.
logits_soft_cap
)
logits_soft_cap
=
self
.
logits_soft_cap
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/selector.py
View file @
b98cc28f
...
...
@@ -226,6 +226,10 @@ def which_attn_to_use(
elif
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"fp8"
):
logger
.
info
(
"Cannot use FlashAttention-2 backend for FP8 KV cache."
)
logger
.
warning
(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by set environment "
"VLLM_ATTENTION_BACKEND=FLASHINFER"
)
selected_backend
=
_Backend
.
XFORMERS
elif
block_size
%
16
!=
0
:
logger
.
info
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment