Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2fa4623d
Unverified
Commit
2fa4623d
authored
Jul 17, 2024
by
Cody Yu
Committed by
GitHub
Jul 17, 2024
Browse files
[Core] Refactor _prepare_model_input_tensors - take 2 (#6164)
parent
a9a2e74d
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1050 additions
and
470 deletions
+1050
-470
tests/worker/test_model_input.py
tests/worker/test_model_input.py
+5
-1
vllm/attention/__init__.py
vllm/attention/__init__.py
+3
-1
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+43
-2
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+11
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+181
-2
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+236
-2
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+11
-0
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+233
-1
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+10
-0
vllm/attention/selector.py
vllm/attention/selector.py
+3
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+299
-459
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+15
-0
No files found.
tests/worker/test_model_input.py
View file @
2fa4623d
...
@@ -3,7 +3,7 @@ from typing import List, Tuple, Type
...
@@ -3,7 +3,7 @@ from typing import List, Tuple, Type
import
torch
import
torch
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
...
@@ -26,6 +26,10 @@ class MockAttentionBackend(AttentionBackend):
...
@@ -26,6 +26,10 @@ class MockAttentionBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
AttentionMetadata
return
AttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"AttentionMetadataBuilder"
]:
raise
AttentionMetadataBuilder
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
...
vllm/attention/__init__.py
View file @
2fa4623d
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
)
AttentionMetadata
,
AttentionMetadataBuilder
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
...
@@ -7,6 +8,7 @@ __all__ = [
...
@@ -7,6 +8,7 @@ __all__ = [
"Attention"
,
"Attention"
,
"AttentionBackend"
,
"AttentionBackend"
,
"AttentionMetadata"
,
"AttentionMetadata"
,
"AttentionMetadataBuilder"
,
"Attention"
,
"Attention"
,
"get_attn_backend"
,
"get_attn_backend"
,
]
]
vllm/attention/backends/abstract.py
View file @
2fa4623d
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
(
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
TypeVar
)
Tuple
,
Type
,
TypeVar
)
import
torch
import
torch
if
TYPE_CHECKING
:
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.worker.model_runner_base
import
ModelRunnerInputBuilderBase
class
AttentionType
(
Enum
):
class
AttentionType
(
Enum
):
DECODER
=
auto
()
# Decoder attention between previous layer Q/K/V
DECODER
=
auto
()
# Decoder attention between previous layer Q/K/V
...
@@ -35,6 +39,16 @@ class AttentionBackend(ABC):
...
@@ -35,6 +39,16 @@ class AttentionBackend(ABC):
def
make_metadata
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
def
make_metadata
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
return
cls
.
get_metadata_cls
()(
*
args
,
**
kwargs
)
return
cls
.
get_metadata_cls
()(
*
args
,
**
kwargs
)
@
staticmethod
@
abstractmethod
def
get_builder_cls
()
->
Type
[
"AttentionMetadataBuilder"
]:
raise
NotImplementedError
@
classmethod
def
make_metadata_builder
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadataBuilder"
:
return
cls
.
get_builder_cls
()(
*
args
,
**
kwargs
)
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
@@ -110,6 +124,33 @@ class AttentionMetadata:
...
@@ -110,6 +124,33 @@ class AttentionMetadata:
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
class
AttentionMetadataBuilder
(
ABC
,
Generic
[
T
]):
"""Abstract class for attention metadata builders."""
@
abstractmethod
def
__init__
(
self
,
input_builder
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
add_seq_group
(
self
,
seq_group_metadata
:
"SequenceGroupMetadata"
,
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
:
bool
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata and update
corresponding fields (in Python objects).
"""
raise
NotImplementedError
@
abstractmethod
def
build
(
self
,
runner
:
"ModelRunnerInputBuilderBase"
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
T
:
"""Build attention metadata with on-device tensors."""
raise
NotImplementedError
class
AttentionImpl
(
ABC
,
Generic
[
T
]):
class
AttentionImpl
(
ABC
,
Generic
[
T
]):
@
abstractmethod
@
abstractmethod
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
2fa4623d
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.ops.blocksparse_attention.interface
import
(
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
...
@@ -93,6 +94,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
...
@@ -93,6 +94,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
BlocksparseFlashAttentionMetadata
return
BlocksparseFlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"BlocksparseFlashAttentionMetadataBuilder"
]:
return
BlocksparseFlashAttentionMetadataBuilder
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -244,6 +249,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
...
@@ -244,6 +249,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
class
BlocksparseFlashAttentionMetadataBuilder
(
CommonMetadataBuilder
[
BlocksparseFlashAttentionMetadata
]):
_metadata_cls
=
BlocksparseFlashAttentionMetadata
class
BlocksparseFlashAttentionImpl
(
AttentionImpl
):
class
BlocksparseFlashAttentionImpl
(
AttentionImpl
):
"""
"""
If the input tensors contain prompt tokens, the layout is as follows:
If the input tensors contain prompt tokens, the layout is as follows:
...
...
vllm/attention/backends/flash_attn.py
View file @
2fa4623d
"""Attention layer with FlashAttention."""
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
)
class
FlashAttentionBackend
(
AttentionBackend
):
class
FlashAttentionBackend
(
AttentionBackend
):
...
@@ -28,6 +39,10 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -28,6 +39,10 @@ class FlashAttentionBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
FlashAttentionMetadata
return
FlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -184,6 +199,170 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -184,6 +199,170 @@ class FlashAttentionMetadata(AttentionMetadata):
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
:
bool
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
seq_group_metadata
.
is_prompt
block_tables
=
seq_group_metadata
.
block_tables
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
seq_group_metadata
.
seq_data
.
keys
(),
token_lens
,
seq_lens
,
curr_seq_lens
,
query_lens
,
context_lens
,
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
,
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
seq_group_metadata
.
block_tables
)
def
build
(
self
,
runner
:
"GPUModelRunnerBase"
,
seq_lens
,
query_lens
,
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors."""
device
=
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
else
:
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
self
.
block_tables
)
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
context_lens_tensor
=
torch
.
tensor
(
self
.
context_lens
,
dtype
=
torch
.
int
,
device
=
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
return
FlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
class
FlashAttentionImpl
(
AttentionImpl
):
class
FlashAttentionImpl
(
AttentionImpl
):
"""
"""
If the input tensors contain prompt tokens, the layout is as follows:
If the input tensors contain prompt tokens, the layout is as follows:
...
...
vllm/attention/backends/flashinfer.py
View file @
2fa4623d
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
try
:
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
...
@@ -14,7 +14,18 @@ import torch
...
@@ -14,7 +14,18 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
get_kv_cache_torch_dtype
,
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
)
class
FlashInferBackend
(
AttentionBackend
):
class
FlashInferBackend
(
AttentionBackend
):
...
@@ -31,6 +42,10 @@ class FlashInferBackend(AttentionBackend):
...
@@ -31,6 +42,10 @@ class FlashInferBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
FlashInferMetadata
return
FlashInferMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"FlashInferMetadataBuilder"
]:
return
FlashInferMetadataBuilder
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -188,6 +203,225 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -188,6 +203,225 @@ class FlashInferMetadata(AttentionMetadata):
return
self
return
self
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
self
.
paged_kv_indices
:
List
[
int
]
=
[]
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
self
.
paged_kv_indptr
:
List
[
int
]
=
[
0
]
# paged_kv_last_page_len is the length of the last page of each request
self
.
paged_kv_last_page_len
:
List
[
int
]
=
[]
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
:
bool
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
seq_group_metadata
.
is_prompt
block_tables
=
seq_group_metadata
.
block_tables
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
seq_group_metadata
.
seq_data
.
keys
(),
token_lens
,
seq_lens
,
curr_seq_lens
,
query_lens
,
context_lens
,
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
block_table
=
computed_block_nums
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
is_profile_run
=
is_block_tables_empty
(
block_tables
)
# Compute slot mapping.
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
,
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
seq_group_metadata
.
block_tables
)
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will
# create dummy inputs.
if
is_profile_run
:
return
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
block_table
=
block_tables
[
seq_id
]
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
self
.
paged_kv_last_page_len
.
append
(
last_page_len
)
def
build
(
self
,
runner
:
"GPUModelRunnerBase"
,
seq_lens
,
query_lens
,
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
device
=
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
last_paged_kv_indptr
=
self
.
paged_kv_indptr
[
-
1
]
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
cuda_graph_pad_size
)
self
.
paged_kv_last_page_len
.
extend
([
0
]
*
cuda_graph_pad_size
)
else
:
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
self
.
block_tables
)
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
logits_soft_cap
=
getattr
(
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
len
(
self
.
paged_kv_indptr
)
>
0
:
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
paged_kv_indptr_tensor
=
torch
.
tensor
(
self
.
paged_kv_indptr
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
paged_kv_last_page_len_tensor
=
torch
.
tensor
(
self
.
paged_kv_last_page_len
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
else
:
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_len_tensor
=
None
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
runner
.
kv_cache_dtype
,
runner
.
model_config
.
dtype
)
return
FlashInferMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
max_prefill_seq_len
=
max_prefill_seq_len
,
block_tables
=
block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor
,
num_qo_heads
=
runner
.
model_config
.
get_num_attention_heads
(
runner
.
parallel_config
),
num_kv_heads
=
runner
.
model_config
.
get_num_kv_heads
(
runner
.
parallel_config
),
head_dim
=
runner
.
model_config
.
get_head_size
(),
page_size
=
self
.
block_size
,
seq_start_loc
=
seq_start_loc
,
query_start_loc
=
query_start_loc
,
device
=
device
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
use_captured_graph
,
logits_soft_cap
=
logits_soft_cap
)
class
FlashInferImpl
(
AttentionImpl
):
class
FlashInferImpl
(
AttentionImpl
):
def
__init__
(
def
__init__
(
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
2fa4623d
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -28,6 +29,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
...
@@ -28,6 +29,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
ROCmFlashAttentionMetadata
return
ROCmFlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"ROCmFlashAttentionMetadataBuilder"
]:
return
ROCmFlashAttentionMetadataBuilder
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -166,6 +171,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -166,6 +171,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
class
ROCmFlashAttentionMetadataBuilder
(
CommonMetadataBuilder
[
ROCmFlashAttentionMetadata
]):
_metadata_cls
=
ROCmFlashAttentionMetadata
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seq_lens
:
Optional
[
List
[
int
]],
seq_lens
:
Optional
[
List
[
int
]],
...
...
vllm/attention/backends/utils.py
View file @
2fa4623d
"""Attention backend utils"""
"""Attention backend utils"""
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Type
,
TypeVar
,
Union
import
torch
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
# Error string(s) for encoder/decoder
# Error string(s) for encoder/decoder
# unsupported attention scenarios
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
=
(
"ROCm/HIP is not currently supported "
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
=
(
"ROCm/HIP is not currently supported "
"with encoder/decoder models."
)
"with encoder/decoder models."
)
PAD_SLOT_ID
=
-
1
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
)
def
is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if
block_tables
is
None
:
return
True
if
isinstance
(
block_tables
,
dict
)
and
all
(
value
is
None
for
value
in
block_tables
.
values
()):
return
True
return
False
def
compute_slot_mapping_start_idx
(
is_prompt
:
bool
,
query_len
:
int
,
context_len
:
int
,
sliding_window
:
int
,
use_v2_block_manager
:
bool
):
"""
Compute the start index of slot mapping.
"""
start_idx
=
0
if
is_prompt
and
sliding_window
is
not
None
:
assert
use_v2_block_manager
or
context_len
==
0
,
(
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager"
)
# When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx
=
max
(
0
,
query_len
-
sliding_window
)
return
start_idx
def
compute_slot_mapping
(
is_profile_run
:
bool
,
slot_mapping
:
List
[
int
],
seq_id
:
int
,
seq_len
:
int
,
context_len
:
int
,
start_idx
:
int
,
block_size
:
int
,
block_tables
:
Dict
[
int
,
List
[
int
]]):
"""
Compute slot mapping.
"""
if
is_profile_run
:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
seq_len
)
return
# Mask the [0, start_idx) tokens of the prompt with
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table
=
block_tables
[
seq_id
]
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
max
(
0
,
start_idx
-
context_len
))
for
i
in
range
(
max
(
start_idx
,
context_len
),
seq_len
):
block_number
=
block_table
[
i
//
block_size
]
block_offset
=
i
%
block_size
slot
=
block_number
*
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
TAttentionMetadata
=
TypeVar
(
"TAttentionMetadata"
,
bound
=
'AttentionMetadata'
)
class
CommonMetadataBuilder
(
AttentionMetadataBuilder
[
TAttentionMetadata
]):
_metadata_cls
:
Type
[
TAttentionMetadata
]
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
,
chunked_prefill_enabled
):
is_prompt
=
seq_group_metadata
.
is_prompt
block_tables
=
seq_group_metadata
.
block_tables
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
seq_group_metadata
.
seq_data
.
keys
(),
token_lens
,
seq_lens
,
curr_seq_lens
,
query_lens
,
context_lens
,
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
block_table
=
computed_block_nums
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
,
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
seq_group_metadata
.
block_tables
)
def
build
(
self
,
runner
:
"GPUModelRunnerBase"
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
device
=
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Please use Flashinfer backend for models with logits_soft_cap "
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
"Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
else
:
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
self
.
block_tables
)
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
context_lens_tensor
=
torch
.
tensor
(
self
.
context_lens
,
dtype
=
torch
.
int
,
device
=
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
return
self
.
_metadata_cls
(
# type: ignore
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
vllm/attention/backends/xformers.py
View file @
2fa4623d
...
@@ -11,6 +11,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
...
@@ -11,6 +11,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -32,6 +33,10 @@ class XFormersBackend(AttentionBackend):
...
@@ -32,6 +33,10 @@ class XFormersBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
XFormersMetadata
return
XFormersMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"XFormersMetadataBuilder"
]:
return
XFormersMetadataBuilder
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -362,6 +367,11 @@ def _get_seq_len_block_table_args(
...
@@ -362,6 +367,11 @@ def _get_seq_len_block_table_args(
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
class
XFormersMetadataBuilder
(
CommonMetadataBuilder
[
XFormersMetadata
]):
_metadata_cls
=
XFormersMetadata
class
XFormersImpl
(
AttentionImpl
[
XFormersMetadata
]):
class
XFormersImpl
(
AttentionImpl
[
XFormersMetadata
]):
"""
"""
If the input tensors contain prompt tokens, the layout is as follows:
If the input tensors contain prompt tokens, the layout is as follows:
...
...
vllm/attention/selector.py
View file @
2fa4623d
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
,
is_openvino
,
is_tpu
,
is_xpu
from
vllm.utils
import
is_cpu
,
is_hip
,
is_openvino
,
is_tpu
,
is_xpu
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -136,7 +137,7 @@ def which_attn_to_use(
...
@@ -136,7 +137,7 @@ def which_attn_to_use(
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
!=
9
:
if
current_platform
.
get_device_capability
()[
0
]
!=
9
:
# not Instinct series GPUs.
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
else
:
else
:
...
@@ -145,7 +146,7 @@ def which_attn_to_use(
...
@@ -145,7 +146,7 @@ def which_attn_to_use(
# FlashAttn in NVIDIA GPUs.
# FlashAttn in NVIDIA GPUs.
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
if
current_platform
.
get_device_capability
()[
0
]
<
8
:
# Volta and Turing NVIDIA GPUs.
# Volta and Turing NVIDIA GPUs.
logger
.
info
(
logger
.
info
(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"Cannot use FlashAttention-2 backend for Volta and Turing "
...
...
vllm/worker/model_runner.py
View file @
2fa4623d
This diff is collapsed.
Click to expand it.
vllm/worker/model_runner_base.py
View file @
2fa4623d
...
@@ -113,6 +113,21 @@ class ModelRunnerInputBase(ABC):
...
@@ -113,6 +113,21 @@ class ModelRunnerInputBase(ABC):
raise
NotImplementedError
raise
NotImplementedError
class
ModelRunnerInputBuilderBase
(
ABC
,
Generic
[
T
]):
"""A builder to create ModelRunnerInputBase objects.
"""
@
abstractmethod
def
add_seq_group
(
self
,
seq_group_metadata
):
"""TBA"""
raise
NotImplementedError
@
abstractmethod
def
build
(
self
,
*
args
,
**
kwargs
)
->
T
:
"""Build metadata with on-device tensors."""
raise
NotImplementedError
class
ModelRunnerBase
(
ABC
,
Generic
[
T
]):
class
ModelRunnerBase
(
ABC
,
Generic
[
T
]):
"""
"""
Model runner interface that abstracts a particular hardware and/or type of
Model runner interface that abstracts a particular hardware and/or type of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment