Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3b682179
Unverified
Commit
3b682179
authored
Aug 20, 2024
by
Antoni Baum
Committed by
GitHub
Aug 20, 2024
Browse files
[Core] Add `AttentionState` abstraction (#7663)
parent
c6af027a
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
372 additions
and
247 deletions
+372
-247
tests/worker/test_model_input.py
tests/worker/test_model_input.py
+6
-1
vllm/attention/__init__.py
vllm/attention/__init__.py
+2
-1
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+50
-1
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+6
-1
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+6
-1
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+164
-1
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+5
-0
vllm/attention/backends/openvino.py
vllm/attention/backends/openvino.py
+6
-1
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+5
-0
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+6
-1
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+5
-0
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+73
-2
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+6
-1
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+1
-46
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+3
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+28
-188
No files found.
tests/worker/test_model_input.py
View file @
3b682179
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.worker.embedding_model_runner
import
(
from
vllm.worker.embedding_model_runner
import
(
...
@@ -29,7 +30,11 @@ class MockAttentionBackend(AttentionBackend):
...
@@ -29,7 +30,11 @@ class MockAttentionBackend(AttentionBackend):
@
staticmethod
@
staticmethod
def
get_builder_cls
()
->
Type
[
"AttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
Type
[
"AttentionMetadataBuilder"
]:
raise
AttentionMetadataBuilder
return
AttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
...
vllm/attention/__init__.py
View file @
3b682179
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
AttentionType
)
AttentionState
,
AttentionType
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
...
@@ -12,5 +12,6 @@ __all__ = [
...
@@ -12,5 +12,6 @@ __all__ = [
"AttentionType"
,
"AttentionType"
,
"AttentionMetadataBuilder"
,
"AttentionMetadataBuilder"
,
"Attention"
,
"Attention"
,
"AttentionState"
,
"get_attn_backend"
,
"get_attn_backend"
,
]
]
vllm/attention/backends/abstract.py
View file @
3b682179
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
...
@@ -7,7 +8,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
...
@@ -7,7 +8,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
import
torch
import
torch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
ModelRunnerInputBuilderBase
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
)
class
AttentionType
(
Enum
):
class
AttentionType
(
Enum
):
...
@@ -34,6 +37,11 @@ class AttentionBackend(ABC):
...
@@ -34,6 +37,11 @@ class AttentionBackend(ABC):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
get_state_cls
()
->
Type
[
"AttentionState"
]:
raise
NotImplementedError
@
classmethod
@
classmethod
def
make_metadata
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
def
make_metadata
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
return
cls
.
get_metadata_cls
()(
*
args
,
**
kwargs
)
return
cls
.
get_metadata_cls
()(
*
args
,
**
kwargs
)
...
@@ -126,6 +134,47 @@ class AttentionMetadata:
...
@@ -126,6 +134,47 @@ class AttentionMetadata:
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
class
AttentionState
(
ABC
,
Generic
[
T
]):
"""Holds attention backend-specific objects reused during the
lifetime of the model runner."""
@
abstractmethod
def
__init__
(
self
,
runner
:
"ModelRunnerBase"
):
...
@
abstractmethod
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
"""Context manager used when capturing CUDA graphs."""
yield
@
abstractmethod
def
graph_clone
(
self
,
batch_size
:
int
)
->
"AttentionState[T]"
:
"""Clone attention state to save in CUDA graph metadata."""
...
@
abstractmethod
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
)
->
T
:
"""Get attention metadata for CUDA graph capture of batch_size."""
...
@
abstractmethod
def
get_graph_input_buffers
(
self
,
attn_metadata
:
T
)
->
Dict
[
str
,
Any
]:
"""Get attention-specific input buffers for CUDA graph capture."""
...
@
abstractmethod
def
prepare_graph_input_buffers
(
self
,
input_buffers
:
Dict
[
str
,
Any
],
attn_metadata
:
T
)
->
None
:
"""In-place modify input buffers dict for CUDA graph replay."""
...
@
abstractmethod
def
begin_forward
(
self
,
model_input
:
"ModelRunnerInputBase"
)
->
None
:
"""Prepare state for forward pass."""
...
class
AttentionMetadataBuilder
(
ABC
,
Generic
[
T
]):
class
AttentionMetadataBuilder
(
ABC
,
Generic
[
T
]):
"""Abstract class for attention metadata builders."""
"""Abstract class for attention metadata builders."""
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
3b682179
...
@@ -5,7 +5,8 @@ import torch
...
@@ -5,7 +5,8 @@ import torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
)
from
vllm.attention.ops.blocksparse_attention.interface
import
(
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
...
@@ -98,6 +99,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
...
@@ -98,6 +99,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
def
get_builder_cls
()
->
Type
[
"BlocksparseFlashAttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
Type
[
"BlocksparseFlashAttentionMetadataBuilder"
]:
return
BlocksparseFlashAttentionMetadataBuilder
return
BlocksparseFlashAttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
...
vllm/attention/backends/flash_attn.py
View file @
3b682179
...
@@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
...
@@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
AttentionType
)
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionState
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
...
@@ -142,6 +143,10 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -142,6 +143,10 @@ class FlashAttentionBackend(AttentionBackend):
def
get_builder_cls
()
->
Type
[
"FlashAttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
Type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
return
FlashAttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
...
vllm/attention/backends/flashinfer.py
View file @
3b682179
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
try
:
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
import
vllm.attention.backends.flash_attn
# noqa
import
vllm.attention.backends.flash_attn
# noqa
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
256
*
1024
*
1024
except
ImportError
:
except
ImportError
:
BatchDecodeWithPagedKVCacheWrapper
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
0
import
torch
import
torch
...
@@ -16,7 +21,7 @@ from vllm import _custom_ops as ops
...
@@ -16,7 +21,7 @@ from vllm import _custom_ops as ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
AttentionType
)
AttentionState
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
...
@@ -46,6 +51,10 @@ class FlashInferBackend(AttentionBackend):
...
@@ -46,6 +51,10 @@ class FlashInferBackend(AttentionBackend):
def
get_builder_cls
()
->
Type
[
"FlashInferMetadataBuilder"
]:
def
get_builder_cls
()
->
Type
[
"FlashInferMetadataBuilder"
]:
return
FlashInferMetadataBuilder
return
FlashInferMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"FlashInferState"
]:
return
FlashInferState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -75,6 +84,160 @@ class FlashInferBackend(AttentionBackend):
...
@@ -75,6 +84,160 @@ class FlashInferBackend(AttentionBackend):
return
[
64
,
128
,
256
]
return
[
64
,
128
,
256
]
class
FlashInferState
(
AttentionState
):
def
__init__
(
self
,
runner
):
self
.
runner
=
runner
self
.
_is_graph_capturing
=
False
self
.
_workspace_buffer
=
None
self
.
_decode_wrapper
=
None
self
.
_prefill_wrapper
=
None
def
_get_workspace_buffer
(
self
):
if
self
.
_workspace_buffer
is
None
:
self
.
_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
runner
.
device
)
return
self
.
_workspace_buffer
def
_get_prefill_wrapper
(
self
):
if
self
.
_prefill_wrapper
is
None
:
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
)
return
self
.
_prefill_wrapper
def
_get_decode_wrapper
(
self
):
if
self
.
_decode_wrapper
is
None
:
num_qo_heads
=
(
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
))
num_kv_heads
=
self
.
runner
.
model_config
.
get_num_kv_heads
(
self
.
runner
.
parallel_config
)
use_tensor_cores
=
num_qo_heads
//
num_kv_heads
>=
4
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
return
self
.
_decode_wrapper
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
self
.
_is_graph_capturing
=
True
self
.
_graph_decode_wrapper
=
None
self
.
_graph_slot_mapping
=
torch
.
full
((
max_batch_size
,
),
PAD_SLOT_ID
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
self
.
_graph_seq_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
self
.
_graph_block_tables
=
torch
.
from_numpy
(
self
.
runner
.
graph_block_tables
).
to
(
device
=
self
.
runner
.
device
)
self
.
_graph_decode_workspace_buffer
=
self
.
_get_workspace_buffer
()
self
.
_graph_indices_buffer
=
torch
.
empty
(
max_batch_size
*
self
.
runner
.
cache_config
.
num_gpu_blocks
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
self
.
_graph_indptr_buffer
=
torch
.
empty
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
self
.
_graph_last_page_len_buffer
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
yield
self
.
_is_graph_capturing
=
False
del
self
.
_graph_slot_mapping
del
self
.
_graph_seq_lens
del
self
.
_graph_block_tables
del
self
.
_graph_decode_workspace_buffer
del
self
.
_graph_indices_buffer
del
self
.
_graph_indptr_buffer
del
self
.
_graph_last_page_len_buffer
del
self
.
_graph_decode_wrapper
def
graph_clone
(
self
,
batch_size
:
int
):
assert
self
.
_is_graph_capturing
state
=
self
.
__class__
(
self
.
runner
)
state
.
_workspace_buffer
=
self
.
_graph_decode_workspace_buffer
state
.
_decode_wrapper
=
self
.
_graph_decode_wrapper
state
.
_prefill_wrapper
=
self
.
_get_prefill_wrapper
()
return
state
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
):
assert
self
.
_is_graph_capturing
_indptr_buffer
=
self
.
_graph_indptr_buffer
[:
batch_size
+
1
]
_last_page_len_buffer
=
self
.
_graph_last_page_len_buffer
[:
batch_size
]
num_qo_heads
=
(
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
))
num_kv_heads
=
self
.
runner
.
model_config
.
get_num_kv_heads
(
self
.
runner
.
parallel_config
)
use_tensor_cores
=
num_qo_heads
//
num_kv_heads
>=
4
self
.
_graph_decode_wrapper
=
\
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
(
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
"NHD"
,
use_tensor_cores
)
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
paged_kv_indptr_tensor_host
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
)
paged_kv_indices_tensor_host
=
torch
.
arange
(
0
,
batch_size
,
dtype
=
torch
.
int32
)
paged_kv_last_page_len_tensor_host
=
torch
.
full
((
batch_size
,
),
self
.
runner
.
block_size
,
dtype
=
torch
.
int32
)
query_start_loc_host
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
)
attn_metadata
=
self
.
runner
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
max_prefill_seq_len
=
0
,
block_tables
=
self
.
_graph_block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor_host
,
paged_kv_indices
=
paged_kv_indices_tensor_host
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor_host
,
num_qo_heads
=
num_qo_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim
=
self
.
runner
.
model_config
.
get_head_size
(),
page_size
=
self
.
runner
.
block_size
,
seq_start_loc
=
None
,
query_start_loc
=
query_start_loc_host
,
device
=
self
.
runner
.
device
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
True
,
decode_wrapper
=
self
.
_graph_decode_wrapper
,
prefill_wrapper
=
None
)
attn_metadata
.
begin_forward
()
return
attn_metadata
def
get_graph_input_buffers
(
self
,
attn_metadata
):
return
{
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
}
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
):
return
def
begin_forward
(
self
,
model_input
):
assert
not
self
.
_is_graph_capturing
state
=
self
if
model_input
.
attn_metadata
.
use_cuda_graph
:
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
state
=
(
self
.
runner
.
graph_runners
[
model_input
.
virtual_engine
]
[
batch_size
].
attn_state
)
model_input
.
attn_metadata
.
prefill_wrapper
=
state
.
_get_prefill_wrapper
(
)
model_input
.
attn_metadata
.
decode_wrapper
=
state
.
_get_decode_wrapper
()
model_input
.
attn_metadata
.
begin_forward
()
@
dataclass
@
dataclass
class
FlashInferMetadata
(
AttentionMetadata
):
class
FlashInferMetadata
(
AttentionMetadata
):
# Maximum sequence length among prefill batch. 0 if there are decoding
# Maximum sequence length among prefill batch. 0 if there are decoding
...
...
vllm/attention/backends/ipex_attn.py
View file @
3b682179
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
from
vllm._ipex_ops
import
ipex_ops
from
vllm._ipex_ops
import
ipex_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
...
@@ -28,6 +29,10 @@ class IpexAttnBackend(AttentionBackend):
...
@@ -28,6 +29,10 @@ class IpexAttnBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"IpexAttnMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"IpexAttnMetadata"
]:
return
IpexAttnMetadata
return
IpexAttnMetadata
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
...
vllm/attention/backends/openvino.py
View file @
3b682179
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
,
Type
import
openvino
as
ov
import
openvino
as
ov
import
torch
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
)
AttentionMetadata
)
from
vllm.attention.backends.utils
import
CommonAttentionState
class
OpenVINOAttentionBackend
(
AttentionBackend
):
class
OpenVINOAttentionBackend
(
AttentionBackend
):
...
@@ -24,6 +25,10 @@ class OpenVINOAttentionBackend(AttentionBackend):
...
@@ -24,6 +25,10 @@ class OpenVINOAttentionBackend(AttentionBackend):
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
make_openvino_metadata
(
*
args
,
**
kwargs
)
->
"OpenVINOAttentionMetadata"
:
def
make_openvino_metadata
(
*
args
,
**
kwargs
)
->
"OpenVINOAttentionMetadata"
:
return
OpenVINOAttentionMetadata
(
*
args
,
**
kwargs
)
return
OpenVINOAttentionMetadata
(
*
args
,
**
kwargs
)
...
...
vllm/attention/backends/pallas.py
View file @
3b682179
...
@@ -6,6 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
...
@@ -6,6 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
class
PallasAttentionBackend
(
AttentionBackend
):
class
PallasAttentionBackend
(
AttentionBackend
):
...
@@ -18,6 +19,10 @@ class PallasAttentionBackend(AttentionBackend):
...
@@ -18,6 +19,10 @@ class PallasAttentionBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"PallasMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"PallasMetadata"
]:
return
PallasMetadata
return
PallasMetadata
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
3b682179
...
@@ -7,7 +7,8 @@ import torch
...
@@ -7,7 +7,8 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -33,6 +34,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
...
@@ -33,6 +34,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def
get_builder_cls
()
->
Type
[
"ROCmFlashAttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
Type
[
"ROCmFlashAttentionMetadataBuilder"
]:
return
ROCmFlashAttentionMetadataBuilder
return
ROCmFlashAttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
...
vllm/attention/backends/torch_sdpa.py
View file @
3b682179
...
@@ -8,6 +8,7 @@ from torch.nn.functional import scaled_dot_product_attention
...
@@ -8,6 +8,7 @@ from torch.nn.functional import scaled_dot_product_attention
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.paged_attn
import
PagedAttentionMetadata
from
vllm.attention.ops.paged_attn
import
PagedAttentionMetadata
from
vllm.utils
import
is_cpu
from
vllm.utils
import
is_cpu
...
@@ -34,6 +35,10 @@ class TorchSDPABackend(AttentionBackend):
...
@@ -34,6 +35,10 @@ class TorchSDPABackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
TorchSDPAMetadata
return
TorchSDPAMetadata
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
...
vllm/attention/backends/utils.py
View file @
3b682179
"""Attention backend utils"""
"""Attention backend utils"""
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Type
,
TypeVar
,
Union
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Type
,
TypeVar
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionState
)
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
ModelRunnerBase
# Error string(s) for encoder/decoder
# Error string(s) for encoder/decoder
# unsupported attention scenarios
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
=
(
"ROCm/HIP is not currently supported "
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
=
(
"ROCm/HIP is not currently supported "
...
@@ -269,3 +274,69 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -269,3 +274,69 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
,
)
)
class
CommonAttentionState
(
AttentionState
):
def
__init__
(
self
,
runner
:
"ModelRunnerBase"
):
self
.
runner
=
runner
self
.
_is_graph_capturing
=
False
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
self
.
_is_graph_capturing
=
True
self
.
_graph_slot_mapping
=
torch
.
full
((
max_batch_size
,
),
PAD_SLOT_ID
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
self
.
_graph_seq_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
self
.
_graph_block_tables
=
torch
.
from_numpy
(
self
.
runner
.
graph_block_tables
).
to
(
device
=
self
.
runner
.
device
)
yield
self
.
_is_graph_capturing
=
False
del
self
.
_graph_slot_mapping
del
self
.
_graph_seq_lens
del
self
.
_graph_block_tables
def
graph_clone
(
self
,
batch_size
:
int
)
->
"CommonAttentionState"
:
assert
self
.
_is_graph_capturing
return
self
.
__class__
(
self
.
runner
)
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
):
assert
self
.
_is_graph_capturing
attn_metadata
=
self
.
runner
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
seq_lens
=
None
,
seq_lens_tensor
=
self
.
_graph_seq_lens
[:
batch_size
],
max_query_len
=
None
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
runner
.
max_seq_len_to_capture
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
_graph_block_tables
[:
batch_size
],
use_cuda_graph
=
True
,
)
return
attn_metadata
def
get_graph_input_buffers
(
self
,
attn_metadata
)
->
Dict
[
str
,
Any
]:
return
{
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"seq_lens_tensor"
:
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
)
->
None
:
input_buffers
[
"seq_lens_tensor"
].
copy_
(
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
non_blocking
=
True
)
input_buffers
[
"block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
def
begin_forward
(
self
,
model_input
)
->
None
:
return
vllm/attention/backends/xformers.py
View file @
3b682179
...
@@ -11,7 +11,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
...
@@ -11,7 +11,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -37,6 +38,10 @@ class XFormersBackend(AttentionBackend):
...
@@ -37,6 +38,10 @@ class XFormersBackend(AttentionBackend):
def
get_builder_cls
()
->
Type
[
"XFormersMetadataBuilder"
]:
def
get_builder_cls
()
->
Type
[
"XFormersMetadataBuilder"
]:
return
XFormersMetadataBuilder
return
XFormersMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
...
vllm/spec_decode/draft_model_runner.py
View file @
3b682179
...
@@ -11,17 +11,6 @@ except ModuleNotFoundError:
...
@@ -11,17 +11,6 @@ except ModuleNotFoundError:
from
vllm.attention.backends.rocm_flash_attn
import
(
from
vllm.attention.backends.rocm_flash_attn
import
(
ROCmFlashAttentionMetadata
as
FlashAttentionMetadata
)
ROCmFlashAttentionMetadata
as
FlashAttentionMetadata
)
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
256
*
1024
*
1024
except
ImportError
:
BatchDecodeWithPagedKVCacheWrapper
=
None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
0
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
...
@@ -90,11 +79,6 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -90,11 +79,6 @@ class TP1DraftModelRunner(ModelRunner):
observability_config
=
observability_config
,
observability_config
=
observability_config
,
)
)
self
.
flashinfer_decode_workspace_buffer
=
None
self
.
flashinfer_decode_wrapper
=
None
self
.
flashinfer_prefill_workspace_buffer
=
None
self
.
flashinfer_prefill_wrapper
=
None
def
_update_sampling_metadata
(
self
,
sampling_metadata
,
num_seqs
,
def
_update_sampling_metadata
(
self
,
sampling_metadata
,
num_seqs
,
num_queries
):
num_queries
):
...
@@ -270,36 +254,7 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -270,36 +254,7 @@ class TP1DraftModelRunner(ModelRunner):
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_mapping
)
model_input
.
prompt_adapter_mapping
)
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
self
.
attn_state
.
begin_forward
(
model_input
)
assert
model_input
.
attn_metadata
is
not
None
assert
model_input
.
input_tokens
is
not
None
if
self
.
flashinfer_decode_workspace_buffer
is
None
:
self
.
flashinfer_decode_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
flashinfer_decode_wrapper
=
\
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_decode_workspace_buffer
,
"NHD"
)
self
.
flashinfer_prefill_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
flashinfer_prefill_wrapper
=
\
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_prefill_workspace_buffer
,
"NHD"
)
model_input
.
attn_metadata
.
prefill_wrapper
=
\
self
.
flashinfer_prefill_wrapper
if
model_input
.
attn_metadata
.
use_cuda_graph
:
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_input
.
attn_metadata
.
decode_wrapper
=
\
self
.
graph_runners
[
model_input
.
virtual_engine
][
batch_size
].
flashinfer_decode_wrapper
else
:
model_input
.
attn_metadata
.
decode_wrapper
=
\
self
.
flashinfer_decode_wrapper
model_input
.
attn_metadata
.
begin_forward
()
# Detect exec mode
# Detect exec mode
assert
model_input
.
attn_metadata
is
not
None
assert
model_input
.
attn_metadata
is
not
None
...
...
vllm/worker/enc_dec_model_runner.py
View file @
3b682179
...
@@ -6,6 +6,7 @@ import torch.distributed
...
@@ -6,6 +6,7 @@ import torch.distributed
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
)
AttentionMetadata
)
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.attention.selector
import
(
_Backend
,
get_env_variable_attn_backend
,
from
vllm.attention.selector
import
(
_Backend
,
get_env_variable_attn_backend
,
get_global_forced_attn_backend
,
get_global_forced_attn_backend
,
global_force_attn_backend
)
global_force_attn_backend
)
...
@@ -20,7 +21,7 @@ from vllm.sampling_params import SamplingParams
...
@@ -20,7 +21,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
SamplerOutput
,
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
SamplerOutput
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_BACKEND
,
make_tensor_with_pad
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_BACKEND
,
make_tensor_with_pad
from
vllm.worker.model_runner
import
(
_PAD_SLOT_ID
,
GPUModelRunnerBase
,
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
,
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
...
@@ -395,7 +396,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -395,7 +396,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# initialized yet. In this case, we just use a dummy
# initialized yet. In this case, we just use a dummy
# slot mapping.
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
# In embeddings, the block tables are {seq_id: None}.
cross_slot_mapping
.
extend
([
_
PAD_SLOT_ID
]
*
seq_len
)
cross_slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
seq_len
)
else
:
else
:
for
i
in
range
(
0
,
seq_len
):
for
i
in
range
(
0
,
seq_len
):
block_number
=
seq_group_metadata
.
cross_block_table
[
block_number
=
seq_group_metadata
.
cross_block_table
[
...
...
vllm/worker/model_runner.py
View file @
3b682179
...
@@ -13,19 +13,10 @@ import torch
...
@@ -13,19 +13,10 @@ import torch
import
torch.distributed
import
torch.distributed
import
torch.nn
as
nn
import
torch.nn
as
nn
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
256
*
1024
*
1024
except
ImportError
:
BatchDecodeWithPagedKVCacheWrapper
=
None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
0
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention.backends.abstract
import
AttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
...
@@ -52,8 +43,7 @@ from vllm.sampling_params import SamplingParams
...
@@ -52,8 +43,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
PyObjectCache
,
async_tensor_h2d
,
from
vllm.utils
import
(
CudaMemoryProfiler
,
PyObjectCache
,
async_tensor_h2d
,
flatten_2d_lists
,
get_kv_cache_torch_dtype
,
is_hip
,
flatten_2d_lists
,
is_hip
,
is_pin_memory_available
)
is_pin_memory_available
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_attn_metadata_broadcastable_dict
,
...
@@ -66,7 +56,6 @@ if TYPE_CHECKING:
...
@@ -66,7 +56,6 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_PAD_SLOT_ID
=
-
1
LORA_WARMUP_RANK
=
8
LORA_WARMUP_RANK
=
8
_BATCH_SIZE_ALIGNMENT
=
8
_BATCH_SIZE_ALIGNMENT
=
8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
...
@@ -858,6 +847,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -858,6 +847,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
block_size
,
)
if
num_attn_heads
else
None
)
if
num_attn_heads
else
None
if
self
.
attn_backend
:
self
.
attn_state
=
self
.
attn_backend
.
get_state_cls
()(
weakref
.
proxy
(
self
))
else
:
self
.
attn_state
=
CommonAttentionState
(
weakref
.
proxy
(
self
))
# Multi-modal data support
# Multi-modal data support
self
.
input_registry
=
input_registry
self
.
input_registry
=
input_registry
...
@@ -872,11 +866,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -872,11 +866,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
lora_manager
:
Optional
[
LRUCacheWorkerLoRAManager
]
=
None
self
.
lora_manager
:
Optional
[
LRUCacheWorkerLoRAManager
]
=
None
self
.
prompt_adapter_manager
:
LRUCacheWorkerPromptAdapterManager
=
None
self
.
prompt_adapter_manager
:
LRUCacheWorkerPromptAdapterManager
=
None
self
.
flashinfer_decode_workspace_buffer
=
None
self
.
flashinfer_decode_wrapper
=
None
self
.
flashinfer_prefill_workspace_buffer
=
None
self
.
flashinfer_prefill_wrapper
=
None
set_cpu_offload_max_bytes
(
set_cpu_offload_max_bytes
(
int
(
self
.
cache_config
.
cpu_offload_gb
*
1024
**
3
))
int
(
self
.
cache_config
.
cpu_offload_gb
*
1024
**
3
))
...
@@ -1203,10 +1192,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1203,10 +1192,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_batch_size
=
max
(
_BATCH_SIZES_TO_CAPTURE
)
max_batch_size
=
max
(
_BATCH_SIZES_TO_CAPTURE
)
input_tokens
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_tokens
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
input_positions
=
torch
.
zeros
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
slot_mapping
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
long
).
cuda
()
slot_mapping
.
fill_
(
_PAD_SLOT_ID
)
seq_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
block_tables
=
torch
.
from_numpy
(
self
.
graph_block_tables
).
cuda
()
intermediate_inputs
=
None
intermediate_inputs
=
None
if
not
get_pp_group
().
is_first_rank
:
if
not
get_pp_group
().
is_first_rank
:
intermediate_inputs
=
self
.
model
.
make_empty_intermediate_tensors
(
intermediate_inputs
=
self
.
model
.
make_empty_intermediate_tensors
(
...
@@ -1226,102 +1211,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1226,102 +1211,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
]
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
with
self
.
attn_state
.
graph_capture
(
# For flashinfer, different batch sizes will share the
max_batch_size
),
graph_capture
()
as
graph_capture_context
:
# same workspace buffer.
decode_workspace_buffer
=
\
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
indices_buffer
=
torch
.
empty
(
max_batch_size
*
self
.
cache_config
.
num_gpu_blocks
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
indptr_buffer
=
torch
.
empty
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
last_page_len_buffer
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
with
graph_capture
()
as
graph_capture_context
:
# NOTE: Capturing the largest batch size first may help reduce the
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
# memory usage of CUDA graph.
for
virtual_engine
in
range
(
for
virtual_engine
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
):
self
.
parallel_config
.
pipeline_parallel_size
):
for
batch_size
in
reversed
(
batch_size_capture_list
):
for
batch_size
in
reversed
(
batch_size_capture_list
):
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
attn_metadata
=
(
_indptr_buffer
=
indptr_buffer
[:
batch_size
+
1
]
self
.
attn_state
.
graph_capture_get_metadata_for_batch
(
_last_page_len_buffer
=
last_page_len_buffer
[:
batch_size
))
batch_size
]
num_qo_heads
=
(
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
))
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
if
num_qo_heads
//
num_kv_heads
>=
4
:
use_tensor_cores
=
True
else
:
use_tensor_cores
=
False
decode_wrapper
=
\
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
(
decode_workspace_buffer
,
_indptr_buffer
,
indices_buffer
,
_last_page_len_buffer
,
"NHD"
,
use_tensor_cores
)
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
kv_cache_dtype
,
self
.
model_config
.
dtype
)
paged_kv_indptr_tensor_host
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
)
paged_kv_indices_tensor_host
=
torch
.
arange
(
0
,
batch_size
,
dtype
=
torch
.
int32
)
paged_kv_last_page_len_tensor_host
=
torch
.
full
(
(
batch_size
,
),
self
.
block_size
,
dtype
=
torch
.
int32
)
query_start_loc_host
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
slot_mapping
=
slot_mapping
[:
batch_size
],
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
max_prefill_seq_len
=
0
,
block_tables
=
block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor_host
,
paged_kv_indices
=
paged_kv_indices_tensor_host
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor_host
,
num_qo_heads
=
num_qo_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim
=
self
.
model_config
.
get_head_size
(),
page_size
=
self
.
block_size
,
seq_start_loc
=
None
,
query_start_loc
=
query_start_loc_host
,
device
=
self
.
device
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
True
,
decode_wrapper
=
decode_wrapper
,
prefill_wrapper
=
None
)
attn_metadata
.
begin_forward
()
else
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
slot_mapping
=
slot_mapping
[:
batch_size
],
seq_lens
=
None
,
seq_lens_tensor
=
seq_lens
[:
batch_size
],
max_query_len
=
None
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_seq_len_to_capture
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
block_tables
[:
batch_size
],
use_cuda_graph
=
True
,
)
if
self
.
lora_config
:
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
lora_mapping
=
LoRAMapping
(
...
@@ -1339,17 +1238,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1339,17 +1238,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
set
(),
prompt_adapter_mapping
)
set
(),
prompt_adapter_mapping
)
graph_runner
=
CUDAGraphRunner
(
graph_runner
=
CUDAGraphRunner
(
self
.
model
,
self
.
attn_backend
.
get_name
())
self
.
model
,
self
.
attn_backend
.
get_name
(),
self
.
attn_state
.
graph_clone
(
batch_size
))
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
graph_runner
.
flashinfer_indptr_buffer
=
_indptr_buffer
graph_runner
.
flashinfer_indices_buffer
=
indices_buffer
graph_runner
.
flashinfer_last_page_len_buffer
=
\
_last_page_len_buffer
graph_runner
.
flashinfer_decode_workspace_buffer
=
\
decode_workspace_buffer
graph_runner
.
flashinfer_decode_wrapper
=
\
decode_wrapper
capture_inputs
=
{
capture_inputs
=
{
"input_ids"
:
"input_ids"
:
...
@@ -1476,36 +1366,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1476,36 +1366,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_mapping
)
model_input
.
prompt_adapter_mapping
)
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
self
.
attn_state
.
begin_forward
(
model_input
)
assert
model_input
.
attn_metadata
is
not
None
assert
model_input
.
input_tokens
is
not
None
if
self
.
flashinfer_decode_workspace_buffer
is
None
:
self
.
flashinfer_decode_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
flashinfer_decode_wrapper
=
\
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_decode_workspace_buffer
,
"NHD"
)
self
.
flashinfer_prefill_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
flashinfer_prefill_wrapper
=
\
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_prefill_workspace_buffer
,
"NHD"
)
model_input
.
attn_metadata
.
prefill_wrapper
=
\
self
.
flashinfer_prefill_wrapper
if
model_input
.
attn_metadata
.
use_cuda_graph
:
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_input
.
attn_metadata
.
decode_wrapper
=
self
.
graph_runners
[
model_input
.
virtual_engine
][
batch_size
].
flashinfer_decode_wrapper
else
:
model_input
.
attn_metadata
.
decode_wrapper
=
\
self
.
flashinfer_decode_wrapper
model_input
.
attn_metadata
.
begin_forward
()
# Currently cuda graph is only supported by the decode phase.
# Currently cuda graph is only supported by the decode phase.
assert
model_input
.
attn_metadata
is
not
None
assert
model_input
.
attn_metadata
is
not
None
...
@@ -1613,22 +1474,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1613,22 +1474,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
class
CUDAGraphRunner
:
class
CUDAGraphRunner
:
def
__init__
(
self
,
model
:
nn
.
Module
,
backend_name
:
str
):
def
__init__
(
self
,
model
:
nn
.
Module
,
backend_name
:
str
,
attn_state
:
AttentionState
):
self
.
model
=
model
self
.
model
=
model
self
.
backend_name
=
backend_name
self
.
backend_name
=
backend_name
self
.
attn_state
=
attn_state
self
.
input_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
input_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
output_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
output_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_graph
:
Optional
[
torch
.
cuda
.
CUDAGraph
]
=
None
self
.
_graph
:
Optional
[
torch
.
cuda
.
CUDAGraph
]
=
None
self
.
flashinfer_decode_workspace_buffer
:
Optional
[
torch
.
Tensor
]
=
None
self
.
flashinfer_indptr_buffer
:
Optional
[
torch
.
Tensor
]
=
None
self
.
flashinfer_indices_buffer
:
Optional
[
torch
.
Tensor
]
=
None
self
.
flashinfer_last_page_len_buffer
:
Optional
[
torch
.
Tensor
]
=
None
self
.
flashinfer_decode_wrapper
:
Optional
[
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
]
=
None
@
property
@
property
def
graph
(
self
):
def
graph
(
self
):
assert
self
.
_graph
is
not
None
assert
self
.
_graph
is
not
None
...
@@ -1693,23 +1549,11 @@ class CUDAGraphRunner:
...
@@ -1693,23 +1549,11 @@ class CUDAGraphRunner:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# Save the input and output buffers.
# Save the input and output buffers.
if
self
.
backend_name
==
"flashinfer"
:
self
.
input_buffers
=
{
self
.
input_buffers
=
{
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
,
"positions"
:
positions
,
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
"kv_caches"
:
kv_caches
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
**
self
.
attn_state
.
get_graph_input_buffers
(
attn_metadata
),
**
kwargs
,
}
else
:
self
.
input_buffers
=
{
"input_ids"
:
input_ids
,
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"seq_lens_tensor"
:
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
**
kwargs
,
**
kwargs
,
}
}
if
intermediate_inputs
is
not
None
:
if
intermediate_inputs
is
not
None
:
...
@@ -1739,12 +1583,8 @@ class CUDAGraphRunner:
...
@@ -1739,12 +1583,8 @@ class CUDAGraphRunner:
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn_metadata
.
slot_mapping
,
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn_metadata
.
slot_mapping
,
non_blocking
=
True
)
non_blocking
=
True
)
if
self
.
backend_name
!=
"flashinfer"
:
self
.
attn_state
.
prepare_graph_input_buffers
(
self
.
input_buffers
,
self
.
input_buffers
[
"seq_lens_tensor"
].
copy_
(
attn_metadata
)
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
non_blocking
=
True
)
self
.
input_buffers
[
"block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
if
"seqlen_agnostic_capture_inputs"
in
self
.
input_buffers
:
if
"seqlen_agnostic_capture_inputs"
in
self
.
input_buffers
:
self
.
model
.
copy_inputs_before_cuda_graphs
(
self
.
input_buffers
,
self
.
model
.
copy_inputs_before_cuda_graphs
(
self
.
input_buffers
,
**
kwargs
)
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment