Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
14006840
Unverified
Commit
14006840
authored
Aug 18, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 18, 2025
Browse files
[V0 Deprecation] Remove V0 FlashInfer attention backend (#22776)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
66032887
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
9 additions
and
1133 deletions
+9
-1133
tests/basic_correctness/test_basic_correctness.py
tests/basic_correctness/test_basic_correctness.py
+1
-8
tests/compile/test_basic_correctness.py
tests/compile/test_basic_correctness.py
+1
-1
tests/core/block/e2e/test_correctness_sliding_window.py
tests/core/block/e2e/test_correctness_sliding_window.py
+2
-6
tests/distributed/test_pp_cudagraph.py
tests/distributed/test_pp_cudagraph.py
+0
-1
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+3
-0
tests/models/quantization/test_fp8.py
tests/models/quantization/test_fp8.py
+1
-4
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+0
-1098
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+1
-15
No files found.
tests/basic_correctness/test_basic_correctness.py
View file @
14006840
...
@@ -12,7 +12,6 @@ import pytest
...
@@ -12,7 +12,6 @@ import pytest
import
torch
import
torch
from
vllm
import
LLM
,
envs
from
vllm
import
LLM
,
envs
from
vllm.platforms
import
current_platform
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
LLMEngineV1
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
LLMEngineV1
from
..conftest
import
HfRunner
,
VllmRunner
from
..conftest
import
HfRunner
,
VllmRunner
...
@@ -78,11 +77,7 @@ def test_models(
...
@@ -78,11 +77,7 @@ def test_models(
"VLLM_USE_V1"
)
and
envs
.
VLLM_USE_V1
:
"VLLM_USE_V1"
)
and
envs
.
VLLM_USE_V1
:
pytest
.
skip
(
"enable_prompt_embeds is not supported in v1."
)
pytest
.
skip
(
"enable_prompt_embeds is not supported in v1."
)
if
backend
==
"FLASHINFER"
and
current_platform
.
is_rocm
():
if
backend
==
"XFORMERS"
and
model
==
"google/gemma-2-2b-it"
:
pytest
.
skip
(
"Flashinfer does not support ROCm/HIP."
)
if
backend
in
(
"XFORMERS"
,
"FLASHINFER"
)
and
model
==
"google/gemma-2-2b-it"
:
pytest
.
skip
(
pytest
.
skip
(
f
"
{
backend
}
does not support gemma2 with full context length."
)
f
"
{
backend
}
does not support gemma2 with full context length."
)
...
@@ -141,8 +136,6 @@ def test_models(
...
@@ -141,8 +136,6 @@ def test_models(
(
"meta-llama/Llama-3.2-1B-Instruct"
,
"mp"
,
""
,
"L4"
,
{}),
(
"meta-llama/Llama-3.2-1B-Instruct"
,
"mp"
,
""
,
"L4"
,
{}),
(
"distilbert/distilgpt2"
,
"ray"
,
""
,
"A100"
,
{}),
(
"distilbert/distilgpt2"
,
"ray"
,
""
,
"A100"
,
{}),
(
"distilbert/distilgpt2"
,
"mp"
,
""
,
"A100"
,
{}),
(
"distilbert/distilgpt2"
,
"mp"
,
""
,
"A100"
,
{}),
(
"distilbert/distilgpt2"
,
"mp"
,
"FLASHINFER"
,
"A100"
,
{}),
(
"meta-llama/Meta-Llama-3-8B"
,
"ray"
,
"FLASHINFER"
,
"A100"
,
{}),
])
])
@
pytest
.
mark
.
parametrize
(
"enable_prompt_embeds"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_prompt_embeds"
,
[
True
,
False
])
def
test_models_distributed
(
def
test_models_distributed
(
...
...
tests/compile/test_basic_correctness.py
View file @
14006840
...
@@ -34,7 +34,7 @@ class TestSetting:
...
@@ -34,7 +34,7 @@ class TestSetting:
model_args
=
[
"--max-model-len"
,
"2048"
],
model_args
=
[
"--max-model-len"
,
"2048"
],
pp_size
=
2
,
pp_size
=
2
,
tp_size
=
2
,
tp_size
=
2
,
attn_backend
=
"FLASH
INFER
"
,
attn_backend
=
"FLASH
_ATTN
"
,
method
=
"generate"
,
method
=
"generate"
,
fullgraph
=
True
,
fullgraph
=
True
,
),
),
...
...
tests/core/block/e2e/test_correctness_sliding_window.py
View file @
14006840
...
@@ -32,7 +32,7 @@ BLOCK_SIZE = 16
...
@@ -32,7 +32,7 @@ BLOCK_SIZE = 16
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"XFORMERS"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"XFORMERS"
])
def
test_sliding_window_retrieval
(
baseline_llm_generator
,
test_llm_generator
,
def
test_sliding_window_retrieval
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
seed
,
backend
,
monkeypatch
):
batch_size
,
seed
,
backend
,
monkeypatch
):
"""
"""
...
@@ -43,8 +43,6 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator,
...
@@ -43,8 +43,6 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator,
Additionally, we compare the results of the v1 and v2 managers.
Additionally, we compare the results of the v1 and v2 managers.
"""
"""
if
backend
==
"FLASHINFER"
and
current_platform
.
is_rocm
():
pytest
.
skip
(
"Flashinfer does not support ROCm/HIP."
)
if
backend
==
"XFORMERS"
and
current_platform
.
is_rocm
():
if
backend
==
"XFORMERS"
and
current_platform
.
is_rocm
():
pytest
.
skip
(
"Xformers does not support ROCm/HIP."
)
pytest
.
skip
(
"Xformers does not support ROCm/HIP."
)
...
@@ -96,7 +94,7 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator,
...
@@ -96,7 +94,7 @@ def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"enable_chunked_prefill"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"enable_chunked_prefill"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"XFORMERS"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"XFORMERS"
])
def
test_sliding_window_chunked_prefill
(
test_llm_generator
,
batch_size
,
seed
,
def
test_sliding_window_chunked_prefill
(
test_llm_generator
,
batch_size
,
seed
,
backend
,
monkeypatch
):
backend
,
monkeypatch
):
"""
"""
...
@@ -107,8 +105,6 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
...
@@ -107,8 +105,6 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
The results with and without chunked prefill are not the same due to
The results with and without chunked prefill are not the same due to
numerical instabilities.
numerical instabilities.
"""
"""
if
backend
==
"FLASHINFER"
and
current_platform
.
is_rocm
():
pytest
.
skip
(
"Flashinfer does not support ROCm/HIP."
)
if
backend
==
"XFORMERS"
and
current_platform
.
is_rocm
():
if
backend
==
"XFORMERS"
and
current_platform
.
is_rocm
():
pytest
.
skip
(
"Xformers does not support ROCm/HIP."
)
pytest
.
skip
(
"Xformers does not support ROCm/HIP."
)
override_backend_env_variable
(
monkeypatch
,
backend
)
override_backend_env_variable
(
monkeypatch
,
backend
)
...
...
tests/distributed/test_pp_cudagraph.py
View file @
14006840
...
@@ -17,7 +17,6 @@ if TYPE_CHECKING:
...
@@ -17,7 +17,6 @@ if TYPE_CHECKING:
])
])
@
pytest
.
mark
.
parametrize
(
"ATTN_BACKEND"
,
[
@
pytest
.
mark
.
parametrize
(
"ATTN_BACKEND"
,
[
"FLASH_ATTN"
,
"FLASH_ATTN"
,
"FLASHINFER"
,
])
])
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
()
def
test_pp_cudagraph
(
def
test_pp_cudagraph
(
...
...
tests/kernels/attention/test_attention_selector.py
View file @
14006840
...
@@ -81,6 +81,9 @@ def test_env(
...
@@ -81,6 +81,9 @@ def test_env(
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
name
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
name
)
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
if
use_mla
else
"0"
)
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
if
use_mla
else
"0"
)
if
name
==
"FLASHINFER"
and
not
use_v1
:
pytest
.
skip
(
"FlashInfer backend is only available on V1 engine"
)
if
device
==
"cpu"
:
if
device
==
"cpu"
:
if
not
use_v1
:
if
not
use_v1
:
pytest
.
skip
(
"CPU backend only supports V1"
)
pytest
.
skip
(
"CPU backend only supports V1"
)
...
...
tests/models/quantization/test_fp8.py
View file @
14006840
...
@@ -32,7 +32,7 @@ from ..utils import check_logprobs_close
...
@@ -32,7 +32,7 @@ from ..utils import check_logprobs_close
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"XFORMERS"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"XFORMERS"
])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
# reset distributed env properly. Use a value > 1 just when you test.
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
1
])
...
@@ -57,9 +57,6 @@ def test_models(
...
@@ -57,9 +57,6 @@ def test_models(
numerical sensitive kernels.
numerical sensitive kernels.
"""
"""
if
backend
==
"FLASHINFER"
and
current_platform
.
is_rocm
():
pytest
.
skip
(
"Flashinfer does not support ROCm/HIP."
)
if
kv_cache_dtype
==
"fp8_e5m2"
and
current_platform
.
is_rocm
():
if
kv_cache_dtype
==
"fp8_e5m2"
and
current_platform
.
is_rocm
():
pytest
.
skip
(
pytest
.
skip
(
f
"
{
kv_cache_dtype
}
is currently not supported on ROCm/HIP."
)
f
"
{
kv_cache_dtype
}
is currently not supported on ROCm/HIP."
)
...
...
vllm/attention/backends/flashinfer.py
deleted
100644 → 0
View file @
66032887
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
from
vllm.multimodal
import
MultiModalPlaceholderMap
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
(
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
,
trtllm_batch_decode_with_kv_cache
)
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
256
*
1024
*
1024
except
ImportError
:
# Avoid turning these types into variables during type checking
if
not
TYPE_CHECKING
:
BatchDecodeWithPagedKVCacheWrapper
=
None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
trtllm_batch_decode_with_kv_cache
=
None
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
0
raise
ImportError
(
"FlashInfer is not installed. Please install it from "
"https://github.com/flashinfer-ai/flashinfer"
)
from
None
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionState
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
from
vllm.utils.flashinfer
import
use_trtllm_attention
logger
=
init_logger
(
__name__
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
class
FlashInferBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"FLASHINFER"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashInferImpl"
]:
return
FlashInferImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
FlashInferMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"FlashInferMetadataBuilder"
]:
return
FlashInferMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"FlashInferState"
]:
return
FlashInferState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
cache_layout
=
FlashInferState
.
get_kv_cache_layout
()
assert
(
cache_layout
in
(
"NHD"
,
"HND"
))
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
if
cache_layout
==
"NHD"
else
(
0
,
1
,
3
,
2
,
4
)
return
stride_order
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
128
,
256
]
@
staticmethod
def
get_fp8_dtype_for_flashinfer
(
kv_cache_dtype
:
str
)
->
torch
.
dtype
:
if
kv_cache_dtype
in
(
"fp8"
,
"fp8_e4m3"
):
return
torch
.
float8_e4m3fn
elif
kv_cache_dtype
==
"fp8_e5m2"
:
return
torch
.
float8_e5m2
else
:
raise
ValueError
(
f
"Unrecognized FP8 dtype:
{
kv_cache_dtype
}
"
)
@
dataclass
class
PerLayerParameters
:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""
window_left
:
int
logits_soft_cap
:
Optional
[
float
]
sm_scale
:
float
def
get_per_layer_parameters
(
vllm_config
:
VllmConfig
)
->
Dict
[
str
,
PerLayerParameters
]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""
layers
=
get_layers_from_vllm_config
(
vllm_config
,
Attention
)
per_layer_params
:
Dict
[
str
,
PerLayerParameters
]
=
{}
for
key
,
layer
in
layers
.
items
():
impl
=
layer
.
impl
assert
isinstance
(
impl
,
FlashInferImpl
)
# Infer hyperparameters from the attention layer
window_size
=
impl
.
sliding_window
window_left
=
window_size
[
0
]
if
window_size
is
not
None
else
-
1
logits_soft_cap
=
impl
.
logits_soft_cap
sm_scale
=
impl
.
scale
per_layer_params
[
key
]
=
PerLayerParameters
(
window_left
,
logits_soft_cap
,
sm_scale
)
return
per_layer_params
def
infer_global_hyperparameters
(
per_layer_params
:
Dict
[
str
,
PerLayerParameters
])
->
PerLayerParameters
:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert
len
(
per_layer_params
)
>
0
,
"No attention layers found in the model."
param_sets
=
list
(
per_layer_params
.
values
())
global_params
=
param_sets
[
0
]
for
params
in
param_sets
:
assert
params
==
global_params
,
(
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`."
)
return
global_params
class
FlashInferState
(
AttentionState
):
def
__init__
(
self
,
runner
):
self
.
runner
=
runner
self
.
_is_graph_capturing
=
False
self
.
_workspace_buffer
=
None
self
.
_decode_wrapper
=
None
self
.
_prefill_wrapper
=
None
# Global hyperparameters shared by all attention layers
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
vllm_config
=
self
.
runner
.
vllm_config
self
.
_kv_cache_layout
=
None
def
_get_workspace_buffer
(
self
):
if
self
.
_workspace_buffer
is
None
:
self
.
_workspace_buffer
=
torch
.
zeros
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
runner
.
device
)
return
self
.
_workspace_buffer
@
staticmethod
def
get_kv_cache_layout
():
from
vllm.v1.attention.backends.utils
import
_KV_CACHE_LAYOUT_OVERRIDE
if
_KV_CACHE_LAYOUT_OVERRIDE
is
not
None
:
logger
.
info_once
(
"Using KV cache layout %s"
,
_KV_CACHE_LAYOUT_OVERRIDE
)
return
_KV_CACHE_LAYOUT_OVERRIDE
cache_layout
=
envs
.
VLLM_KV_CACHE_LAYOUT
if
cache_layout
is
None
:
logger
.
info_once
(
"Using default KV cache layout NHD"
)
return
"NHD"
logger
.
info_once
(
"Using KV cache layout %s"
,
cache_layout
)
return
cache_layout
def
_get_prefill_wrapper
(
self
):
if
self
.
_prefill_wrapper
is
None
:
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
self
.
get_kv_cache_layout
())
return
self
.
_prefill_wrapper
def
_get_decode_wrapper
(
self
):
if
self
.
_decode_wrapper
is
None
:
num_qo_heads
=
(
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
))
num_kv_heads
=
self
.
runner
.
model_config
.
get_num_kv_heads
(
self
.
runner
.
parallel_config
)
use_tensor_cores
=
envs
.
VLLM_FLASHINFER_FORCE_TENSOR_CORES
or
(
num_qo_heads
//
num_kv_heads
>
4
)
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
self
.
get_kv_cache_layout
(),
use_tensor_cores
=
use_tensor_cores
)
return
self
.
_decode_wrapper
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
self
.
_is_graph_capturing
=
True
self
.
_graph_decode_wrapper
=
None
self
.
_graph_slot_mapping
=
torch
.
full
((
max_batch_size
,
),
PAD_SLOT_ID
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
self
.
_graph_seq_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
self
.
_graph_block_tables
=
torch
.
from_numpy
(
self
.
runner
.
graph_block_tables
).
to
(
device
=
self
.
runner
.
device
)
self
.
_graph_decode_workspace_buffer
=
self
.
_get_workspace_buffer
()
self
.
_graph_indices_buffer
=
torch
.
empty
(
max_batch_size
*
self
.
runner
.
cache_config
.
num_gpu_blocks
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
self
.
_graph_indptr_buffer
=
torch
.
empty
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
self
.
_graph_last_page_len_buffer
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
yield
self
.
_is_graph_capturing
=
False
del
self
.
_graph_slot_mapping
del
self
.
_graph_seq_lens
del
self
.
_graph_block_tables
del
self
.
_graph_decode_workspace_buffer
del
self
.
_graph_indices_buffer
del
self
.
_graph_indptr_buffer
del
self
.
_graph_last_page_len_buffer
del
self
.
_graph_decode_wrapper
def
graph_clone
(
self
,
batch_size
:
int
):
assert
self
.
_is_graph_capturing
state
=
self
.
__class__
(
self
.
runner
)
state
.
_workspace_buffer
=
self
.
_graph_decode_workspace_buffer
state
.
_decode_wrapper
=
self
.
_graph_decode_wrapper
state
.
_prefill_wrapper
=
self
.
_get_prefill_wrapper
()
return
state
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
,
is_encoder_decoder_model
:
bool
=
False
):
assert
self
.
_is_graph_capturing
_indptr_buffer
=
self
.
_graph_indptr_buffer
[:
batch_size
+
1
]
_last_page_len_buffer
=
self
.
_graph_last_page_len_buffer
[:
batch_size
]
num_qo_heads
=
(
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
))
num_kv_heads
=
self
.
runner
.
model_config
.
get_num_kv_heads
(
self
.
runner
.
parallel_config
)
use_tensor_cores
=
envs
.
VLLM_FLASHINFER_FORCE_TENSOR_CORES
or
(
num_qo_heads
//
num_kv_heads
>
4
)
self
.
_graph_decode_wrapper
=
\
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
(
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
self
.
get_kv_cache_layout
(),
use_tensor_cores
)
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
runner
.
kv_cache_dtype
)
else
:
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
paged_kv_indptr_tensor_host
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
)
paged_kv_indices_tensor_host
=
torch
.
arange
(
0
,
batch_size
,
dtype
=
torch
.
int32
)
paged_kv_last_page_len_tensor_host
=
torch
.
full
((
batch_size
,
),
self
.
runner
.
block_size
,
dtype
=
torch
.
int32
)
query_start_loc_host
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
)
global_params
=
infer_global_hyperparameters
(
get_per_layer_parameters
(
self
.
vllm_config
))
attn_metadata
=
self
.
runner
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
0
,
seq_lens_tensor
=
self
.
_graph_seq_lens
,
block_tables
=
self
.
_graph_block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor_host
,
paged_kv_indices
=
paged_kv_indices_tensor_host
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor_host
,
num_qo_heads
=
num_qo_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim
=
self
.
runner
.
model_config
.
get_head_size
(),
page_size
=
self
.
runner
.
block_size
,
seq_start_loc
=
None
,
query_start_loc
=
query_start_loc_host
,
device
=
self
.
runner
.
device
,
data_type
=
kv_cache_dtype
,
q_data_type
=
self
.
runner
.
model_config
.
dtype
,
use_cuda_graph
=
True
,
decode_wrapper
=
self
.
_graph_decode_wrapper
,
prefill_wrapper
=
None
,
**
dataclasses
.
asdict
(
global_params
),
)
attn_metadata
.
begin_forward
()
return
attn_metadata
def
get_graph_input_buffers
(
self
,
attn_metadata
,
is_encoder_decoder_model
:
bool
=
False
):
return
{
"block_tables"
:
attn_metadata
.
block_tables
,
"seq_lens_tensor"
:
attn_metadata
.
seq_lens_tensor
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
}
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
,
is_encoder_decoder_model
:
bool
=
False
):
# FlashInfer-specific logic: copy additional tensors
num_total_blocks
=
attn_metadata
.
decode_metadata
.
seq_lens_tensor
.
shape
[
0
]
input_buffers
[
"seq_lens_tensor"
][:
num_total_blocks
].
copy_
(
attn_metadata
.
seq_lens_tensor
,
non_blocking
=
True
)
input_buffers
[
"block_tables"
][:
num_total_blocks
].
copy_
(
attn_metadata
.
block_tables
,
non_blocking
=
True
)
def
begin_forward
(
self
,
model_input
):
assert
not
self
.
_is_graph_capturing
state
=
self
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
is_decode
=
model_input
.
attn_metadata
.
num_prefills
==
0
# In case of multistep chunked-prefill, there might be prefill requests
# scheduled while CUDA graph mode is enabled. We don't run graph in that
# case.
if
use_cuda_graph
and
is_decode
:
if
model_input
.
inputs_embeds
is
None
:
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
state
=
(
self
.
runner
.
graph_runners
[
model_input
.
virtual_engine
][(
batch_size
,
False
)].
attn_state
)
else
:
batch_size
=
model_input
.
inputs_embeds
.
shape
[
0
]
state
=
(
self
.
runner
.
graph_runners
[
model_input
.
virtual_engine
][(
batch_size
,
True
)].
attn_state
)
model_input
.
attn_metadata
.
prefill_wrapper
=
state
.
_get_prefill_wrapper
(
)
model_input
.
attn_metadata
.
decode_wrapper
=
state
.
_get_decode_wrapper
()
model_input
.
attn_metadata
.
begin_forward
()
@
dataclass
class
FlashInferMetadata
(
AttentionMetadata
):
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len
:
int
max_decode_seq_len
:
int
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len
:
Optional
[
int
]
=
1
use_cuda_graph
:
bool
=
True
prefill_wrapper
:
Optional
[
BatchPrefillWithPagedKVCacheWrapper
]
=
None
decode_wrapper
:
Optional
[
BatchDecodeWithPagedKVCacheWrapper
]
=
None
# Metadata for the prefill stage
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# used for GPU operations
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
block_table_bound
:
Optional
[
torch
.
Tensor
]
=
None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
# The page indices of the paged kv cache
paged_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len
:
Optional
[
torch
.
Tensor
]
=
None
# The number of query/output heads
num_qo_heads
:
Optional
[
int
]
=
None
# The number of key/value heads
num_kv_heads
:
Optional
[
int
]
=
None
# The dimension of the attention heads
head_dim
:
Optional
[
int
]
=
None
# Block size of vllm
page_size
:
Optional
[
int
]
=
None
# The data type of the paged kv cache
data_type
:
torch
.
dtype
=
None
# The data type of the query
q_data_type
:
torch
.
dtype
=
None
# FlashInfer 0.2 encourages passing host tensors
device
:
torch
.
device
=
torch
.
device
(
"cpu"
)
is_profile_run
:
bool
=
False
# The FlashInfer backend currently supports only models in which all layers
# share the same following hyperparameters:
# The left (inclusive) window size for the attention window, when
# set to `-1`, the window size will be set to the full length of
# the sequence. Defaults to `-1`.
window_left
:
int
=
-
1
# The attention logits soft capping value (used in Gemini, Grok and
# Gemma-2, etc.), if not provided, will be set to `0`. If greater
# than 0, the logits will be capped according to formula:
# $$\texttt{logits\_soft\_cap} \times
# \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$,
# where $x$ is the input logits.
logits_soft_cap
:
Optional
[
float
]
=
None
# The scale used in softmax, if not provided, will be set to
# `1.0 / sqrt(head_dim)`.
sm_scale
:
Optional
[
float
]
=
None
def
__post_init__
(
self
):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
supported_head_sizes
=
FlashInferBackend
.
get_supported_head_sizes
()
if
self
.
head_dim
is
not
None
and
self
.
head_dim
\
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Only
{
supported_head_sizes
}
are supported for head_dim,"
,
f
" received
{
self
.
head_dim
}
."
)
def
begin_forward
(
self
):
if
self
.
num_prefill_tokens
>
0
:
if
self
.
paged_kv_indices
is
None
:
return
assert
self
.
prefill_wrapper
is
not
None
assert
self
.
query_start_loc
is
not
None
assert
self
.
paged_kv_indices
is
not
None
assert
self
.
paged_kv_indptr
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
assert
self
.
block_table_bound
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
self
.
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
]
batch_size
=
self
.
query_start_loc
.
shape
[
0
]
-
1
assert
batch_size
>=
0
# We will use flash attention for profiling to
# determine the number of blocks. Therefore,
# we don't need to prepare the input for flashinfer for profile run.
if
not
self
.
is_profile_run
:
self
.
paged_kv_indptr
=
self
.
paged_kv_indptr
.
to
(
self
.
device
)
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
device
)
self
.
block_table_bound
=
self
.
block_table_bound
.
to
(
self
.
device
)
self
.
seq_lens_tensor
=
self
.
seq_lens_tensor
.
to
(
self
.
device
)
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
self
.
prefill_wrapper
.
plan
(
self
.
query_start_loc
,
self
.
paged_kv_indptr
[:
self
.
num_prefills
+
1
],
self
.
paged_kv_indices
,
self
.
paged_kv_last_page_len
[:
self
.
num_prefills
],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
causal
=
True
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
data_type
)
if
self
.
num_decode_tokens
>
0
:
assert
self
.
paged_kv_indices
is
not
None
assert
self
.
paged_kv_indptr
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
self
.
paged_kv_indptr
=
self
.
paged_kv_indptr
.
to
(
self
.
device
)
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
device
)
# handle model warmup path
if
self
.
block_table_bound
is
not
None
:
self
.
block_table_bound
=
self
.
block_table_bound
.
to
(
self
.
device
)
if
self
.
seq_lens_tensor
is
not
None
:
self
.
seq_lens_tensor
=
self
.
seq_lens_tensor
.
to
(
self
.
device
)
assert
self
.
decode_wrapper
is
not
None
self
.
decode_wrapper
.
plan
(
self
.
paged_kv_indptr
[
self
.
num_prefills
:],
self
.
paged_kv_indices
,
self
.
paged_kv_last_page_len
[
self
.
num_prefills
:],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
logits_soft_cap
,
sm_scale
=
self
.
sm_scale
,
# kv-cache data type.
kv_data_type
=
self
.
data_type
,
# query data type.
q_data_type
=
self
.
q_data_type
)
def
asdict_zerocopy
(
self
,
skip_fields
:
Optional
[
Set
[
str
]]
=
None
)
->
Dict
[
str
,
Any
]:
if
skip_fields
is
None
:
skip_fields
=
set
()
# We need to skip the prefill/decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields
.
add
(
'prefill_wrapper'
)
skip_fields
.
add
(
'decode_wrapper'
)
return
super
().
asdict_zerocopy
(
skip_fields
)
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"FlashInferMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
return
self
@
property
def
decode_metadata
(
self
)
->
Optional
[
"FlashInferMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
return
self
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
# Global hyperparameters shared by all attention layers
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
vllm_config
=
self
.
runner
.
vllm_config
def
prepare
(
self
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
self
.
paged_kv_indices
:
List
[
int
]
=
[]
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
self
.
paged_kv_indptr
:
List
[
int
]
=
[
0
]
# paged_kv_last_page_len is the length of the last page of each request
self
.
paged_kv_last_page_len
:
List
[
int
]
=
[]
self
.
total_blocks
=
0
self
.
is_profile_run
:
bool
=
False
if
self
.
global_hyperparameters
is
None
:
# Infer global hyperparameters, since currently we only support
# models in which all layers share the same values for the
# following hyperparameters:
# - `window_left`
# - `logits_soft_cap`
# - `sm_scale`
inferred_params
=
infer_global_hyperparameters
(
get_per_layer_parameters
(
self
.
vllm_config
))
self
.
global_hyperparameters
=
inferred_params
self
.
window_left
=
inferred_params
.
window_left
self
.
logits_soft_cap
=
inferred_params
.
logits_soft_cap
self
.
sm_scale
=
inferred_params
.
sm_scale
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
inter_data
.
is_prompt
block_tables
=
inter_data
.
block_tables
computed_block_nums
=
inter_data
.
computed_block_nums
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
mm_maps
=
inter_data
.
multi_modal_placeholder_maps
if
mm_maps
:
for
modality
,
placeholders
in
mm_maps
.
items
():
self
.
multimodal_placeholder_maps
[
modality
].
extend
(
placeholders
)
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
inter_data
.
prefix_cache_hit
:
block_table
=
computed_block_nums
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
is_profile_run
=
is_block_tables_empty
(
block_tables
)
# Compute slot mapping.
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
inter_data
.
block_tables
)
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will
# create dummy inputs.
if
is_profile_run
:
self
.
is_profile_run
=
is_profile_run
return
block_table
=
block_tables
[
seq_id
]
self
.
_update_paged_kv_tensors
(
block_table
,
seq_len
)
def
_update_paged_kv_tensors
(
self
,
block_table
:
List
[
int
],
seq_len
:
int
):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
self
.
total_blocks
+=
len
(
block_table
)
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
self
.
paged_kv_last_page_len
.
append
(
last_page_len
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
decode_query_len
=
max
(
query_lens
[
self
.
num_prefills
:],
default
=
1
)
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
-
self
.
num_prefill_tokens
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
self
.
runner
.
graph_block_tables
[:
batch_size
]
max_blocks
=
input_block_tables
.
shape
[
1
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
num_blocks
=
len
(
block_table
)
if
num_blocks
<=
max_blocks
:
input_block_tables
[
i
,
:
num_blocks
]
=
block_table
else
:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables
[
i
,
:
max_blocks
]
=
block_table
[:
max_blocks
]
block_tables
=
torch
.
from_numpy
(
input_block_tables
).
to
(
device
,
non_blocking
=
True
)
last_paged_kv_indptr
=
self
.
paged_kv_indptr
[
-
1
]
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
cuda_graph_pad_size
)
self
.
paged_kv_last_page_len
.
extend
([
0
]
*
cuda_graph_pad_size
)
else
:
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
device
is
not
None
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
query_lens_tensor
=
async_tensor_h2d
(
query_lens
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
if
len
(
self
.
paged_kv_indptr
)
>
0
:
# extend to the maximum number of blocks as returned by the
# scheduler
self
.
paged_kv_indices
.
extend
(
[
0
]
*
(
self
.
total_blocks
-
len
(
self
.
paged_kv_indices
)))
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
paged_kv_indptr_tensor
=
torch
.
tensor
(
self
.
paged_kv_indptr
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
paged_kv_last_page_len_tensor
=
torch
.
tensor
(
self
.
paged_kv_last_page_len
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
block_table_bound_tensor
=
torch
.
zeros
(
len
(
self
.
paged_kv_indptr
)
-
1
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
else
:
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_len_tensor
=
None
block_table_bound_tensor
=
None
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
runner
.
kv_cache_dtype
)
else
:
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
return
FlashInferMetadata
(
decode_query_len
=
decode_query_len
,
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
enable_kv_scales_calculation
=
False
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
block_tables
=
block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor
,
block_table_bound
=
block_table_bound_tensor
,
seq_lens_tensor
=
seq_lens_tensor
,
num_qo_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
),
num_kv_heads
=
self
.
runner
.
model_config
.
get_num_kv_heads
(
self
.
runner
.
parallel_config
),
head_dim
=
self
.
runner
.
model_config
.
get_head_size
(),
page_size
=
self
.
block_size
,
seq_start_loc
=
seq_start_loc
,
query_start_loc
=
query_start_loc
,
device
=
device
,
data_type
=
kv_cache_dtype
,
q_data_type
=
self
.
runner
.
model_config
.
dtype
,
use_cuda_graph
=
use_captured_graph
,
is_profile_run
=
self
.
is_profile_run
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
logits_soft_cap
,
sm_scale
=
self
.
sm_scale
,
)
class
FlashInferImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
"FLASHINFER backend."
)
if
use_irope
:
logger
.
warning_once
(
"Using irope in FlashInfer is not supported yet, it will fall"
" back to global attention for long context."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
logits_soft_cap
=
logits_soft_cap
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl"
)
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashInferMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for FlashInferImpl"
)
# TODO: directly write to output tensor
num_heads
:
int
=
self
.
num_heads
head_size
:
int
=
self
.
head_size
num_kv_heads
:
int
=
self
.
num_kv_heads
kv_cache_dtype
:
str
=
self
.
kv_cache_dtype
softmax_scale
:
float
=
self
.
scale
window_size
=
self
.
sliding_window
alibi_slopes
=
self
.
alibi_slopes
logits_soft_cap
=
self
.
logits_soft_cap
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
if
kv_cache
.
numel
()
>
0
:
# Use the same reshape and cache kernel as flash attention.
ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[:,
0
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if
kv_cache_dtype
.
startswith
(
"fp8"
):
torch_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
kv_cache_dtype
)
kv_cache
=
kv_cache
.
view
(
torch_dtype
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
\
f
"key :
{
key
.
shape
}
: #prefill tokens
{
num_prefill_tokens
}
: #decode tokens
{
num_decode_tokens
}
"
# noqa
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
\
f
"value :
{
value
.
shape
}
: #prefill toks
{
num_prefill_tokens
}
: #decode toks
{
num_decode_tokens
}
"
# noqa
query
=
query
.
contiguous
(
)
# Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query
=
query
[
num_prefill_tokens
:]
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
window_left
=
window_size
[
0
]
if
window_size
is
not
None
else
-
1
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
if
num_decode_tokens
>
0
:
decode_output
=
torch
.
empty
(
decode_query
.
shape
,
dtype
=
decode_query
.
dtype
,
device
=
decode_query
.
device
)
else
:
decode_output
=
None
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if
kv_cache
.
numel
()
==
0
:
prefill_output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
)
else
:
assert
prefill_meta
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
assert
prefill_meta
.
prefill_wrapper
.
_causal
assert
prefill_meta
.
prefill_wrapper
.
_window_left
==
window_left
assert
prefill_meta
.
prefill_wrapper
.
_logits_soft_cap
==
(
logits_soft_cap
or
0.0
)
assert
prefill_meta
.
prefill_wrapper
.
_sm_scale
==
softmax_scale
prefill_output
=
prefill_meta
.
prefill_wrapper
.
run
(
query
,
kv_cache
.
permute
(
*
stride_order
),
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
decode_meta
is
not
None
assert
decode_meta
.
decode_wrapper
is
not
None
assert
decode_meta
.
decode_wrapper
.
_window_left
==
window_left
assert
decode_meta
.
decode_wrapper
.
_logits_soft_cap
==
(
logits_soft_cap
or
0.0
)
assert
decode_meta
.
decode_wrapper
.
_sm_scale
==
softmax_scale
# TODO: @pavanimajety Remove this once the switch happens
# inside flashinfer.
if
not
use_trtllm_attention
(
num_decode_tokens
,
attn_metadata
.
max_decode_seq_len
,
kv_cache_dtype
,
attn_metadata
.
num_qo_heads
,
attn_metadata
.
num_kv_heads
,
attn_metadata
.
head_dim
):
decode_meta
.
decode_wrapper
.
run
(
decode_query
,
kv_cache
.
permute
(
*
stride_order
),
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
out
=
decode_output
,
)
else
:
workspace_buffer
=
(
decode_meta
.
decode_wrapper
.
_float_workspace_buffer
)
assert
FlashInferState
.
get_kv_cache_layout
()
==
"HND"
trtllm_batch_decode_with_kv_cache
(
query
=
decode_query
,
kv_cache
=
kv_cache
.
permute
(
*
stride_order
),
workspace_buffer
=
workspace_buffer
,
block_tables
=
attn_metadata
.
block_tables
,
seq_lens
=
decode_meta
.
seq_lens_tensor
,
max_seq_len
=
attn_metadata
.
max_decode_seq_len
,
bmm1_scale
=
layer
.
_k_scale_float
*
softmax_scale
,
bmm2_scale
=
layer
.
_v_scale_float
,
out
=
decode_output
,
)
if
prefill_output
is
None
and
decode_output
is
not
None
:
# Decode only batch.
output
,
num_tokens
=
decode_output
,
num_decode_tokens
elif
decode_output
is
None
and
prefill_output
is
not
None
:
# Prefill only batch.
output
,
num_tokens
=
prefill_output
,
num_prefill_tokens
else
:
# Chunked prefill batch does not work with speculative decoding in
# FlashInfer backend, so the query length for decode should be 1.
assert
prefill_output
is
not
None
assert
decode_output
is
not
None
assert
decode_meta
is
not
None
assert
decode_meta
.
decode_query_len
==
1
decode_output
=
decode_output
.
squeeze
(
1
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/platforms/cuda.py
View file @
14006840
...
@@ -350,17 +350,7 @@ class CudaPlatformBase(Platform):
...
@@ -350,17 +350,7 @@ class CudaPlatformBase(Platform):
return
FLEX_ATTENTION_V1
return
FLEX_ATTENTION_V1
# Backends for V0 engine
# Backends for V0 engine
if
selected_backend
==
_Backend
.
FLASHINFER
:
if
selected_backend
==
_Backend
.
XFORMERS
:
logger
.
info
(
"Using FlashInfer backend."
)
if
cls
.
has_device_capability
(
100
):
from
vllm.v1.attention.backends.utils
import
(
set_kv_cache_layout
)
logger
.
info_once
(
"Using HND KV cache layout on V1 engine by default for "
"Blackwell (SM 10.0) GPUs."
)
set_kv_cache_layout
(
"HND"
)
return
"vllm.attention.backends.flashinfer.FlashInferBackend"
elif
selected_backend
==
_Backend
.
XFORMERS
:
logger
.
info
(
"Using XFormers backend."
)
logger
.
info
(
"Using XFormers backend."
)
return
"vllm.attention.backends.xformers.XFormersBackend"
return
"vllm.attention.backends.xformers.XFormersBackend"
elif
selected_backend
==
_Backend
.
DUAL_CHUNK_FLASH_ATTN
:
elif
selected_backend
==
_Backend
.
DUAL_CHUNK_FLASH_ATTN
:
...
@@ -416,10 +406,6 @@ class CudaPlatformBase(Platform):
...
@@ -416,10 +406,6 @@ class CudaPlatformBase(Platform):
if
(
fp8_kv_cache
and
not
flash_attn_supports_fp8
()):
if
(
fp8_kv_cache
and
not
flash_attn_supports_fp8
()):
logger
.
info
(
logger
.
info
(
"Cannot use FlashAttention backend for FP8 KV cache."
)
"Cannot use FlashAttention backend for FP8 KV cache."
)
logger
.
warning
(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER"
)
target_backend
=
_Backend
.
XFORMERS
target_backend
=
_Backend
.
XFORMERS
except
ImportError
:
except
ImportError
:
logger
.
info
(
logger
.
info
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment