Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2fa4623d
Unverified
Commit
2fa4623d
authored
Jul 17, 2024
by
Cody Yu
Committed by
GitHub
Jul 17, 2024
Browse files
[Core] Refactor _prepare_model_input_tensors - take 2 (#6164)
parent
a9a2e74d
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1050 additions
and
470 deletions
+1050
-470
tests/worker/test_model_input.py
tests/worker/test_model_input.py
+5
-1
vllm/attention/__init__.py
vllm/attention/__init__.py
+3
-1
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+43
-2
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+11
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+181
-2
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+236
-2
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+11
-0
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+233
-1
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+10
-0
vllm/attention/selector.py
vllm/attention/selector.py
+3
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+299
-459
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+15
-0
No files found.
tests/worker/test_model_input.py
View file @
2fa4623d
...
...
@@ -3,7 +3,7 @@ from typing import List, Tuple, Type
import
torch
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
...
...
@@ -26,6 +26,10 @@ class MockAttentionBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
AttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"AttentionMetadataBuilder"
]:
raise
AttentionMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
...
...
vllm/attention/__init__.py
View file @
2fa4623d
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
)
AttentionMetadata
,
AttentionMetadataBuilder
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.selector
import
get_attn_backend
...
...
@@ -7,6 +8,7 @@ __all__ = [
"Attention"
,
"AttentionBackend"
,
"AttentionMetadata"
,
"AttentionMetadataBuilder"
,
"Attention"
,
"get_attn_backend"
,
]
vllm/attention/backends/abstract.py
View file @
2fa4623d
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
fields
from
enum
import
Enum
,
auto
from
typing
import
(
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
)
import
torch
if
TYPE_CHECKING
:
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.worker.model_runner_base
import
ModelRunnerInputBuilderBase
class
AttentionType
(
Enum
):
DECODER
=
auto
()
# Decoder attention between previous layer Q/K/V
...
...
@@ -35,6 +39,16 @@ class AttentionBackend(ABC):
def
make_metadata
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
return
cls
.
get_metadata_cls
()(
*
args
,
**
kwargs
)
@
staticmethod
@
abstractmethod
def
get_builder_cls
()
->
Type
[
"AttentionMetadataBuilder"
]:
raise
NotImplementedError
@
classmethod
def
make_metadata_builder
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadataBuilder"
:
return
cls
.
get_builder_cls
()(
*
args
,
**
kwargs
)
@
staticmethod
@
abstractmethod
def
get_kv_cache_shape
(
...
...
@@ -110,6 +124,33 @@ class AttentionMetadata:
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
class
AttentionMetadataBuilder
(
ABC
,
Generic
[
T
]):
"""Abstract class for attention metadata builders."""
@
abstractmethod
def
__init__
(
self
,
input_builder
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
add_seq_group
(
self
,
seq_group_metadata
:
"SequenceGroupMetadata"
,
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
:
bool
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata and update
corresponding fields (in Python objects).
"""
raise
NotImplementedError
@
abstractmethod
def
build
(
self
,
runner
:
"ModelRunnerInputBuilderBase"
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
T
:
"""Build attention metadata with on-device tensors."""
raise
NotImplementedError
class
AttentionImpl
(
ABC
,
Generic
[
T
]):
@
abstractmethod
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
2fa4623d
...
...
@@ -5,6 +5,7 @@ import torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
...
...
@@ -93,6 +94,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
BlocksparseFlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"BlocksparseFlashAttentionMetadataBuilder"
]:
return
BlocksparseFlashAttentionMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
...
...
@@ -244,6 +249,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
return
self
.
_cached_decode_metadata
class
BlocksparseFlashAttentionMetadataBuilder
(
CommonMetadataBuilder
[
BlocksparseFlashAttentionMetadata
]):
_metadata_cls
=
BlocksparseFlashAttentionMetadata
class
BlocksparseFlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
...
...
vllm/attention/backends/flash_attn.py
View file @
2fa4623d
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
)
class
FlashAttentionBackend
(
AttentionBackend
):
...
...
@@ -28,6 +39,10 @@ class FlashAttentionBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
FlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
...
...
@@ -184,6 +199,170 @@ class FlashAttentionMetadata(AttentionMetadata):
return
self
.
_cached_decode_metadata
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
:
bool
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
seq_group_metadata
.
is_prompt
block_tables
=
seq_group_metadata
.
block_tables
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
seq_group_metadata
.
seq_data
.
keys
(),
token_lens
,
seq_lens
,
curr_seq_lens
,
query_lens
,
context_lens
,
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
,
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
seq_group_metadata
.
block_tables
)
def
build
(
self
,
runner
:
"GPUModelRunnerBase"
,
seq_lens
,
query_lens
,
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors."""
device
=
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
else
:
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
self
.
block_tables
)
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
context_lens_tensor
=
torch
.
tensor
(
self
.
context_lens
,
dtype
=
torch
.
int
,
device
=
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
return
FlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
class
FlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
...
...
vllm/attention/backends/flashinfer.py
View file @
2fa4623d
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
...
...
@@ -14,7 +14,18 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
get_kv_cache_torch_dtype
,
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
)
class
FlashInferBackend
(
AttentionBackend
):
...
...
@@ -31,6 +42,10 @@ class FlashInferBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
FlashInferMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"FlashInferMetadataBuilder"
]:
return
FlashInferMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
...
...
@@ -188,6 +203,225 @@ class FlashInferMetadata(AttentionMetadata):
return
self
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
self
.
paged_kv_indices
:
List
[
int
]
=
[]
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
self
.
paged_kv_indptr
:
List
[
int
]
=
[
0
]
# paged_kv_last_page_len is the length of the last page of each request
self
.
paged_kv_last_page_len
:
List
[
int
]
=
[]
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
:
bool
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
seq_group_metadata
.
is_prompt
block_tables
=
seq_group_metadata
.
block_tables
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
seq_group_metadata
.
seq_data
.
keys
(),
token_lens
,
seq_lens
,
curr_seq_lens
,
query_lens
,
context_lens
,
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
block_table
=
computed_block_nums
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
is_profile_run
=
is_block_tables_empty
(
block_tables
)
# Compute slot mapping.
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
,
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
seq_group_metadata
.
block_tables
)
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will
# create dummy inputs.
if
is_profile_run
:
return
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
block_table
=
block_tables
[
seq_id
]
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
self
.
paged_kv_last_page_len
.
append
(
last_page_len
)
def
build
(
self
,
runner
:
"GPUModelRunnerBase"
,
seq_lens
,
query_lens
,
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
device
=
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
last_paged_kv_indptr
=
self
.
paged_kv_indptr
[
-
1
]
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
cuda_graph_pad_size
)
self
.
paged_kv_last_page_len
.
extend
([
0
]
*
cuda_graph_pad_size
)
else
:
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
self
.
block_tables
)
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
logits_soft_cap
=
getattr
(
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
len
(
self
.
paged_kv_indptr
)
>
0
:
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
paged_kv_indptr_tensor
=
torch
.
tensor
(
self
.
paged_kv_indptr
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
paged_kv_last_page_len_tensor
=
torch
.
tensor
(
self
.
paged_kv_last_page_len
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
else
:
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_len_tensor
=
None
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
runner
.
kv_cache_dtype
,
runner
.
model_config
.
dtype
)
return
FlashInferMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
max_prefill_seq_len
=
max_prefill_seq_len
,
block_tables
=
block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor
,
num_qo_heads
=
runner
.
model_config
.
get_num_attention_heads
(
runner
.
parallel_config
),
num_kv_heads
=
runner
.
model_config
.
get_num_kv_heads
(
runner
.
parallel_config
),
head_dim
=
runner
.
model_config
.
get_head_size
(),
page_size
=
self
.
block_size
,
seq_start_loc
=
seq_start_loc
,
query_start_loc
=
query_start_loc
,
device
=
device
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
use_captured_graph
,
logits_soft_cap
=
logits_soft_cap
)
class
FlashInferImpl
(
AttentionImpl
):
def
__init__
(
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
2fa4623d
...
...
@@ -7,6 +7,7 @@ import torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
...
...
@@ -28,6 +29,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
ROCmFlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"ROCmFlashAttentionMetadataBuilder"
]:
return
ROCmFlashAttentionMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
...
...
@@ -166,6 +171,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
return
self
.
_cached_decode_metadata
class
ROCmFlashAttentionMetadataBuilder
(
CommonMetadataBuilder
[
ROCmFlashAttentionMetadata
]):
_metadata_cls
=
ROCmFlashAttentionMetadata
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
seq_lens
:
Optional
[
List
[
int
]],
...
...
vllm/attention/backends/utils.py
View file @
2fa4623d
"""Attention backend utils"""
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Type
,
TypeVar
,
Union
import
torch
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
=
(
"ROCm/HIP is not currently supported "
"with encoder/decoder models."
)
PAD_SLOT_ID
=
-
1
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
)
def
is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if
block_tables
is
None
:
return
True
if
isinstance
(
block_tables
,
dict
)
and
all
(
value
is
None
for
value
in
block_tables
.
values
()):
return
True
return
False
def
compute_slot_mapping_start_idx
(
is_prompt
:
bool
,
query_len
:
int
,
context_len
:
int
,
sliding_window
:
int
,
use_v2_block_manager
:
bool
):
"""
Compute the start index of slot mapping.
"""
start_idx
=
0
if
is_prompt
and
sliding_window
is
not
None
:
assert
use_v2_block_manager
or
context_len
==
0
,
(
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager"
)
# When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx
=
max
(
0
,
query_len
-
sliding_window
)
return
start_idx
def
compute_slot_mapping
(
is_profile_run
:
bool
,
slot_mapping
:
List
[
int
],
seq_id
:
int
,
seq_len
:
int
,
context_len
:
int
,
start_idx
:
int
,
block_size
:
int
,
block_tables
:
Dict
[
int
,
List
[
int
]]):
"""
Compute slot mapping.
"""
if
is_profile_run
:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
seq_len
)
return
# Mask the [0, start_idx) tokens of the prompt with
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table
=
block_tables
[
seq_id
]
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
max
(
0
,
start_idx
-
context_len
))
for
i
in
range
(
max
(
start_idx
,
context_len
),
seq_len
):
block_number
=
block_table
[
i
//
block_size
]
block_offset
=
i
%
block_size
slot
=
block_number
*
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
TAttentionMetadata
=
TypeVar
(
"TAttentionMetadata"
,
bound
=
'AttentionMetadata'
)
class
CommonMetadataBuilder
(
AttentionMetadataBuilder
[
TAttentionMetadata
]):
_metadata_cls
:
Type
[
TAttentionMetadata
]
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
,
chunked_prefill_enabled
):
is_prompt
=
seq_group_metadata
.
is_prompt
block_tables
=
seq_group_metadata
.
block_tables
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
seq_group_metadata
.
seq_data
.
keys
(),
token_lens
,
seq_lens
,
curr_seq_lens
,
query_lens
,
context_lens
,
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
block_table
=
computed_block_nums
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
,
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
seq_group_metadata
.
block_tables
)
def
build
(
self
,
runner
:
"GPUModelRunnerBase"
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
device
=
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Please use Flashinfer backend for models with logits_soft_cap "
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
"Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
+
cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
else
:
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
self
.
block_tables
)
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
context_lens_tensor
=
torch
.
tensor
(
self
.
context_lens
,
dtype
=
torch
.
int
,
device
=
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
return
self
.
_metadata_cls
(
# type: ignore
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
vllm/attention/backends/xformers.py
View file @
2fa4623d
...
...
@@ -11,6 +11,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
...
...
@@ -32,6 +33,10 @@ class XFormersBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
XFormersMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"XFormersMetadataBuilder"
]:
return
XFormersMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
...
...
@@ -362,6 +367,11 @@ def _get_seq_len_block_table_args(
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
class
XFormersMetadataBuilder
(
CommonMetadataBuilder
[
XFormersMetadata
]):
_metadata_cls
=
XFormersMetadata
class
XFormersImpl
(
AttentionImpl
[
XFormersMetadata
]):
"""
If the input tensors contain prompt tokens, the layout is as follows:
...
...
vllm/attention/selector.py
View file @
2fa4623d
...
...
@@ -7,6 +7,7 @@ import torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
,
is_openvino
,
is_tpu
,
is_xpu
logger
=
init_logger
(
__name__
)
...
...
@@ -136,7 +137,7 @@ def which_attn_to_use(
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
!=
9
:
if
current_platform
.
get_device_capability
()[
0
]
!=
9
:
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
else
:
...
...
@@ -145,7 +146,7 @@ def which_attn_to_use(
# FlashAttn in NVIDIA GPUs.
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
if
current_platform
.
get_device_capability
()[
0
]
<
8
:
# Volta and Turing NVIDIA GPUs.
logger
.
info
(
"Cannot use FlashAttention-2 backend for Volta and Turing "
...
...
vllm/worker/model_runner.py
View file @
2fa4623d
...
...
@@ -2,6 +2,7 @@ import dataclasses
import
gc
import
time
import
warnings
import
weakref
from
collections
import
defaultdict
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
)
...
...
@@ -48,9 +49,9 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
get_kv_cache_torch_dtype
,
is_hip
,
is_pin_memory_available
,
make_tensor_with_pad
)
is_pin_memory_available
)
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
...
...
@@ -165,6 +166,298 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
return
cls
(
**
tensor_dict
)
class
ModelInputForGPUBuilder
(
ModelRunnerInputBuilderBase
[
ModelInputForGPU
]):
"""TBA"""
def
__init__
(
self
,
runner
:
"GPUModelRunnerBase"
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
):
super
().
__init__
()
self
.
runner
=
runner
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
self
.
attn_backend
=
self
.
runner
.
attn_backend
self
.
scheduler_config
=
self
.
runner
.
scheduler_config
self
.
sliding_window
=
self
.
runner
.
sliding_window
self
.
block_size
=
self
.
runner
.
block_size
self
.
enable_lora
=
self
.
runner
.
lora_config
is
not
None
self
.
enable_prompt_adapter
=
(
self
.
runner
.
prompt_adapter_config
is
not
None
)
self
.
multi_modal_input_mapper
=
self
.
runner
.
multi_modal_input_mapper
self
.
finished_requests_ids
=
finished_requests_ids
self
.
decode_only
=
True
# Common inputs.
self
.
input_tokens
:
List
[
int
]
=
[]
self
.
input_positions
:
List
[
int
]
=
[]
self
.
seq_lens
:
List
[
int
]
=
[]
self
.
query_lens
:
List
[
int
]
=
[]
self
.
max_decode_seq_len
:
int
=
0
self
.
request_ids_to_seq_ids
:
Dict
[
str
,
List
[
int
]]
=
defaultdict
(
list
)
# LoRA inputs.
self
.
lora_index_mapping
:
List
[
int
]
=
[]
self
.
lora_prompt_mapping
:
List
[
int
]
=
[]
self
.
lora_requests
:
Set
[
LoRARequest
]
=
set
()
# Prompt adapter inputs.
self
.
prompt_adapter_index_mapping
:
List
[
int
]
=
[]
self
.
prompt_adapter_prompt_mapping
:
List
[
int
]
=
[]
self
.
prompt_adapter_requests
:
Set
[
PromptAdapterRequest
]
=
set
()
# Multi-modal inputs.
self
.
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
# Attention metadata inputs.
self
.
attn_metadata_builder
=
self
.
attn_backend
.
make_metadata_builder
(
self
)
# Engine/Model configurations.
self
.
chunked_prefill_enabled
=
(
self
.
scheduler_config
is
not
None
and
self
.
scheduler_config
.
chunked_prefill_enabled
)
if
self
.
sliding_window
is
not
None
:
self
.
sliding_window_blocks
=
(
self
.
sliding_window
+
self
.
block_size
-
1
)
//
self
.
block_size
self
.
block_aligned_sliding_window
=
\
self
.
sliding_window_blocks
*
self
.
block_size
def
_compute_len_for_sliding_window
(
self
,
seq_len
:
int
):
curr_sliding_window_blocks
=
0
sliding_seq_len
=
seq_len
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if
self
.
sliding_window
is
not
None
:
curr_sliding_window_blocks
=
self
.
sliding_window_blocks
if
self
.
scheduler_config
.
use_v2_block_manager
:
# number of elements in last block
suff_len
=
seq_len
%
self
.
block_size
sliding_seq_len
=
min
(
seq_len
,
self
.
block_aligned_sliding_window
+
suff_len
)
if
suff_len
>
0
:
curr_sliding_window_blocks
+=
1
else
:
sliding_seq_len
=
min
(
seq_len
,
self
.
sliding_window
)
return
curr_sliding_window_blocks
,
sliding_seq_len
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
n_seqs
=
len
(
seq_ids
)
is_prompt
=
seq_group_metadata
.
is_prompt
token_chunk_size
=
seq_group_metadata
.
token_chunk_size
if
is_prompt
:
assert
n_seqs
==
1
self
.
decode_only
=
False
# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
self
.
request_ids_to_seq_ids
[
seq_group_metadata
.
request_id
]
=
[]
# The number of input tokens in each sequence.
token_lens
:
List
[
int
]
=
[]
# The number of tokens that are already computed.
context_lens
:
List
[
int
]
=
[]
# The current sliding window block for each sequence.
curr_sliding_window_blocks
:
List
[
int
]
=
[]
# The original sequence length (before applying sliding window)
# for each sequence.
orig_seq_lens
:
List
[
int
]
=
[]
# The sequence length (may be capped to the sliding window).
curr_seq_lens
:
List
[
int
]
=
[]
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
self
.
request_ids_to_seq_ids
[
seq_group_metadata
.
request_id
].
append
(
seq_id
)
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
# Check if hit prefix cache (i.e., some blocks are already computed)
# Note that prefix caching does not support sliding window.
prefix_cache_hit
=
(
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
and
is_prompt
)
if
self
.
chunked_prefill_enabled
and
prefix_cache_hit
:
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching now."
)
# Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens).
seq_len
=
seq_data
.
get_len
()
if
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
else
:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len
=
seq_len
-
1
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
# Compute tokens.
if
is_prompt
:
tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
else
:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens
=
[
seq_data
.
get_last_token_id
()]
if
prefix_cache_hit
:
assert
computed_block_nums
is
not
None
context_len
=
len
(
computed_block_nums
)
*
self
.
block_size
tokens
=
tokens
[
context_len
:]
# These are seq_len/context_len capped to the sliding window.
# They are passed to decode kernel.
# We still need original seq_len/context_len to compute slot
# mapping (and input position) below.
if
is_prompt
:
curr_sliding_window_block
=
0
sliding_seq_len
=
seq_len
query_len
=
seq_len
-
context_len
else
:
curr_sliding_window_block
,
sliding_seq_len
=
(
self
.
_compute_len_for_sliding_window
(
seq_len
))
query_len
=
1
self
.
seq_lens
.
append
(
sliding_seq_len
)
if
not
is_prompt
:
self
.
max_decode_seq_len
=
max
(
self
.
max_decode_seq_len
,
sliding_seq_len
)
self
.
query_lens
.
append
(
query_len
)
self
.
input_tokens
.
extend
(
tokens
)
self
.
input_positions
.
extend
(
list
(
range
(
context_len
,
seq_len
)))
# Intermediate data of the current sequence group for
# the attention metadata.
token_lens
.
append
(
len
(
tokens
))
context_lens
.
append
(
context_len
)
curr_seq_lens
.
append
(
sliding_seq_len
)
curr_sliding_window_blocks
.
append
(
curr_sliding_window_block
)
orig_seq_lens
.
append
(
seq_len
)
# Update attention metadata. Note that input builder attributes
# (self.xxx) include all added sequences, so we need to slice
# the last n_seqs sequences.
self
.
attn_metadata_builder
.
add_seq_group
(
seq_group_metadata
,
token_lens
,
orig_seq_lens
,
curr_seq_lens
,
self
.
query_lens
[
-
n_seqs
:],
context_lens
,
curr_sliding_window_blocks
,
prefix_cache_hit
,
self
.
chunked_prefill_enabled
)
# LoRA data.
if
self
.
enable_lora
:
lora_id
=
seq_group_metadata
.
lora_int_id
for
query_len
in
self
.
query_lens
[
-
n_seqs
:]:
if
lora_id
>
0
:
self
.
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
self
.
lora_index_mapping
+=
[
lora_id
]
*
query_len
self
.
lora_prompt_mapping
.
extend
(
[
lora_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
is
not
None
else
1
))
# Prompt adapter data. Note that when is_prompt=True,
# we expect only one sequence in the group.
if
self
.
enable_prompt_adapter
:
prompt_adapter_id
=
seq_group_metadata
.
prompt_adapter_id
if
prompt_adapter_id
>
0
and
is_prompt
:
query_len
=
self
.
query_lens
[
-
1
]
self
.
prompt_adapter_requests
.
add
(
seq_group_metadata
.
prompt_adapter_request
)
num_tokens
=
seq_group_metadata
.
\
prompt_adapter_num_virtual_tokens
pm
=
[
prompt_adapter_id
]
*
num_tokens
+
[
0
]
*
(
query_len
-
num_tokens
)
self
.
prompt_adapter_index_mapping
+=
pm
self
.
prompt_adapter_prompt_mapping
.
extend
(
[
prompt_adapter_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
# Multi-modal data.
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
self
.
multi_modal_inputs_list
.
append
(
mm_kwargs
)
def
build
(
self
)
->
ModelInputForGPU
:
if
not
self
.
input_tokens
:
return
self
.
model_input_cls
()
batch_size
=
len
(
self
.
input_tokens
)
use_captured_graph
=
(
self
.
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
self
.
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
cuda_graph_pad_size
=
-
1
if
use_captured_graph
:
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
cuda_graph_pad_size
=
graph_batch_size
-
batch_size
batch_size
=
graph_batch_size
# Tokens and positions.
self
.
input_tokens
.
extend
([
0
]
*
cuda_graph_pad_size
)
self
.
input_positions
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_tokens_tensor
=
torch
.
tensor
(
self
.
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
input_positions_tensor
=
torch
.
tensor
(
self
.
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
# Sequence and query lengths.
self
.
seq_lens
.
extend
([
1
]
*
cuda_graph_pad_size
)
# Attention metadata.
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
self
.
runner
,
self
.
seq_lens
,
self
.
query_lens
,
cuda_graph_pad_size
,
batch_size
)
# LoRA data.
if
self
.
enable_lora
:
self
.
lora_index_mapping
.
extend
([
0
]
*
cuda_graph_pad_size
)
lora_mapping
=
LoRAMapping
(
self
.
lora_index_mapping
,
self
.
lora_prompt_mapping
,
)
else
:
lora_mapping
=
None
# Prompt adapter data.
if
self
.
enable_prompt_adapter
:
self
.
prompt_adapter_index_mapping
.
extend
([
0
]
*
cuda_graph_pad_size
)
prompt_adapter_mapping
=
PromptAdapterMapping
(
self
.
prompt_adapter_index_mapping
,
self
.
prompt_adapter_prompt_mapping
,
)
else
:
prompt_adapter_mapping
=
None
# Multi-modal data.
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
self
.
multi_modal_inputs_list
,
device
=
self
.
runner
.
device
)
return
self
.
model_input_cls
(
input_tokens
=
input_tokens_tensor
,
input_positions
=
input_positions_tensor
,
attn_metadata
=
attn_metadata
,
seq_lens
=
self
.
seq_lens
,
query_lens
=
self
.
query_lens
,
lora_mapping
=
lora_mapping
,
lora_requests
=
self
.
lora_requests
,
multi_modal_kwargs
=
multi_modal_kwargs
,
request_ids_to_seq_ids
=
self
.
request_ids_to_seq_ids
,
finished_requests_ids
=
self
.
finished_requests_ids
,
prompt_adapter_mapping
=
prompt_adapter_mapping
,
prompt_adapter_requests
=
self
.
prompt_adapter_requests
)
class
GPUModelRunnerBase
(
ModelRunnerBase
[
TModelInputForGPU
]):
"""
Helper class for shared methods between GPU model runners.
...
...
@@ -368,464 +661,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs.
"""
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
lora_index_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
prompt_adapter_index_mapping
:
List
[
int
]
=
[]
prompt_adapter_prompt_mapping
:
List
[
int
]
=
[]
prompt_adapter_requests
:
Set
[
PromptAdapterRequest
]
=
set
()
seq_lens
:
List
[
int
]
=
[]
prefill_seq_lens
:
List
[
int
]
=
[]
decode_seq_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
request_ids_to_seq_ids
:
Dict
[
str
,
List
[
int
]]
=
defaultdict
(
list
)
decode_only
=
True
num_prefills
=
0
num_prefill_tokens
=
0
num_decode_tokens
=
0
# The following fields are only for flashinfer
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
paged_kv_indices
:
List
[
int
]
=
[]
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
paged_kv_indptr
:
List
[
int
]
=
[
0
]
# paged_kv_last_page_len is the length of the last page of each request
paged_kv_last_page_len
:
List
[
int
]
=
[]
if
len
(
seq_group_metadata_list
)
==
0
:
return
self
.
_model_input_cls
()
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
+
self
.
block_size
-
1
)
//
self
.
block_size
block_aligned_sliding_window
=
\
sliding_window_blocks
*
self
.
block_size
builder
=
ModelInputForGPUBuilder
(
weakref
.
proxy
(
self
),
finished_requests_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
is_prompt
=
seq_group_metadata
.
is_prompt
for
seq_id
in
seq_ids
:
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
if
(
self
.
scheduler_config
is
not
None
and
self
.
scheduler_config
.
chunked_prefill_enabled
and
not
(
computed_block_nums
is
None
or
computed_block_nums
==
[])):
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching "
"now."
)
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
if
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
else
:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len
=
seq_data
.
get_len
()
-
1
seq_len
=
min
(
seq_data
.
get_len
(),
context_len
+
seq_group_metadata
.
token_chunk_size
)
if
is_prompt
:
tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
else
:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens
=
[
seq_data
.
get_last_token_id
()]
# Prefix cache was hit.
# Prefix is not supported with sliding_window
prefix_cache_hit
=
(
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
and
is_prompt
)
# These are seq_len/context_len capped to the sliding window.
# They are passed to decode kernel.
# We still need original seq_len/context_len to compute slot
# mapping (and input position) below.
curr_sliding_window_blocks
=
None
sliding_seq_len
=
seq_len
sliding_context_len
=
context_len
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if
(
self
.
sliding_window
is
not
None
and
not
is_prompt
):
curr_sliding_window_blocks
=
sliding_window_blocks
if
self
.
scheduler_config
.
use_v2_block_manager
:
# number of elements in last block
suff_len
=
seq_len
%
self
.
block_size
sliding_seq_len
=
min
(
seq_len
,
block_aligned_sliding_window
+
suff_len
)
if
suff_len
>
0
:
curr_sliding_window_blocks
+=
1
else
:
sliding_seq_len
=
min
(
seq_len
,
self
.
sliding_window
)
sliding_context_len
=
sliding_seq_len
-
1
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if
prefix_cache_hit
:
assert
computed_block_nums
is
not
None
context_len
=
len
(
computed_block_nums
)
*
self
.
block_size
tokens
=
tokens
[
context_len
:]
# need to think what to set it to when we have both sliding
# window and prefix caching...
assert
self
.
sliding_window
is
None
,
\
"Prefix caching is not supported with sliding window"
sliding_context_len
=
context_len
if
self
.
attn_backend
.
get_name
()
==
"flash-attn"
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
# TODO(woosuk): This is a temporary fix. We should
# provide a unified interface for different backends.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
else
:
block_table
=
computed_block_nums
elif
(
self
.
scheduler_config
.
chunked_prefill_enabled
or
not
is_prompt
):
if
seq_group_metadata
.
block_tables
is
not
None
:
# chunked prefill or decode
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
if
curr_sliding_window_blocks
is
not
None
:
block_table
=
block_table
[
-
curr_sliding_window_blocks
:]
else
:
# Only happens when memory profiling runs.
block_table
=
[]
else
:
# Prefill without chunked prefill or memory profiling.
block_table
=
[]
block_tables
.
append
(
block_table
)
seq_lens
.
append
(
sliding_seq_len
)
context_lens
.
append
(
sliding_context_len
)
query_len
=
sliding_seq_len
-
sliding_context_len
query_lens
.
append
(
query_len
)
input_tokens
.
extend
(
tokens
)
input_positions
.
extend
(
list
(
range
(
context_len
,
seq_len
)))
lora_id
=
seq_group_metadata
.
lora_int_id
prompt_adapter_id
=
seq_group_metadata
.
prompt_adapter_id
if
is_prompt
:
assert
len
(
seq_ids
)
==
1
num_prefills
+=
1
num_prefill_tokens
+=
len
(
tokens
)
decode_only
=
False
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
num_decode_tokens
+=
query_len
decode_seq_lens
.
append
(
sliding_seq_len
)
if
lora_id
>
0
:
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_index_mapping
+=
[
lora_id
]
*
query_len
lora_prompt_mapping
.
extend
(
[
lora_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
is
not
None
else
1
))
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
# Process multi-modal data
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
if
prompt_adapter_id
>
0
and
is_prompt
:
prompt_adapter_requests
.
add
(
seq_group_metadata
.
prompt_adapter_request
)
num_tokens
=
seq_group_metadata
.
\
prompt_adapter_num_virtual_tokens
pm
=
[
prompt_adapter_id
]
*
num_tokens
+
[
0
]
*
(
query_len
-
num_tokens
)
prompt_adapter_index_mapping
+=
pm
prompt_adapter_prompt_mapping
.
extend
(
[
prompt_adapter_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
is_profile_run
=
_is_block_tables_empty
(
seq_group_metadata
.
block_tables
)
if
is_profile_run
:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq_len
)
continue
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with
# _PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
if
is_prompt
:
assert
self
.
scheduler_config
.
use_v2_block_manager
\
or
context_len
==
0
,
(
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager"
)
# It is an optimization. When it is decoding, it is always
# 0. When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx
=
max
(
0
,
query_len
-
self
.
sliding_window
)
for
i
in
range
(
context_len
,
seq_len
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
# Prepare input tensors for flashinfer
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
seq_len
=
seq_data
.
get_len
()
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
paged_kv_indptr
.
append
(
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
paged_kv_last_page_len
.
append
(
last_page_len
)
batch_size
=
len
(
input_tokens
)
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
decode_seq_lens
,
default
=
0
)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
use_captured_graph
=
(
decode_only
and
not
self
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_decode_seq_len
<=
self
.
max_seq_len_to_capture
)
if
use_captured_graph
:
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
for
_
in
range
(
graph_batch_size
-
batch_size
):
input_tokens
.
append
(
0
)
input_positions
.
append
(
0
)
slot_mapping
.
append
(
_PAD_SLOT_ID
)
seq_lens
.
append
(
1
)
block_tables
.
append
([])
lora_index_mapping
.
append
(
0
)
prompt_adapter_index_mapping
.
append
(
0
)
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
last_paged_kv_indptr
=
paged_kv_indptr
[
-
1
]
paged_kv_indptr
.
append
(
last_paged_kv_indptr
)
paged_kv_last_page_len
.
append
(
0
)
batch_size
=
graph_batch_size
num_decode_tokens
=
batch_size
if
use_captured_graph
:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
self
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
self
.
device
)
else
:
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
block_tables
)
block_tables
=
make_tensor_with_pad
(
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
self
.
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions_tensor
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping_tensor
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
logits_soft_cap
=
getattr
(
self
.
model_config
.
hf_config
,
'attn_logit_softcapping'
,
None
)
if
logits_soft_cap
is
not
None
and
self
.
attn_backend
.
get_name
(
)
!=
"flashinfer"
:
raise
ValueError
(
"Please use Flashinfer backend for models with"
"logits_soft_cap (i.e., Gemma-2)."
" Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
if
len
(
paged_kv_indptr
)
>
0
:
paged_kv_indices_tensor
=
torch
.
tensor
(
paged_kv_indices
,
device
=
'cpu'
,
dtype
=
torch
.
int
)
paged_kv_indptr_tensor
=
torch
.
tensor
(
paged_kv_indptr
,
device
=
'cpu'
,
dtype
=
torch
.
int
)
paged_kv_last_page_len_tensor
=
torch
.
tensor
(
paged_kv_last_page_len
,
device
=
'cpu'
,
dtype
=
torch
.
int
)
else
:
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_len_tensor
=
None
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
kv_cache_dtype
,
self
.
model_config
.
dtype
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
max_prefill_seq_len
=
max_prefill_seq_len
,
block_tables
=
block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor
,
num_qo_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
head_dim
=
self
.
model_config
.
get_head_size
(),
page_size
=
self
.
block_size
,
seq_start_loc
=
seq_start_loc
,
query_start_loc
=
query_start_loc
,
device
=
self
.
device
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
use_captured_graph
,
logits_soft_cap
=
logits_soft_cap
)
else
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
lora_index_mapping
,
lora_prompt_mapping
,
)
else
:
lora_mapping
=
None
if
self
.
prompt_adapter_config
:
prompt_adapter_mapping
=
PromptAdapterMapping
(
prompt_adapter_index_mapping
,
prompt_adapter_prompt_mapping
,
)
else
:
prompt_adapter_mapping
=
None
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
,
device
=
self
.
device
)
request_ids_to_seq_ids
=
{
seq_group_metadata
.
request_id
:
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_group_metadata
in
seq_group_metadata_list
}
return
self
.
_model_input_cls
(
input_tokens
=
input_tokens_tensor
,
input_positions
=
input_positions_tensor
,
attn_metadata
=
attn_metadata
,
seq_lens
=
seq_lens
,
query_lens
=
query_lens
,
lora_mapping
=
lora_mapping
,
lora_requests
=
lora_requests
,
multi_modal_kwargs
=
multi_modal_kwargs
,
request_ids_to_seq_ids
=
request_ids_to_seq_ids
,
finished_requests_ids
=
finished_requests_ids
,
prompt_adapter_mapping
=
prompt_adapter_mapping
,
prompt_adapter_requests
=
prompt_adapter_requests
,
)
builder
.
add_seq_group
(
seq_group_metadata
)
return
builder
.
build
()
# type: ignore
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
...
...
vllm/worker/model_runner_base.py
View file @
2fa4623d
...
...
@@ -113,6 +113,21 @@ class ModelRunnerInputBase(ABC):
raise
NotImplementedError
class
ModelRunnerInputBuilderBase
(
ABC
,
Generic
[
T
]):
"""A builder to create ModelRunnerInputBase objects.
"""
@
abstractmethod
def
add_seq_group
(
self
,
seq_group_metadata
):
"""TBA"""
raise
NotImplementedError
@
abstractmethod
def
build
(
self
,
*
args
,
**
kwargs
)
->
T
:
"""Build metadata with on-device tensors."""
raise
NotImplementedError
class
ModelRunnerBase
(
ABC
,
Generic
[
T
]):
"""
Model runner interface that abstracts a particular hardware and/or type of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment