Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
705f6a35
Commit
705f6a35
authored
Jul 16, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1
parents
af837396
4cf256ae
Changes
439
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2065 additions
and
760 deletions
+2065
-760
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+27
-12
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+7
-0
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+398
-82
vllm/attention/layer.py
vllm/attention/layer.py
+24
-12
vllm/attention/ops/blocksparse_attention/interface.py
vllm/attention/ops/blocksparse_attention/interface.py
+3
-2
vllm/attention/ops/blocksparse_attention/utils.py
vllm/attention/ops/blocksparse_attention/utils.py
+33
-7
vllm/attention/ops/ipex_attn.py
vllm/attention/ops/ipex_attn.py
+120
-0
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+11
-3
vllm/attention/selector.py
vllm/attention/selector.py
+35
-4
vllm/block.py
vllm/block.py
+0
-43
vllm/config.py
vllm/config.py
+306
-133
vllm/core/block/block_table.py
vllm/core/block/block_table.py
+57
-30
vllm/core/block/common.py
vllm/core/block/common.py
+154
-44
vllm/core/block/cpu_gpu_block_allocator.py
vllm/core/block/cpu_gpu_block_allocator.py
+56
-28
vllm/core/block/interfaces.py
vllm/core/block/interfaces.py
+44
-12
vllm/core/block/naive_block.py
vllm/core/block/naive_block.py
+150
-68
vllm/core/block/prefix_caching_block.py
vllm/core/block/prefix_caching_block.py
+476
-208
vllm/core/block_manager_v1.py
vllm/core/block_manager_v1.py
+12
-10
vllm/core/block_manager_v2.py
vllm/core/block_manager_v2.py
+104
-49
vllm/core/scheduler.py
vllm/core/scheduler.py
+48
-13
No files found.
Too many changes to show.
To preserve performance only
439 of 439+
files are displayed.
Plain diff
Email patch
vllm/attention/backends/torch_sdpa.py
View file @
705f6a35
...
@@ -7,9 +7,17 @@ import torch
...
@@ -7,9 +7,17 @@ import torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
torch.nn.functional
import
scaled_dot_product_attention
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
PagedAttentionMetadata
PagedAttentionMetadata
)
from
vllm.utils
import
is_cpu
if
is_cpu
():
try
:
from
vllm.attention.ops.ipex_attn
import
PagedAttention
except
ImportError
:
from
vllm.attention.ops.paged_attn
import
PagedAttention
else
:
from
vllm.attention.ops.paged_attn
import
PagedAttention
class
TorchSDPABackend
(
AttentionBackend
):
class
TorchSDPABackend
(
AttentionBackend
):
...
@@ -23,8 +31,8 @@ class TorchSDPABackend(AttentionBackend):
...
@@ -23,8 +31,8 @@ class TorchSDPABackend(AttentionBackend):
return
TorchSDPABackendImpl
return
TorchSDPABackendImpl
@
staticmethod
@
staticmethod
def
make
_metadata
(
*
args
,
**
kwargs
)
->
"TorchSDPA
Metadata"
:
def
get
_metadata
_cls
()
->
Type
[
"Attention
Metadata"
]
:
return
TorchSDPAMetadata
(
*
args
,
**
kwargs
)
return
TorchSDPAMetadata
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
@@ -137,6 +145,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -137,6 +145,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
kv_scale
:
float
=
1.0
,
kv_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
"""Forward pass with torch SDPA and PagedAttention.
...
@@ -150,6 +159,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -150,6 +159,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
kv_scale
==
1.0
assert
kv_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TorchSDPABackendImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
@@ -197,13 +211,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -197,13 +211,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata
.
attn_bias
):
attn_metadata
.
attn_bias
):
end
=
start
+
seq_len
end
=
start
+
seq_len
sub_out
=
scaled_dot_product_attention
(
sub_out
=
scaled_dot_product_attention
(
query
[:,
start
:
end
,
:],
query
[
None
,
:,
start
:
end
,
:],
key
[:,
start
:
end
,
:],
key
[
None
,
:,
start
:
end
,
:],
value
[:,
start
:
end
,
:],
value
[
None
,
:,
start
:
end
,
:],
attn_mask
=
mask
,
attn_mask
=
mask
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
is_causal
=
not
self
.
need_mask
,
is_causal
=
not
self
.
need_mask
,
scale
=
self
.
scale
).
movedim
(
query
.
dim
()
-
2
,
0
)
scale
=
self
.
scale
).
squeeze
(
0
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start
:
end
,
:,
:]
=
sub_out
output
[
start
:
end
,
:,
:]
=
sub_out
start
=
end
start
=
end
else
:
else
:
...
@@ -236,7 +251,7 @@ def _make_alibi_bias(
...
@@ -236,7 +251,7 @@ def _make_alibi_bias(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seq_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
)
->
List
[
torch
.
Tensor
]:
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
attn_biases
:
List
[
torch
.
Tensor
]
=
[]
for
seq_len
in
seq_lens
:
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# NOTE(zhuohan): HF uses
...
@@ -248,7 +263,7 @@ def _make_alibi_bias(
...
@@ -248,7 +263,7 @@ def _make_alibi_bias(
num_heads
=
alibi_slopes
.
shape
[
0
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
.
unsqueeze_
(
0
)
inf_mask
=
torch
.
empty
(
inf_mask
=
torch
.
empty
(
(
1
,
seq_len
,
seq_len
),
(
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
...
@@ -262,7 +277,7 @@ def _make_sliding_window_bias(
...
@@ -262,7 +277,7 @@ def _make_sliding_window_bias(
window_size
:
Optional
[
int
],
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
)
->
List
[
torch
.
Tensor
]:
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
attn_biases
:
List
[
torch
.
Tensor
]
=
[]
for
seq_len
in
seq_lens
:
for
seq_len
in
seq_lens
:
tensor
=
torch
.
full
(
tensor
=
torch
.
full
(
(
1
,
seq_len
,
seq_len
),
(
1
,
seq_len
,
seq_len
),
...
...
vllm/attention/backends/utils.py
0 → 100644
View file @
705f6a35
"""Attention backend utils"""
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
=
(
"ROCm/HIP is not currently supported "
"with encoder/decoder models."
)
vllm/attention/backends/xformers.py
View file @
705f6a35
...
@@ -6,10 +6,11 @@ import torch
...
@@ -6,10 +6,11 @@ import torch
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
(
AttentionBias
,
from
xformers.ops.fmha.attn_bias
import
(
AttentionBias
,
BlockDiagonalCausalMask
,
BlockDiagonalCausalMask
,
BlockDiagonalMask
,
LowerTriangularMaskWithTensorBias
)
LowerTriangularMaskWithTensorBias
)
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -28,8 +29,8 @@ class XFormersBackend(AttentionBackend):
...
@@ -28,8 +29,8 @@ class XFormersBackend(AttentionBackend):
return
XFormersImpl
return
XFormersImpl
@
staticmethod
@
staticmethod
def
make
_metadata
(
*
args
,
**
kwargs
)
->
"XFormers
Metadata"
:
def
get
_metadata
_cls
()
->
Type
[
"Attention
Metadata"
]
:
return
XFormersMetadata
(
*
args
,
**
kwargs
)
return
XFormersMetadata
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
@@ -66,11 +67,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -66,11 +67,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
dynamically, it should be stored in tensor. The tensor has to be
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
updated from `CUDAGraphRunner.forward` API.
"""
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# |---------- N-1 iteration --------|
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |---------------- N iteration ---------------------|
...
@@ -79,8 +75,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -79,8 +75,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# |-------------------- seq_len ----------------------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
# seq_lens stored as a tensor.
max_query_len
:
Optional
[
int
]
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# FIXME: It is for flash attn.
# FIXME: It is for flash attn.
# Maximum sequence length among prefill batch. 0 if there are decoding
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
# requests only.
...
@@ -88,26 +85,55 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -88,26 +85,55 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum sequence length among decode batch. 0 if there are prefill
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
# requests only.
max_decode_seq_len
:
int
max_decode_seq_len
:
int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# Whether or not if cuda graph is enabled.
# is [4, 6], it is [0, 4, 10].
# Cuda-graph is currently enabled for decoding only.
query_start_loc
:
Optional
[
torch
.
Tensor
]
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
=
None
# FIXME: It is for flash attn.
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,) A tensor of context lengths (tokens that are computed
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Whether or not if cuda graph is enabled.
# Maximum query length in the batch. None for decoding.
# Cuda-graph is currently enabled for decoding only.
max_query_len
:
Optional
[
int
]
=
None
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata
:
Optional
[
"XFormersMetadata"
]
=
None
_cached_prefill_metadata
:
Optional
[
"XFormersMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"XFormersMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"XFormersMetadata"
]
=
None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
Optional
[
int
]
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# It is a list because it is needed to set per prompt
...
@@ -115,6 +141,28 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -115,6 +141,28 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# from xformer API.
# from xformer API.
# will not appear in the __repr__ and __init__
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
self
.
attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
self
.
encoder_attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
self
.
cross_attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
@
property
def
is_all_encoder_attn_metadata_set
(
self
):
'''
All attention metadata required for encoder attention is set.
'''
return
((
self
.
encoder_seq_lens
is
not
None
)
and
(
self
.
encoder_seq_lens_tensor
is
not
None
)
and
(
self
.
max_encoder_seq_len
is
not
None
))
@
property
def
is_all_cross_attn_metadata_set
(
self
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
(
self
.
is_all_encoder_attn_metadata_set
and
(
self
.
cross_slot_mapping
is
not
None
)
and
(
self
.
cross_block_tables
is
not
None
))
@
property
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"XFormersMetadata"
]:
def
prefill_metadata
(
self
)
->
Optional
[
"XFormersMetadata"
]:
...
@@ -122,30 +170,50 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -122,30 +170,50 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
return
None
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
if
self
.
_cached_prefill_metadata
is
not
None
:
# Recover cached prefill-phase attention
# metadata structure
return
self
.
_cached_prefill_metadata
return
self
.
_cached_prefill_metadata
assert
self
.
seq_lens
is
not
None
assert
((
self
.
seq_lens
is
not
None
)
assert
self
.
seq_lens_tensor
is
not
None
or
(
self
.
encoder_seq_lens
is
not
None
))
assert
self
.
query_start_loc
is
not
None
assert
((
self
.
seq_lens_tensor
is
not
None
)
assert
self
.
context_lens_tensor
is
not
None
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
assert
self
.
block_tables
is
not
None
# Compute some attn_metadata fields which default to None
query_start_loc
=
(
None
if
self
.
query_start_loc
is
None
else
self
.
query_start_loc
[:
self
.
num_prefills
+
1
])
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[:
self
.
num_prefill_tokens
])
seq_lens
=
(
None
if
self
.
seq_lens
is
None
else
self
.
seq_lens
[:
self
.
num_prefills
])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[:
self
.
num_prefills
])
context_lens_tensor
=
(
None
if
self
.
context_lens_tensor
is
None
else
self
.
context_lens_tensor
[:
self
.
num_prefills
])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[:
self
.
num_prefills
])
# Construct & cache prefill-phase attention metadata structure
self
.
_cached_prefill_metadata
=
XFormersMetadata
(
self
.
_cached_prefill_metadata
=
XFormersMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
]
,
slot_mapping
=
slot_mapping
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
]
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
]
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
self
.
max_query_len
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_seq_len
=
0
,
max_decode_seq_len
=
0
,
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
],
query_start_loc
=
query_start_loc
,
seq_start_loc
=
None
,
context_lens_tensor
=
context_lens_tensor
,
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
block_tables
,
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
)
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_prefill_metadata
return
self
.
_cached_prefill_metadata
@
property
@
property
...
@@ -154,29 +222,146 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -154,29 +222,146 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
return
None
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
if
self
.
_cached_decode_metadata
is
not
None
:
# Recover cached decode-phase attention
# metadata structure
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
assert
self
.
block_tables
is
not
None
assert
((
self
.
seq_lens_tensor
is
not
None
)
assert
self
.
seq_lens_tensor
is
not
None
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
# Compute some attn_metadata fields which default to None
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[
self
.
num_prefill_tokens
:])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[
self
.
num_prefills
:])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[
self
.
num_prefills
:])
# Construct & cache decode-phase attention metadata structure
self
.
_cached_decode_metadata
=
XFormersMetadata
(
self
.
_cached_decode_metadata
=
XFormersMetadata
(
num_prefills
=
0
,
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
slot_mapping
=
slot_mapping
,
seq_lens
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
max_prefill_seq_len
=
0
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
None
,
block_tables
=
block_tables
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
use_cuda_graph
=
self
.
use_cuda_graph
,
)
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
def
_get_attn_bias
(
attn_metadata
:
XFormersMetadata
,
attn_type
:
AttentionType
,
)
->
Optional
[
AttentionBias
]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if
attn_type
==
AttentionType
.
DECODER
:
return
attn_metadata
.
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
return
attn_metadata
.
encoder_attn_bias
else
:
# attn_type == AttentionType.ENCODER_DECODER
return
attn_metadata
.
cross_attn_bias
def
_set_attn_bias
(
attn_metadata
:
XFormersMetadata
,
attn_bias
:
List
[
Optional
[
AttentionBias
]],
attn_type
:
AttentionType
,
)
->
None
:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if
attn_type
==
AttentionType
.
DECODER
:
attn_metadata
.
attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
attn_metadata
.
encoder_attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
attn_metadata
.
cross_attn_bias
=
attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
_get_seq_len_block_table_args
(
attn_metadata
:
XFormersMetadata
,
is_prompt
:
bool
,
attn_type
:
AttentionType
,
)
->
tuple
:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if
is_prompt
:
max_seq_len
=
attn_metadata
.
max_prefill_seq_len
else
:
max_seq_len
=
attn_metadata
.
max_decode_seq_len
return
(
attn_metadata
.
seq_lens_tensor
,
max_seq_len
,
attn_metadata
.
block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return
(
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
cross_block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# No block tables associated with encoder attention
return
(
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
None
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
class
XFormersImpl
(
AttentionImpl
[
XFormersMetadata
]):
class
XFormersImpl
(
AttentionImpl
[
XFormersMetadata
]):
"""
"""
If the input tensors contain prompt tokens, the layout is as follows:
If the input tensors contain prompt tokens, the layout is as follows:
...
@@ -238,51 +423,144 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -238,51 +423,144 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
,
value
:
torch
.
Tensor
,
value
:
Optional
[
torch
.
Tensor
]
,
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
"XFormersMetadata"
,
attn_metadata
:
"XFormersMetadata"
,
kv_scale
:
float
=
1.0
,
kv_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
"""Forward pass with xFormers and PagedAttention.
For decoder-only models: query, key and value must be non-None.
For encoder/decoder models:
* XFormersImpl.forward() may be invoked for both self- and cross-
attention layers.
* For self-attention: query, key and value must be non-None.
* For cross-attention:
* Query must be non-None
* During prefill, key and value must be non-None; key and value
get cached for use during decode.
* During decode, key and value may be None, since:
(1) key and value tensors were cached during prefill, and
(2) cross-attention key and value tensors do not grow during
decode
A note on how the attn_type (attention type enum) argument impacts
attention forward() behavior:
* DECODER: normal decoder-only behavior;
use decoder self-attention block table
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
Args:
Args:
query: shape = [num_tokens, num_heads * head_size]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
if
(
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
raise
AttributeError
(
"Encoder attention requires setting "
"encoder metadata attributes."
)
elif
(
attn_type
==
AttentionType
.
ENCODER_DECODER
and
(
not
attn_metadata
.
is_all_cross_attn_metadata_set
)):
raise
AttributeError
(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
assert
value
is
not
None
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
else
:
assert
value
is
None
# Self-attention vs. cross-attention will impact
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if
(
attn_type
!=
AttentionType
.
ENCODER
and
kv_cache
is
not
None
):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
# Reshape the input keys and values and store them in the cache.
if
(
key
is
not
None
)
and
(
value
is
not
None
):
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
if
attn_type
==
AttentionType
.
ENCODER_DECODER
:
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
# Update cross-attention KV cache (prefill-only)
value_cache
,
# During cross-attention decode, key & value will be None,
attn_metadata
.
slot_mapping
,
# preventing this IF-statement branch from running
self
.
kv_cache_dtype
,
kv_scale
)
updated_slot_mapping
=
attn_metadata
.
cross_slot_mapping
else
:
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
# Update self-attention KV cache (prefill/decode)
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
updated_slot_mapping
=
attn_metadata
.
slot_mapping
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
kv_scale
)
if
attn_type
!=
AttentionType
.
ENCODER
:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
else
:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_tokens
=
0
if
attn_type
==
AttentionType
.
DECODER
:
# Only enforce this shape-constraint for decoder
# self-attention
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
if
key
is
not
None
and
value
is
not
None
:
value
=
value
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
...
@@ -294,10 +572,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -294,10 +572,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# block tables are empty if the prompt does not have a cached
# block tables are empty if the prompt does not have a cached
# prefix.
# prefix.
out
=
self
.
_run_memory_efficient_xformers_forward
(
out
=
self
.
_run_memory_efficient_xformers_forward
(
query
,
key
,
value
,
prefill_meta
)
query
,
key
,
value
,
prefill_meta
,
attn_type
=
attn_type
)
assert
out
.
shape
==
output
[:
num_prefill_tokens
].
shape
assert
out
.
shape
==
output
[:
num_prefill_tokens
].
shape
output
[:
num_prefill_tokens
]
=
out
output
[:
num_prefill_tokens
]
=
out
else
:
else
:
assert
prefill_meta
.
query_start_loc
is
not
None
assert
prefill_meta
.
max_query_len
is
not
None
# prefix-enabled attention
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# deal with different data types between KV and FP8 KV cache,
...
@@ -320,13 +602,20 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -320,13 +602,20 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
output
[:
num_prefill_tokens
]
=
out
output
[:
num_prefill_tokens
]
=
out
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
(
seq_lens_arg
,
max_seq_len_arg
,
block_tables_arg
,
)
=
_get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
decode_query
,
decode_query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
decode_meta
.
block_tables
,
block_tables
_arg
,
decode_meta
.
seq_lens_
tensor
,
seq_lens_
arg
,
decode_meta
.
max_decode
_seq_len
,
max
_seq_len
_arg
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
...
@@ -343,6 +632,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -343,6 +632,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
XFormersMetadata
,
attn_metadata
:
XFormersMetadata
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Attention for 1D query of multiple prompts. Multiple prompt
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
tokens are flattened in to `query` input.
...
@@ -356,8 +646,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -356,8 +646,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
"""
"""
assert
attn_metadata
.
seq_lens
is
not
None
original_query
=
query
original_query
=
query
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# GQA/MQA requires the shape [B, M, G, H, K].
# GQA/MQA requires the shape [B, M, G, H, K].
...
@@ -375,18 +669,39 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -375,18 +669,39 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Set attention bias if not provided. This typically happens at
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
# FIXME(woosuk): This is a hack.
if
attn_metadata
.
attn_bias
is
None
:
attn_bias
=
_get_attn_bias
(
attn_metadata
,
attn_type
)
if
attn_bias
is
None
:
if
self
.
alibi_slopes
is
None
:
if
self
.
alibi_slopes
is
None
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
if
(
attn_type
==
AttentionType
.
ENCODER_DECODER
):
attn_metadata
.
seq_lens
)
assert
attn_metadata
.
seq_lens
is
not
None
assert
attn_metadata
.
encoder_seq_lens
is
not
None
# Default enc/dec cross-attention mask is non-causal
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
attn_metadata
.
seq_lens
,
attn_metadata
.
encoder_seq_lens
)
elif
attn_type
==
AttentionType
.
ENCODER
:
assert
attn_metadata
.
encoder_seq_lens
is
not
None
# Default encoder self-attention mask is non-causal
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
attn_metadata
.
encoder_seq_lens
)
else
:
assert
attn_metadata
.
seq_lens
is
not
None
# Default decoder self-attention mask is causal
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
attn_metadata
.
seq_lens
)
if
self
.
sliding_window
is
not
None
:
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
self
.
sliding_window
)
attn_metadata
.
attn_bias
=
[
attn_bias
]
attn_bias
=
[
attn_bias
]
else
:
else
:
attn_metadata
.
attn_bias
=
_make_alibi_bias
(
assert
attn_metadata
.
seq_lens
is
not
None
self
.
alibi_slopes
,
self
.
num_kv_heads
,
query
.
dtype
,
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
attn_metadata
.
seq_lens
)
self
.
num_kv_heads
,
query
.
dtype
,
attn_metadata
.
seq_lens
)
_set_attn_bias
(
attn_metadata
,
attn_bias
,
attn_type
)
# No alibi slopes.
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
# TODO(woosuk): Too many view operations. Let's try to reduce
...
@@ -400,7 +715,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -400,7 +715,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query
,
query
,
key
,
key
,
value
,
value
,
attn_bias
=
attn_
metadata
.
attn_
bias
[
0
],
attn_bias
=
attn_bias
[
0
],
p
=
0.0
,
p
=
0.0
,
scale
=
self
.
scale
)
scale
=
self
.
scale
)
return
out
.
view_as
(
original_query
)
return
out
.
view_as
(
original_query
)
...
@@ -409,6 +724,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -409,6 +724,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# FIXME(woosuk): Because xformers does not support dynamic sequence
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# one. This is inefficient, especially when we have many short prompts.
assert
attn_metadata
.
seq_lens
is
not
None
output
=
torch
.
empty_like
(
original_query
)
output
=
torch
.
empty_like
(
original_query
)
start
=
0
start
=
0
for
i
,
seq_len
in
enumerate
(
attn_metadata
.
seq_lens
):
for
i
,
seq_len
in
enumerate
(
attn_metadata
.
seq_lens
):
...
@@ -417,7 +733,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -417,7 +733,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query
[
None
,
start
:
end
],
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
attn_bias
=
attn_
metadata
.
attn_
bias
[
i
],
attn_bias
=
attn_bias
[
i
],
p
=
0.0
,
p
=
0.0
,
scale
=
self
.
scale
)
scale
=
self
.
scale
)
# TODO(woosuk): Unnecessary copy. Optimize.
# TODO(woosuk): Unnecessary copy. Optimize.
...
@@ -431,8 +747,8 @@ def _make_alibi_bias(
...
@@ -431,8 +747,8 @@ def _make_alibi_bias(
num_kv_heads
:
int
,
num_kv_heads
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seq_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
)
->
L
owerTriangularMaskWithTensor
Bias
:
)
->
L
ist
[
Attention
Bias
]
:
attn_biases
=
[]
attn_biases
:
List
[
AttentionBias
]
=
[]
for
seq_len
in
seq_lens
:
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# NOTE(zhuohan): HF uses
...
...
vllm/attention/layer.py
View file @
705f6a35
...
@@ -4,11 +4,12 @@ from typing import Any, Dict, List, Optional
...
@@ -4,11 +4,12 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
,
AttentionType
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8KVCacheMethod
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
...
@@ -56,15 +57,19 @@ class Attention(nn.Module):
...
@@ -56,15 +57,19 @@ class Attention(nn.Module):
quant_method
=
quant_config
.
get_quant_method
(
quant_method
=
quant_config
.
get_quant_method
(
self
)
if
quant_config
else
None
self
)
if
quant_config
else
None
if
quant_method
is
not
None
:
if
quant_method
is
not
None
:
if
self
.
kv_cache_dtype
==
"fp8_e5m2"
:
assert
isinstance
(
quant_method
,
Fp8KVCacheMethod
)
raise
ValueError
(
"fp8_e5m2 kv-cache is not supported with "
# TODO (mgoin): kv cache dtype should be specified in the FP8
"fp8 checkpoints."
)
# checkpoint config and become the "auto" behavior
# When FP8 quantization is enabled, we make a parameter
if
"fp8"
in
self
.
kv_cache_dtype
:
# "kv_scale" so that it can be loaded from FP8 checkpoint.
if
self
.
kv_cache_dtype
==
"fp8_e5m2"
:
# The kv_scale will then be converted back
raise
ValueError
(
"fp8_e5m2 kv-cache is not supported with "
# to self._kv_scale in a native float32 value after weight loading.
"fp8 checkpoints."
)
self
.
quant_method
=
quant_method
# When FP8 quantization is enabled, we make a parameter
self
.
quant_method
.
create_weights
(
self
)
# "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back to self._kv_scale
# in a native float32 value after weight loading.
self
.
quant_method
=
quant_method
self
.
quant_method
.
create_weights
(
self
)
# During model initialization, the default dtype is set as the model
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
# weight and activation dtype.
...
@@ -85,9 +90,16 @@ class Attention(nn.Module):
...
@@ -85,9 +90,16 @@ class Attention(nn.Module):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_kv_scale
)
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_kv_scale
,
attn_type
=
attn_type
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
...
...
vllm/attention/ops/blocksparse_attention/interface.py
View file @
705f6a35
...
@@ -2,13 +2,14 @@ import math
...
@@ -2,13 +2,14 @@ import math
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
from
vllm.utils
import
is_cpu
,
is_hip
from
.utils
import
(
dense_to_crow_col
,
get_head_sliding_step
,
from
.utils
import
(
dense_to_crow_col
,
get_head_sliding_step
,
get_sparse_attn_mask
)
get_sparse_attn_mask
)
IS_COMPUTE_8_OR_ABOVE
=
(
torch
.
cuda
.
is_available
()
IS_COMPUTE_8_OR_ABOVE
=
(
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
)
and
current_platform
.
get_device_capability
()[
0
]
>=
8
)
if
IS_COMPUTE_8_OR_ABOVE
:
if
IS_COMPUTE_8_OR_ABOVE
:
from
.blocksparse_attention_kernel
import
blocksparse_flash_attn_varlen_fwd
from
.blocksparse_attention_kernel
import
blocksparse_flash_attn_varlen_fwd
...
@@ -235,4 +236,4 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
...
@@ -235,4 +236,4 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
v
,
v
,
cu_seqlens_k
,
cu_seqlens_k
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
sm_scale
=
sm_scale
)
sm_scale
=
sm_scale
)
\ No newline at end of file
vllm/attention/ops/blocksparse_attention/utils.py
View file @
705f6a35
...
@@ -4,9 +4,35 @@
...
@@ -4,9 +4,35 @@
from
functools
import
lru_cache
from
functools
import
lru_cache
import
numpy
as
np
import
torch
import
torch
import
triton
import
triton
from
scipy
import
sparse
class
csr_matrix
:
"""Simple implementation of CSR matrix conversion without scipy.
This replaced scipy.sparse.csr_matrix() previously used."""
def
__init__
(
self
,
input_array
):
if
not
isinstance
(
input_array
,
np
.
ndarray
):
raise
ValueError
(
"Input must be a NumPy array"
)
self
.
shape
=
input_array
.
shape
rows
,
cols
=
self
.
shape
data
=
[]
indices
=
[]
indptr
=
[
0
]
for
i
in
range
(
rows
):
for
j
in
range
(
cols
):
if
input_array
[
i
,
j
]:
data
.
append
(
input_array
[
i
,
j
])
indices
.
append
(
j
)
indptr
.
append
(
len
(
indices
))
self
.
data
=
np
.
array
(
data
)
self
.
indices
=
np
.
array
(
indices
)
self
.
indptr
=
np
.
array
(
indptr
)
def
dense_to_crow_col
(
x
:
torch
.
Tensor
):
def
dense_to_crow_col
(
x
:
torch
.
Tensor
):
...
@@ -19,7 +45,7 @@ def dense_to_crow_col(x: torch.Tensor):
...
@@ -19,7 +45,7 @@ def dense_to_crow_col(x: torch.Tensor):
assert
x
.
dim
()
in
(
2
,
3
)
assert
x
.
dim
()
in
(
2
,
3
)
if
x
.
dim
()
==
2
:
if
x
.
dim
()
==
2
:
x
=
x
[
None
]
x
=
x
[
None
]
x
=
[
sparse
.
csr_matrix
(
xi
.
bool
().
cpu
().
numpy
())
for
xi
in
x
]
x
=
[
csr_matrix
(
xi
.
bool
().
cpu
().
numpy
())
for
xi
in
x
]
crows
=
torch
.
vstack
([
torch
.
from_numpy
(
xi
.
indptr
)
for
xi
in
x
])
crows
=
torch
.
vstack
([
torch
.
from_numpy
(
xi
.
indptr
)
for
xi
in
x
])
cols
=
[
torch
.
from_numpy
(
xi
.
indices
)
for
xi
in
x
]
cols
=
[
torch
.
from_numpy
(
xi
.
indices
)
for
xi
in
x
]
max_cols
=
max
(
len
(
xi
)
for
xi
in
cols
)
max_cols
=
max
(
len
(
xi
)
for
xi
in
cols
)
...
@@ -77,11 +103,11 @@ def _get_sparse_attn_mask_homo_head(
...
@@ -77,11 +103,11 @@ def _get_sparse_attn_mask_homo_head(
):
):
"""
"""
:return: a tuple of 3:
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
- tuple of crow_indices, col_indices representation
of CSR format.
of CSR format.
- block dense mask
- block dense mask
- all token dense mask (be aware that it can be
- all token dense mask (be aware that it can be
OOM if it is too big) if `return_dense==True`,
OOM if it is too big) if `return_dense==True`,
otherwise, None
otherwise, None
"""
"""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -148,10 +174,10 @@ def get_sparse_attn_mask(
...
@@ -148,10 +174,10 @@ def get_sparse_attn_mask(
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
or "bias" (-inf for skip token, 0 or others)
or "bias" (-inf for skip token, 0 or others)
:return: a tuple of 3:
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
- tuple of crow_indices, col_indices representation
of CSR format.
of CSR format.
- block dense mask
- block dense mask
- all token dense mask (be aware that it can be OOM if it
- all token dense mask (be aware that it can be OOM if it
is too big) if `return_dense==True`, otherwise, None
is too big) if `return_dense==True`, otherwise, None
"""
"""
assert
dense_mask_type
in
(
"binary"
,
"bias"
)
assert
dense_mask_type
in
(
"binary"
,
"bias"
)
...
...
vllm/attention/ops/ipex_attn.py
0 → 100644
View file @
705f6a35
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
import
torch
from
vllm
import
_custom_ops
as
ops
class
PagedAttention
:
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
80
,
96
,
112
,
128
,
256
]
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
Tuple
[
int
,
...]:
return
(
2
,
num_blocks
,
block_size
*
num_kv_heads
*
head_size
)
@
staticmethod
def
split_kv_cache
(
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
return
key_cache
,
value_cache
@
staticmethod
def
write_to_paged_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
*
args
,
)
->
None
:
ipex_modules
.
PagedAttention
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
().
int
())
@
staticmethod
def
forward_decode
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_scale
:
float
,
*
args
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
block_size
=
value_cache
.
shape
[
2
]
head_mapping
=
torch
.
arange
(
0
,
num_kv_heads
,
device
=
"cpu"
,
dtype
=
torch
.
int32
,
).
view
(
num_kv_heads
,
1
).
repeat_interleave
(
query
.
size
(
1
)
//
num_kv_heads
).
flatten
()
ipex_modules
.
PagedAttention
.
single_query_cached_kv_attention
(
output
,
query
.
contiguous
(),
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
)
return
output
@
staticmethod
def
forward_prefix
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
subquery_start_loc
:
torch
.
Tensor
,
prompt_lens_tensor
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_subquery_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
*
args
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
*
args
,
)
->
None
:
raise
NotImplementedError
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]],
*
args
,
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
vllm/attention/ops/prefix_prefill.py
View file @
705f6a35
...
@@ -5,6 +5,8 @@ import torch
...
@@ -5,6 +5,8 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm.platforms
import
current_platform
if
triton
.
__version__
>=
"2.1.0"
:
if
triton
.
__version__
>=
"2.1.0"
:
@
triton
.
jit
@
triton
.
jit
...
@@ -683,8 +685,14 @@ if triton.__version__ >= "2.1.0":
...
@@ -683,8 +685,14 @@ if triton.__version__ >= "2.1.0":
alibi_slopes
=
None
,
alibi_slopes
=
None
,
sliding_window
=
None
):
sliding_window
=
None
):
cap
=
torch
.
cuda
.
get_device_capability
()
cap
=
current_platform
.
get_device_capability
()
BLOCK
=
128
if
cap
[
0
]
>=
8
else
64
BLOCK
=
128
if
cap
[
0
]
>=
8
else
64
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
if
q
.
dtype
is
torch
.
float32
:
BLOCK
=
BLOCK
//
2
# shape constraints
# shape constraints
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lq
==
Lk
and
Lk
==
Lv
...
@@ -716,7 +724,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -716,7 +724,7 @@ if triton.__version__ >= "2.1.0":
b_ctx_len
,
b_ctx_len
,
alibi_slopes
,
alibi_slopes
,
v_cache
.
shape
[
3
],
v_cache
.
shape
[
3
],
8
,
k_cache
.
shape
[
4
]
,
o
,
o
,
b_loc
.
stride
(
0
),
b_loc
.
stride
(
0
),
b_loc
.
stride
(
1
),
b_loc
.
stride
(
1
),
...
@@ -766,7 +774,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -766,7 +774,7 @@ if triton.__version__ >= "2.1.0":
b_seq_len
,
b_seq_len
,
b_ctx_len
,
b_ctx_len
,
v_cache
.
shape
[
3
],
v_cache
.
shape
[
3
],
8
,
k_cache
.
shape
[
4
]
,
o
,
o
,
b_loc
.
stride
(
0
),
b_loc
.
stride
(
0
),
b_loc
.
stride
(
1
),
b_loc
.
stride
(
1
),
...
...
vllm/attention/selector.py
View file @
705f6a35
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_cpu
,
is_hip
from
vllm.utils
import
is_cpu
,
is_hip
,
is_openvino
,
is_tpu
,
is_xpu
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -17,7 +17,10 @@ class _Backend(enum.Enum):
...
@@ -17,7 +17,10 @@ class _Backend(enum.Enum):
XFORMERS
=
enum
.
auto
()
XFORMERS
=
enum
.
auto
()
ROCM_FLASH
=
enum
.
auto
()
ROCM_FLASH
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
OPENVINO
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
PALLAS
=
enum
.
auto
()
IPEX
=
enum
.
auto
()
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
...
@@ -57,15 +60,29 @@ def get_attn_backend(
...
@@ -57,15 +60,29 @@ def get_attn_backend(
ROCmFlashAttentionBackend
)
ROCmFlashAttentionBackend
)
return
ROCmFlashAttentionBackend
return
ROCmFlashAttentionBackend
elif
backend
==
_Backend
.
TORCH_SDPA
:
elif
backend
==
_Backend
.
TORCH_SDPA
:
assert
is_cpu
(),
RuntimeError
(
"Torch SDPA backend is only used for the CPU device."
)
logger
.
info
(
"Using Torch SDPA backend."
)
logger
.
info
(
"Using Torch SDPA backend."
)
from
vllm.attention.backends.torch_sdpa
import
TorchSDPABackend
from
vllm.attention.backends.torch_sdpa
import
TorchSDPABackend
return
TorchSDPABackend
return
TorchSDPABackend
elif
backend
==
_Backend
.
OPENVINO
:
logger
.
info
(
"Using OpenVINO Attention backend."
)
from
vllm.attention.backends.openvino
import
OpenVINOAttentionBackend
return
OpenVINOAttentionBackend
elif
backend
==
_Backend
.
IPEX
:
assert
is_xpu
(),
RuntimeError
(
"IPEX attention backend is only used for the XPU device."
)
logger
.
info
(
"Using IPEX attention backend."
)
from
vllm.attention.backends.ipex_attn
import
IpexAttnBackend
return
IpexAttnBackend
elif
backend
==
_Backend
.
FLASHINFER
:
elif
backend
==
_Backend
.
FLASHINFER
:
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
warning
(
"Eager mode is required for the Flashinfer backend. "
"Please make sure --enforce-eager is set."
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
return
FlashInferBackend
return
FlashInferBackend
elif
backend
==
_Backend
.
PALLAS
:
logger
.
info
(
"Using Pallas backend."
)
from
vllm.attention.backends.pallas
import
PallasAttentionBackend
return
PallasAttentionBackend
else
:
else
:
raise
ValueError
(
"Invalid attention backend."
)
raise
ValueError
(
"Invalid attention backend."
)
...
@@ -80,7 +97,6 @@ def which_attn_to_use(
...
@@ -80,7 +97,6 @@ def which_attn_to_use(
block_size
:
int
,
block_size
:
int
,
)
->
_Backend
:
)
->
_Backend
:
"""Returns which flash attention backend to use."""
"""Returns which flash attention backend to use."""
# Default case.
# Default case.
selected_backend
=
_Backend
.
FLASH_ATTN
selected_backend
=
_Backend
.
FLASH_ATTN
...
@@ -100,6 +116,21 @@ def which_attn_to_use(
...
@@ -100,6 +116,21 @@ def which_attn_to_use(
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
return
_Backend
.
TORCH_SDPA
return
_Backend
.
TORCH_SDPA
if
is_openvino
():
if
selected_backend
!=
_Backend
.
OPENVINO
:
logger
.
info
(
"Cannot use %s backend on OpenVINO."
,
selected_backend
)
return
_Backend
.
OPENVINO
if
is_xpu
():
if
selected_backend
!=
_Backend
.
IPEX
:
logger
.
info
(
"Cannot use %s backend on XPU."
,
selected_backend
)
return
_Backend
.
IPEX
if
is_tpu
():
if
selected_backend
!=
_Backend
.
PALLAS
:
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
return
_Backend
.
PALLAS
if
is_hip
():
if
is_hip
():
# AMD GPUs.
# AMD GPUs.
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
...
...
vllm/block.py
View file @
705f6a35
...
@@ -3,52 +3,9 @@ from typing import List
...
@@ -3,52 +3,9 @@ from typing import List
from
vllm.utils
import
Device
from
vllm.utils
import
Device
_BLANK_TOKEN_ID
=
-
1
DEFAULT_LAST_ACCESSED_TIME
=
-
1
DEFAULT_LAST_ACCESSED_TIME
=
-
1
class
LogicalTokenBlock
:
"""A block that stores a contiguous chunk of tokens from left to right.
Logical blocks are used to represent the states of the corresponding
physical blocks in the KV cache.
"""
def
__init__
(
self
,
block_number
:
int
,
block_size
:
int
,
)
->
None
:
self
.
block_number
=
block_number
self
.
block_size
=
block_size
self
.
token_ids
=
[
_BLANK_TOKEN_ID
]
*
block_size
self
.
num_tokens
=
0
def
is_empty
(
self
)
->
bool
:
return
self
.
num_tokens
==
0
def
get_num_empty_slots
(
self
)
->
int
:
return
self
.
block_size
-
self
.
num_tokens
def
is_full
(
self
)
->
bool
:
return
self
.
num_tokens
==
self
.
block_size
def
append_tokens
(
self
,
token_ids
:
List
[
int
])
->
None
:
assert
len
(
token_ids
)
<=
self
.
get_num_empty_slots
()
curr_idx
=
self
.
num_tokens
self
.
token_ids
[
curr_idx
:
curr_idx
+
len
(
token_ids
)]
=
token_ids
self
.
num_tokens
+=
len
(
token_ids
)
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
token_ids
[:
self
.
num_tokens
]
def
get_last_token_id
(
self
)
->
int
:
assert
self
.
num_tokens
>
0
return
self
.
token_ids
[
self
.
num_tokens
-
1
]
class
PhysicalTokenBlock
:
class
PhysicalTokenBlock
:
"""Represents the state of a block in the KV cache."""
"""Represents the state of a block in the KV cache."""
...
...
vllm/config.py
View file @
705f6a35
import
enum
import
enum
import
json
import
json
from
dataclasses
import
dataclass
,
field
,
fields
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Tuple
,
from
typing
import
TYPE_CHECKING
,
ClassVar
,
List
,
Optional
,
Tuple
,
Union
Union
)
import
torch
import
torch
from
transformers
import
PretrainedConfig
,
PreTrainedTokenizerBase
from
transformers
import
PretrainedConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.tracing
import
is_otel_installed
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.utils
import
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
from
vllm.utils
import
(
cuda_device_count_stateless
,
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
,
is_openvino
,
is_tpu
,
is_xpu
,
print_warning_once
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
ray.util.placement_group
import
PlacementGroup
...
@@ -23,6 +25,17 @@ logger = init_logger(__name__)
...
@@ -23,6 +25,17 @@ logger = init_logger(__name__)
_GB
=
1
<<
30
_GB
=
1
<<
30
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS
=
32768
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS
=
32768
_PP_SUPPORTED_MODELS
=
[
"AquilaModel"
,
"AquilaForCausalLM"
,
"InternLMForCausalLM"
,
"LlamaForCausalLM"
,
"LLaMAForCausalLM"
,
"MistralForCausalLM"
,
"Phi3ForCausalLM"
,
"GPT2LMHeadModel"
,
]
class
ModelConfig
:
class
ModelConfig
:
"""Configuration for the model.
"""Configuration for the model.
...
@@ -105,6 +118,7 @@ class ModelConfig:
...
@@ -105,6 +118,7 @@ class ModelConfig:
disable_sliding_window
:
bool
=
False
,
disable_sliding_window
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
multimodal_config
:
Optional
[
"MultiModalConfig"
]
=
None
,
)
->
None
:
)
->
None
:
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
...
@@ -123,12 +137,10 @@ class ModelConfig:
...
@@ -123,12 +137,10 @@ class ModelConfig:
self
.
quantization
=
quantization
self
.
quantization
=
quantization
self
.
quantization_param_path
=
quantization_param_path
self
.
quantization_param_path
=
quantization_param_path
self
.
enforce_eager
=
enforce_eager
self
.
enforce_eager
=
enforce_eager
self
.
max_context_len_to_capture
=
max_context_len_to_capture
if
max_context_len_to_capture
is
not
None
:
if
self
.
max_context_len_to_capture
is
not
None
:
raise
ValueError
(
"`max_context_len_to_capture` is deprecated. "
raise
ValueError
(
"`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead."
)
"Use `max_seq_len_to_capture` instead."
)
self
.
max_seq_len_to_capture
=
(
max_seq_len_to_capture
self
.
max_seq_len_to_capture
=
max_seq_len_to_capture
or
max_context_len_to_capture
)
self
.
max_logprobs
=
max_logprobs
self
.
max_logprobs
=
max_logprobs
self
.
disable_sliding_window
=
disable_sliding_window
self
.
disable_sliding_window
=
disable_sliding_window
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
skip_tokenizer_init
=
skip_tokenizer_init
...
@@ -137,6 +149,17 @@ class ModelConfig:
...
@@ -137,6 +149,17 @@ class ModelConfig:
code_revision
,
rope_scaling
,
rope_theta
)
code_revision
,
rope_scaling
,
rope_theta
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
if
(
not
self
.
disable_sliding_window
and
self
.
hf_text_config
.
model_type
==
"gemma2"
and
self
.
hf_text_config
.
sliding_window
is
not
None
):
print_warning_once
(
"Gemma 2 uses sliding window attention for every odd layer, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f
"(
{
self
.
hf_text_config
.
sliding_window
}
)."
)
self
.
disable_sliding_window
=
True
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
max_model_len
=
_get_and_verify_max_len
(
hf_config
=
self
.
hf_text_config
,
hf_config
=
self
.
hf_text_config
,
max_model_len
=
max_model_len
,
max_model_len
=
max_model_len
,
...
@@ -144,6 +167,8 @@ class ModelConfig:
...
@@ -144,6 +167,8 @@ class ModelConfig:
sliding_window_len
=
self
.
get_hf_config_sliding_window
())
sliding_window_len
=
self
.
get_hf_config_sliding_window
())
self
.
served_model_name
=
get_served_model_name
(
model
,
self
.
served_model_name
=
get_served_model_name
(
model
,
served_model_name
)
served_model_name
)
self
.
multimodal_config
=
multimodal_config
if
not
self
.
skip_tokenizer_init
:
if
not
self
.
skip_tokenizer_init
:
self
.
_verify_tokenizer_mode
()
self
.
_verify_tokenizer_mode
()
self
.
_verify_embedding_mode
()
self
.
_verify_embedding_mode
()
...
@@ -212,7 +237,7 @@ class ModelConfig:
...
@@ -212,7 +237,7 @@ class ModelConfig:
f
"
{
self
.
quantization
}
quantization is currently not "
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
f
"supported in ROCm."
)
if
(
self
.
quantization
if
(
self
.
quantization
not
in
[
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
]
):
not
in
(
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
)
):
logger
.
warning
(
logger
.
warning
(
"%s quantization is not fully "
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"optimized yet. The speed can be slower than "
...
@@ -228,7 +253,8 @@ class ModelConfig:
...
@@ -228,7 +253,8 @@ class ModelConfig:
self
,
self
,
parallel_config
:
"ParallelConfig"
,
parallel_config
:
"ParallelConfig"
,
)
->
None
:
)
->
None
:
total_num_attention_heads
=
self
.
hf_text_config
.
num_attention_heads
total_num_attention_heads
=
getattr
(
self
.
hf_text_config
,
"num_attention_heads"
,
0
)
tensor_parallel_size
=
parallel_config
.
tensor_parallel_size
tensor_parallel_size
=
parallel_config
.
tensor_parallel_size
if
total_num_attention_heads
%
tensor_parallel_size
!=
0
:
if
total_num_attention_heads
%
tensor_parallel_size
!=
0
:
raise
ValueError
(
raise
ValueError
(
...
@@ -236,13 +262,13 @@ class ModelConfig:
...
@@ -236,13 +262,13 @@ class ModelConfig:
" must be divisible by tensor parallel size "
" must be divisible by tensor parallel size "
f
"(
{
tensor_parallel_size
}
)."
)
f
"(
{
tensor_parallel_size
}
)."
)
total_num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
pipeline_parallel_size
=
parallel_config
.
pipeline_parallel_size
pipeline_parallel_size
=
parallel_config
.
pipeline_parallel_size
if
total_num_hidden_layers
%
pipeline_parallel_size
!=
0
:
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
raise
ValueError
(
if
not
all
(
arch
in
_PP_SUPPORTED_MODELS
f
"Total number of hidden layers (
{
total_num_hidden_layers
}
) "
for
arch
in
architectures
)
and
pipeline_parallel_size
>
1
:
"must be divisible by pipeline parallel size "
raise
NotImplementedError
(
f
"(
{
pipeline_parallel_size
}
)."
)
"Pipeline parallelism is only supported for the following "
f
" architectures:
{
_PP_SUPPORTED_MODELS
}
."
)
if
self
.
quantization
==
"bitsandbytes"
and
(
if
self
.
quantization
==
"bitsandbytes"
and
(
parallel_config
.
tensor_parallel_size
>
1
parallel_config
.
tensor_parallel_size
>
1
...
@@ -251,8 +277,7 @@ class ModelConfig:
...
@@ -251,8 +277,7 @@ class ModelConfig:
"BitAndBytes quantization with TP or PP is not supported yet."
)
"BitAndBytes quantization with TP or PP is not supported yet."
)
def
get_hf_config_sliding_window
(
self
)
->
Optional
[
int
]:
def
get_hf_config_sliding_window
(
self
)
->
Optional
[
int
]:
"""Get the sliding window size, or None if disabled.
"""Get the sliding window size, or None if disabled."""
"""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
# addition to sliding window size. We check if that field is present
...
@@ -307,7 +332,11 @@ class ModelConfig:
...
@@ -307,7 +332,11 @@ class ModelConfig:
return
1
return
1
# For DBRX and MPT
# For DBRX and MPT
if
self
.
hf_config
.
model_type
in
[
"dbrx"
,
"mpt"
]:
if
self
.
hf_config
.
model_type
==
"mpt"
:
if
"kv_n_heads"
in
self
.
hf_config
.
attn_config
:
return
self
.
hf_config
.
attn_config
[
"kv_n_heads"
]
return
self
.
hf_config
.
num_attention_heads
if
self
.
hf_config
.
model_type
==
"dbrx"
:
return
getattr
(
self
.
hf_config
.
attn_config
,
"kv_n_heads"
,
return
getattr
(
self
.
hf_config
.
attn_config
,
"kv_n_heads"
,
self
.
hf_config
.
num_attention_heads
)
self
.
hf_config
.
num_attention_heads
)
...
@@ -341,12 +370,43 @@ class ModelConfig:
...
@@ -341,12 +370,43 @@ class ModelConfig:
def
get_num_attention_heads
(
self
,
def
get_num_attention_heads
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
parallel_config
:
"ParallelConfig"
)
->
int
:
return
self
.
hf_text_config
.
num_attention_heads
//
\
num_heads
=
getattr
(
self
.
hf_text_config
,
"
num_attention_heads
"
,
0
)
parallel_config
.
tensor_parallel_size
return
num_heads
//
parallel_config
.
tensor_parallel_size
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
total_num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
from
vllm.distributed.utils
import
get_pp_indices
return
total_num_hidden_layers
//
parallel_config
.
pipeline_parallel_size
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
"num_hidden_layers"
,
0
)
pp_rank
=
parallel_config
.
rank
//
parallel_config
.
tensor_parallel_size
pp_size
=
parallel_config
.
pipeline_parallel_size
start
,
end
=
get_pp_indices
(
total_num_hidden_layers
,
pp_rank
,
pp_size
)
return
end
-
start
def
contains_seqlen_agnostic_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
bool
:
"""True for Mamba/SSM models (Jamba)"""
return
self
.
_get_num_seqlen_agnostic_layers
(
parallel_config
)
>
0
def
get_layers_block_type
(
self
,
parallel_config
:
"ParallelConfig"
)
->
List
[
str
]:
num_layers
=
self
.
get_num_layers
(
parallel_config
)
# Transformers supports layers_block_type @property
return
getattr
(
self
.
hf_config
,
"layers_block_type"
,
[
"attention"
]
*
num_layers
)
def
get_num_attention_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
return
len
([
t
for
t
in
self
.
get_layers_block_type
(
parallel_config
)
if
t
==
"attention"
])
def
_get_num_seqlen_agnostic_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
return
len
([
t
for
t
in
self
.
get_layers_block_type
(
parallel_config
)
if
t
!=
"attention"
])
class
CacheConfig
:
class
CacheConfig
:
...
@@ -611,45 +671,50 @@ class ParallelConfig:
...
@@ -611,45 +671,50 @@ class ParallelConfig:
if
self
.
distributed_executor_backend
is
None
and
self
.
world_size
>
1
:
if
self
.
distributed_executor_backend
is
None
and
self
.
world_size
>
1
:
# We use multiprocessing by default if world_size fits on the
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
# current node and we aren't in a ray placement group.
from
torch.cuda
import
device_count
from
vllm.executor
import
ray_utils
from
vllm.executor
import
ray_utils
backend
=
"mp"
backend
=
"mp"
ray_found
=
ray_utils
.
ray
is
not
None
ray_found
=
ray_utils
.
ray
_
is
_available
()
if
device_count
()
<
self
.
world_size
:
if
cuda_
device_count
_stateless
()
<
self
.
world_size
:
if
not
ray_found
:
if
not
ray_found
:
raise
ValueError
(
"Unable to load Ray which is "
raise
ValueError
(
"Unable to load Ray which is "
"required for multi-node inference"
)
"required for multi-node inference, "
"please install Ray with `pip install "
"ray`."
)
from
ray_utils
.
ray_import_err
backend
=
"ray"
backend
=
"ray"
elif
ray_found
:
elif
ray_found
:
from
ray.util
import
get_current_placement_group
if
self
.
placement_group
:
if
self
.
placement_group
or
get_current_placement_group
():
backend
=
"ray"
backend
=
"ray"
else
:
from
ray
import
is_initialized
as
ray_is_initialized
if
ray_is_initialized
():
from
ray.util
import
get_current_placement_group
if
get_current_placement_group
():
backend
=
"ray"
self
.
distributed_executor_backend
=
backend
self
.
distributed_executor_backend
=
backend
logger
.
info
(
"Defaulting to use %s for distributed inference"
,
logger
.
info
(
"Defaulting to use %s for distributed inference"
,
backend
)
backend
)
self
.
_verify_args
()
self
.
_verify_args
()
self
.
rank
=
0
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
if
self
.
pipeline_parallel_size
>
1
:
if
(
self
.
pipeline_parallel_size
>
1
raise
NotImplementedError
(
and
self
.
distributed_executor_backend
==
"mp"
):
"Pipeline parallelism is not supported yet."
)
raise
NotImplementedError
(
"Pipeline parallelism is not supported "
"yet with multiprocessing."
)
if
self
.
distributed_executor_backend
not
in
(
"ray"
,
"mp"
,
None
):
if
self
.
distributed_executor_backend
not
in
(
"ray"
,
"mp"
,
None
):
raise
ValueError
(
raise
ValueError
(
"Unrecognized distributed executor backend. Supported values "
"Unrecognized distributed executor backend. Supported values "
"are 'ray' or 'mp'."
)
"are 'ray' or 'mp'."
)
if
not
self
.
disable_custom_all_reduce
and
self
.
world_size
>
1
:
if
self
.
distributed_executor_backend
==
"ray"
:
if
is_hip
():
from
vllm.executor
import
ray_utils
self
.
disable_custom_all_reduce
=
True
ray_utils
.
assert_ray_available
()
logger
.
info
(
if
is_hip
():
"Disabled the custom all-reduce kernel because it is not "
self
.
disable_custom_all_reduce
=
True
"supported on AMD GPUs."
)
logger
.
info
(
elif
self
.
pipeline_parallel_size
>
1
:
"Disabled the custom all-reduce kernel because it is not "
self
.
disable_custom_all_reduce
=
True
"supported on AMD GPUs."
)
logger
.
info
(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism."
)
if
self
.
ray_workers_use_nsight
and
(
if
self
.
ray_workers_use_nsight
and
(
not
self
.
distributed_executor_backend
==
"ray"
):
not
self
.
distributed_executor_backend
==
"ray"
):
raise
ValueError
(
"Unable to use nsight profiling unless workers "
raise
ValueError
(
"Unable to use nsight profiling unless workers "
...
@@ -720,7 +785,6 @@ class SchedulerConfig:
...
@@ -720,7 +785,6 @@ class SchedulerConfig:
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
embedding_mode
=
embedding_mode
self
.
embedding_mode
=
embedding_mode
self
.
preemption_mode
=
preemption_mode
self
.
preemption_mode
=
preemption_mode
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
...
@@ -754,8 +818,14 @@ class DeviceConfig:
...
@@ -754,8 +818,14 @@ class DeviceConfig:
# Automated device type detection
# Automated device type detection
if
is_neuron
():
if
is_neuron
():
self
.
device_type
=
"neuron"
self
.
device_type
=
"neuron"
elif
is_openvino
():
self
.
device_type
=
"openvino"
elif
is_tpu
():
self
.
device_type
=
"tpu"
elif
is_cpu
():
elif
is_cpu
():
self
.
device_type
=
"cpu"
self
.
device_type
=
"cpu"
elif
is_xpu
():
self
.
device_type
=
"xpu"
else
:
else
:
# We don't call torch.cuda.is_available() here to
# We don't call torch.cuda.is_available() here to
# avoid initializing CUDA before workers are forked
# avoid initializing CUDA before workers are forked
...
@@ -765,8 +835,10 @@ class DeviceConfig:
...
@@ -765,8 +835,10 @@ class DeviceConfig:
self
.
device_type
=
device
self
.
device_type
=
device
# Some device types require processing inputs on CPU
# Some device types require processing inputs on CPU
if
self
.
device_type
in
[
"neuron"
]:
if
self
.
device_type
in
[
"neuron"
,
"openvino"
]:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
elif
self
.
device_type
in
[
"tpu"
]:
self
.
device
=
None
else
:
else
:
# Set device with device type
# Set device with device type
self
.
device
=
torch
.
device
(
self
.
device_type
)
self
.
device
=
torch
.
device
(
self
.
device_type
)
...
@@ -785,6 +857,7 @@ class SpeculativeConfig:
...
@@ -785,6 +857,7 @@ class SpeculativeConfig:
target_parallel_config
:
ParallelConfig
,
target_parallel_config
:
ParallelConfig
,
target_dtype
:
str
,
target_dtype
:
str
,
speculative_model
:
Optional
[
str
],
speculative_model
:
Optional
[
str
],
speculative_draft_tensor_parallel_size
:
Optional
[
int
],
num_speculative_tokens
:
Optional
[
int
],
num_speculative_tokens
:
Optional
[
int
],
speculative_max_model_len
:
Optional
[
int
],
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
enable_chunked_prefill
:
bool
,
...
@@ -792,6 +865,9 @@ class SpeculativeConfig:
...
@@ -792,6 +865,9 @@ class SpeculativeConfig:
speculative_disable_by_batch_size
:
Optional
[
int
],
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
draft_token_acceptance_method
:
str
,
typical_acceptance_sampler_posterior_threshold
:
Optional
[
float
],
typical_acceptance_sampler_posterior_alpha
:
Optional
[
float
],
)
->
Optional
[
"SpeculativeConfig"
]:
)
->
Optional
[
"SpeculativeConfig"
]:
"""Create a SpeculativeConfig if possible, else return None.
"""Create a SpeculativeConfig if possible, else return None.
...
@@ -807,8 +883,11 @@ class SpeculativeConfig:
...
@@ -807,8 +883,11 @@ class SpeculativeConfig:
target_dtype (str): The data type used for the target model.
target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative
speculative_model (Optional[str]): The name of the speculative
model, if provided.
model, if provided.
speculative_draft_tensor_parallel_size (Optional[int]): The degree
of the tensor parallelism for the draft model.
num_speculative_tokens (Optional[int]): The number of speculative
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided.
tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required.
speculative_max_model_len (Optional[int]): The maximum model len of
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
the speculative model. Used when testing the ability to skip
speculation for some sequences.
speculation for some sequences.
...
@@ -825,30 +904,37 @@ class SpeculativeConfig:
...
@@ -825,30 +904,37 @@ class SpeculativeConfig:
window, if provided.
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided.
window, if provided.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
Returns:
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
the necessary conditions are met, else None.
"""
"""
if
speculative_model
is
None
and
num_speculative_tokens
is
None
:
if
speculative_model
is
None
:
if
num_speculative_tokens
is
not
None
:
raise
ValueError
(
"num_speculative_tokens was provided without "
"speculative_model."
)
return
None
return
None
if
speculative_model
is
not
None
and
num_speculative_tokens
is
None
:
raise
ValueError
(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f
"
{
speculative_model
=
}
and
{
num_speculative_tokens
=
}
."
)
if
(
speculative_disable_by_batch_size
is
not
None
if
(
speculative_disable_by_batch_size
is
not
None
and
speculative_disable_by_batch_size
<
2
):
and
speculative_disable_by_batch_size
<
2
):
raise
ValueError
(
"Expect the batch size threshold of disabling "
raise
ValueError
(
"Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
"speculative decoding is > 1, but got "
f
"
{
speculative_disable_by_batch_size
=
}
"
)
f
"
{
speculative_disable_by_batch_size
=
}
"
)
assert
(
speculative_model
is
not
None
and
num_speculative_tokens
is
not
None
)
if
enable_chunked_prefill
:
if
enable_chunked_prefill
:
raise
ValueError
(
raise
ValueError
(
"Speculative decoding and chunked prefill are "
"Speculative decoding and chunked prefill are "
...
@@ -902,6 +988,25 @@ class SpeculativeConfig:
...
@@ -902,6 +988,25 @@ class SpeculativeConfig:
max_logprobs
=
target_model_config
.
max_logprobs
,
max_logprobs
=
target_model_config
.
max_logprobs
,
)
)
draft_hf_config
=
draft_model_config
.
hf_config
if
(
num_speculative_tokens
is
not
None
and
hasattr
(
draft_hf_config
,
"num_lookahead_tokens"
)):
draft_hf_config
.
num_lookahead_tokens
=
num_speculative_tokens
n_predict
=
getattr
(
draft_hf_config
,
"n_predict"
,
None
)
if
n_predict
is
not
None
:
if
num_speculative_tokens
is
None
:
# Default to max value defined in draft model config.
num_speculative_tokens
=
n_predict
elif
num_speculative_tokens
>
n_predict
:
# Verify provided value doesn't exceed the maximum
# supported by the draft model.
raise
ValueError
(
"This speculative model supports a maximum of "
f
"num_speculative_tokens=
{
n_predict
}
, but "
f
"
{
num_speculative_tokens
=
}
was provided."
)
draft_model_config
.
max_model_len
=
(
draft_model_config
.
max_model_len
=
(
SpeculativeConfig
.
_maybe_override_draft_max_model_len
(
SpeculativeConfig
.
_maybe_override_draft_max_model_len
(
speculative_max_model_len
,
speculative_max_model_len
,
...
@@ -911,7 +1016,19 @@ class SpeculativeConfig:
...
@@ -911,7 +1016,19 @@ class SpeculativeConfig:
draft_parallel_config
=
(
draft_parallel_config
=
(
SpeculativeConfig
.
create_draft_parallel_config
(
SpeculativeConfig
.
create_draft_parallel_config
(
target_parallel_config
))
target_parallel_config
,
speculative_draft_tensor_parallel_size
))
if
num_speculative_tokens
is
None
:
raise
ValueError
(
"num_speculative_tokens must be provided with "
"speculative_model unless the draft model config contains an "
"n_predict parameter."
)
if
typical_acceptance_sampler_posterior_threshold
is
None
:
typical_acceptance_sampler_posterior_threshold
=
0.09
if
typical_acceptance_sampler_posterior_alpha
is
None
:
typical_acceptance_sampler_posterior_alpha
=
0.3
return
SpeculativeConfig
(
return
SpeculativeConfig
(
draft_model_config
,
draft_model_config
,
...
@@ -920,6 +1037,11 @@ class SpeculativeConfig:
...
@@ -920,6 +1037,11 @@ class SpeculativeConfig:
speculative_disable_by_batch_size
,
speculative_disable_by_batch_size
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
,
ngram_prompt_lookup_min
,
draft_token_acceptance_method
=
draft_token_acceptance_method
,
typical_acceptance_sampler_posterior_threshold
=
\
typical_acceptance_sampler_posterior_threshold
,
typical_acceptance_sampler_posterior_alpha
=
\
typical_acceptance_sampler_posterior_alpha
,
)
)
@
staticmethod
@
staticmethod
...
@@ -959,16 +1081,26 @@ class SpeculativeConfig:
...
@@ -959,16 +1081,26 @@ class SpeculativeConfig:
@
staticmethod
@
staticmethod
def
create_draft_parallel_config
(
def
create_draft_parallel_config
(
target_parallel_config
:
ParallelConfig
)
->
ParallelConfig
:
target_parallel_config
:
ParallelConfig
,
speculative_draft_tensor_parallel_size
:
Optional
[
int
]
)
->
ParallelConfig
:
"""Create a parallel config for use by the draft worker.
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config. In the future the
This is mostly a copy of the target parallel config, except the tp_size.
draft worker can have a different parallel strategy, e.g. TP=1.
"""
"""
if
speculative_draft_tensor_parallel_size
is
None
:
speculative_draft_tensor_parallel_size
=
\
target_parallel_config
.
tensor_parallel_size
elif
speculative_draft_tensor_parallel_size
!=
1
:
# TODO(wooyeon): allow tp values larger than 1
raise
ValueError
(
f
"
{
speculative_draft_tensor_parallel_size
=
}
cannot be"
f
"other value than 1"
)
draft_parallel_config
=
ParallelConfig
(
draft_parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
target_parallel_config
.
pipeline_parallel_size
=
target_parallel_config
.
pipeline_parallel_size
,
pipeline_parallel_size
,
tensor_parallel_size
=
target_parallel_config
.
tensor_parallel_size
,
tensor_parallel_size
=
speculative_draft_
tensor_parallel_size
,
distributed_executor_backend
=
target_parallel_config
.
distributed_executor_backend
=
target_parallel_config
.
distributed_executor_backend
,
distributed_executor_backend
,
max_parallel_loading_workers
=
target_parallel_config
.
max_parallel_loading_workers
=
target_parallel_config
.
...
@@ -991,6 +1123,9 @@ class SpeculativeConfig:
...
@@ -991,6 +1123,9 @@ class SpeculativeConfig:
speculative_disable_by_batch_size
:
Optional
[
int
],
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
draft_token_acceptance_method
:
str
,
typical_acceptance_sampler_posterior_threshold
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
):
):
"""Create a SpeculativeConfig object.
"""Create a SpeculativeConfig object.
...
@@ -1004,6 +1139,19 @@ class SpeculativeConfig:
...
@@ -1004,6 +1139,19 @@ class SpeculativeConfig:
enqueue requests is larger than this value.
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
"""
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
draft_parallel_config
=
draft_parallel_config
...
@@ -1012,6 +1160,11 @@ class SpeculativeConfig:
...
@@ -1012,6 +1160,11 @@ class SpeculativeConfig:
speculative_disable_by_batch_size
speculative_disable_by_batch_size
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
or
0
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
or
0
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
or
0
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
or
0
self
.
draft_token_acceptance_method
=
draft_token_acceptance_method
self
.
typical_acceptance_sampler_posterior_threshold
=
\
typical_acceptance_sampler_posterior_threshold
self
.
typical_acceptance_sampler_posterior_alpha
=
\
typical_acceptance_sampler_posterior_alpha
self
.
_verify_args
()
self
.
_verify_args
()
...
@@ -1023,6 +1176,31 @@ class SpeculativeConfig:
...
@@ -1023,6 +1176,31 @@ class SpeculativeConfig:
if
self
.
draft_model_config
:
if
self
.
draft_model_config
:
self
.
draft_model_config
.
verify_with_parallel_config
(
self
.
draft_model_config
.
verify_with_parallel_config
(
self
.
draft_parallel_config
)
self
.
draft_parallel_config
)
# Validate and set draft token acceptance related settings.
if
(
self
.
draft_token_acceptance_method
is
None
):
raise
ValueError
(
"draft_token_acceptance_method is not set. "
"Expected values are rejection_sampler or "
"typical_acceptance_sampler."
)
if
(
self
.
draft_token_acceptance_method
!=
'rejection_sampler'
and
self
.
draft_token_acceptance_method
!=
'typical_acceptance_sampler'
):
raise
ValueError
(
"Expected draft_token_acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it "
f
"is
{
self
.
draft_token_acceptance_method
}
"
)
if
(
self
.
typical_acceptance_sampler_posterior_threshold
<
0
or
self
.
typical_acceptance_sampler_posterior_alpha
<
0
):
raise
ValueError
(
"Expected typical_acceptance_sampler_posterior_threshold "
"and typical_acceptance_sampler_posterior_alpha to be > 0. "
"Instead found "
f
"typical_acceptance_sampler_posterior_threshold = "
f
"
{
self
.
typical_acceptance_sampler_posterior_threshold
}
and "
f
"typical_acceptance_sampler_posterior_alpha = "
f
"
{
self
.
typical_acceptance_sampler_posterior_alpha
}
"
)
@
property
@
property
def
num_lookahead_slots
(
self
)
->
int
:
def
num_lookahead_slots
(
self
)
->
int
:
...
@@ -1094,79 +1272,49 @@ class LoRAConfig:
...
@@ -1094,79 +1272,49 @@ class LoRAConfig:
"Due to limitations of the custom LoRA CUDA kernel, "
"Due to limitations of the custom LoRA CUDA kernel, "
"max_num_batched_tokens must be <= 65528 when "
"max_num_batched_tokens must be <= 65528 when "
"LoRA is enabled."
)
"LoRA is enabled."
)
if
scheduler_config
.
chunked_prefill_enabled
:
raise
ValueError
(
"LoRA is not supported with chunked prefill yet."
)
@
dataclass
@
dataclass
class
VisionLanguageConfig
:
class
PromptAdapterConfig
:
"""Configs the input data format and how models should run for
max_prompt_adapters
:
int
vision language models."""
max_prompt_adapter_token
:
int
max_cpu_prompt_adapters
:
Optional
[
int
]
=
None
class
ImageInputType
(
enum
.
Enum
):
prompt_adapter_dtype
:
Optional
[
torch
.
dtype
]
=
None
"""Image input type into the vision language model.
An image roughly goes through the following transformation:
def
__post_init__
(
self
):
Raw image --> pixel values --> image features --> image embeddings.
library_name
=
'peft'
The difference between different image input types is where the
image encoder (pixel values --> image features) is run.
Different image input types also correspond to different tensor shapes.
For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
IMAGE_FEATURES: (1, 576, 1024).
"""
PIXEL_VALUES
=
enum
.
auto
()
IMAGE_FEATURES
=
enum
.
auto
()
image_input_type
:
ImageInputType
# The input id corresponding to image token.
image_token_id
:
int
# Used for running `run_prefill_max_token`.
# For models that support varying resolution, this corresponds to
# worst case scenario (biggest supported resolution).
image_input_shape
:
tuple
image_feature_size
:
int
# The image processor to load from HuggingFace
image_processor
:
Optional
[
str
]
image_processor_revision
:
Optional
[
str
]
@
classmethod
def
get_image_input_enum_type
(
cls
,
value
:
str
)
->
ImageInputType
:
"""Get the image input type from a string."""
try
:
try
:
return
cls
.
ImageInputType
[
value
.
upper
()]
__import__
(
library_name
)
except
KeyError
as
e
:
except
ImportError
as
e
:
raise
ValueError
(
f
"
{
value
}
is not a valid choice. "
raise
ImportError
(
f
"Expecting to choose from "
f
"'
{
library_name
}
' is not installed for prompt adapter support."
f
"
{
[
x
.
name
for
x
in
cls
.
ImageInputType
]
}
."
)
from
e
f
"Please install it using 'pip install
{
library_name
}
'."
)
from
e
#TODO(ywang96): make this a cached property once we refactor the
# VisionLanguageConfig class.
if
self
.
max_prompt_adapters
<
1
:
def
get_image_token_text
(
raise
ValueError
(
f
"max_prompt_adapters "
self
,
tokenizer
:
PreTrainedTokenizerBase
)
->
Tuple
[
str
,
str
]:
f
"(
{
self
.
max_prompt_adapters
}
) must be >= 1."
)
"""Get the image token placeholder text to be inserted into the
if
self
.
max_prompt_adapter_token
==
0
:
text prompt and the string representation of the image token id.
raise
ValueError
(
"max_prompt_adapter_token must be set."
)
"""
if
self
.
max_cpu_prompt_adapters
is
None
:
image_token_str
=
tokenizer
.
decode
(
self
.
image_token_id
)
self
.
max_cpu_prompt_adapters
=
self
.
max_prompt_adapters
return
image_token_str
*
self
.
image_feature_size
,
image_token_str
def
as_cli_args_dict
(
self
)
->
Dict
[
str
,
Any
]:
"""Flatten vision language config to pure args.
Compatible with what llm entrypoint expects.
def
verify_with_model_config
(
self
,
model_config
:
ModelConfig
):
"""
if
self
.
prompt_adapter_dtype
in
(
None
,
"auto"
):
result
:
Dict
[
str
,
Any
]
=
{}
self
.
prompt_adapter_dtype
=
model_config
.
dtype
for
f
in
fields
(
self
):
elif
isinstance
(
self
.
prompt_adapter_dtype
,
str
):
value
=
getattr
(
self
,
f
.
name
)
self
.
prompt_adapter_dtype
=
getattr
(
torch
,
if
isinstance
(
value
,
enum
.
Enum
):
self
.
prompt_adapter_dtype
)
result
[
f
.
name
]
=
value
.
name
.
lower
()
elif
isinstance
(
value
,
tuple
):
result
[
f
.
name
]
=
","
.
join
([
str
(
item
)
for
item
in
value
])
else
:
result
[
f
.
name
]
=
value
result
[
"disable_image_processor"
]
=
self
.
image_processor
is
None
return
result
@
dataclass
class
MultiModalConfig
:
"""Configs the input data format and how models should run for
multimodal models."""
# TODO: Add configs to init vision tower or not.
pass
_STR_DTYPE_TO_TORCH_DTYPE
=
{
_STR_DTYPE_TO_TORCH_DTYPE
=
{
...
@@ -1194,10 +1342,16 @@ def _get_and_verify_dtype(
...
@@ -1194,10 +1342,16 @@ def _get_and_verify_dtype(
dtype
=
dtype
.
lower
()
dtype
=
dtype
.
lower
()
if
dtype
==
"auto"
:
if
dtype
==
"auto"
:
if
config_dtype
==
torch
.
float32
:
if
config_dtype
==
torch
.
float32
:
# Following the common practice, we use float16 for float32
if
config
.
model_type
==
"gemma2"
:
# models.
logger
.
info
(
logger
.
info
(
"Casting torch.float32 to torch.float16."
)
"For Gemma 2, we downcast float32 to bfloat16 instead "
torch_dtype
=
torch
.
float16
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
)
torch_dtype
=
torch
.
bfloat16
else
:
# Following the common practice, we use float16 for float32
# models.
torch_dtype
=
torch
.
float16
else
:
else
:
torch_dtype
=
config_dtype
torch_dtype
=
config_dtype
else
:
else
:
...
@@ -1282,7 +1436,10 @@ def _get_and_verify_max_len(
...
@@ -1282,7 +1436,10 @@ def _get_and_verify_max_len(
derived_max_model_len
=
default_max_len
derived_max_model_len
=
default_max_len
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
rope_scaling
[
"type"
]
!=
"su"
:
# The correct one should be "longrope", kept "su" here
# to be backward compatible
if
rope_scaling
is
not
None
and
rope_scaling
[
"type"
]
!=
"su"
\
and
rope_scaling
[
"type"
]
!=
"longrope"
:
if
disable_sliding_window
:
if
disable_sliding_window
:
# TODO(robertgshaw): Find a model that supports rope_scaling
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
# with sliding window to see if this case should be allowed.
...
@@ -1357,6 +1514,17 @@ class DecodingConfig:
...
@@ -1357,6 +1514,17 @@ class DecodingConfig:
f
"must be one of
{
valid_guided_backends
}
"
)
f
"must be one of
{
valid_guided_backends
}
"
)
@
dataclass
class
ObservabilityConfig
:
"""Configuration for observability."""
otlp_traces_endpoint
:
Optional
[
str
]
=
None
def
__post_init__
(
self
):
if
not
is_otel_installed
()
and
self
.
otlp_traces_endpoint
is
not
None
:
raise
ValueError
(
"OpenTelemetry packages must be installed before "
"configuring 'otlp_traces_endpoint'"
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
EngineConfig
:
class
EngineConfig
:
"""Dataclass which contains all engine-related configuration. This
"""Dataclass which contains all engine-related configuration. This
...
@@ -1370,9 +1538,11 @@ class EngineConfig:
...
@@ -1370,9 +1538,11 @@ class EngineConfig:
device_config
:
DeviceConfig
device_config
:
DeviceConfig
load_config
:
LoadConfig
load_config
:
LoadConfig
lora_config
:
Optional
[
LoRAConfig
]
lora_config
:
Optional
[
LoRAConfig
]
vision_language
_config
:
Optional
[
VisionLanguage
Config
]
multimodal
_config
:
Optional
[
MultiModal
Config
]
speculative_config
:
Optional
[
SpeculativeConfig
]
speculative_config
:
Optional
[
SpeculativeConfig
]
decoding_config
:
Optional
[
DecodingConfig
]
decoding_config
:
Optional
[
DecodingConfig
]
observability_config
:
Optional
[
ObservabilityConfig
]
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
def
__post_init__
(
self
):
def
__post_init__
(
self
):
"""Verify configs are valid & consistent with each other.
"""Verify configs are valid & consistent with each other.
...
@@ -1384,6 +1554,9 @@ class EngineConfig:
...
@@ -1384,6 +1554,9 @@ class EngineConfig:
self
.
lora_config
.
verify_with_model_config
(
self
.
model_config
)
self
.
lora_config
.
verify_with_model_config
(
self
.
model_config
)
self
.
lora_config
.
verify_with_scheduler_config
(
self
.
lora_config
.
verify_with_scheduler_config
(
self
.
scheduler_config
)
self
.
scheduler_config
)
if
self
.
prompt_adapter_config
:
self
.
prompt_adapter_config
.
verify_with_model_config
(
self
.
model_config
)
def
to_dict
(
self
):
def
to_dict
(
self
):
"""Return the configs as a dictionary, for use in **kwargs.
"""Return the configs as a dictionary, for use in **kwargs.
...
...
vllm/core/block/block_table.py
View file @
705f6a35
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
vllm.core.block.common
import
BlockList
from
vllm.core.block.interfaces
import
Block
,
DeviceAwareBlockAllocator
from
vllm.core.block.interfaces
import
Block
,
DeviceAwareBlockAllocator
from
vllm.utils
import
Device
,
cdiv
,
chunk_list
from
vllm.utils
import
Device
,
cdiv
,
chunk_list
...
@@ -47,12 +48,10 @@ class BlockTable:
...
@@ -47,12 +48,10 @@ class BlockTable:
self
.
_allocator
=
block_allocator
self
.
_allocator
=
block_allocator
if
_blocks
is
None
:
if
_blocks
is
None
:
_blocks
=
[]
_blocks
=
[]
self
.
_blocks
:
List
[
Block
]
=
_blocks
self
.
_blocks
:
Block
List
=
Block
List
(
_blocks
)
self
.
_max_block_sliding_window
=
max_block_sliding_window
self
.
_max_block_sliding_window
=
max_block_sliding_window
# Use helper method instead of directly calculating, as blocks
self
.
_num_full_slots
=
self
.
_get_num_token_ids
()
# may not be allocated.
self
.
_num_full_slots
=
len
(
self
.
_get_all_token_ids
())
@
staticmethod
@
staticmethod
def
get_num_required_blocks
(
token_ids
:
List
[
int
],
block_size
:
int
)
->
int
:
def
get_num_required_blocks
(
token_ids
:
List
[
int
],
block_size
:
int
)
->
int
:
...
@@ -88,11 +87,18 @@ class BlockTable:
...
@@ -88,11 +87,18 @@ class BlockTable:
"""
"""
assert
not
self
.
_is_allocated
assert
not
self
.
_is_allocated
assert
token_ids
assert
token_ids
self
.
_blocks
=
self
.
_allocate_blocks_for_token_ids
(
prev_block
=
None
,
blocks
=
self
.
_allocate_blocks_for_token_ids
(
prev_block
=
None
,
token_ids
=
token_ids
,
token_ids
=
token_ids
,
device
=
device
)
device
=
device
)
self
.
update
(
blocks
)
self
.
_num_full_slots
=
len
(
token_ids
)
self
.
_num_full_slots
=
len
(
token_ids
)
def
update
(
self
,
blocks
:
List
[
Block
])
->
None
:
"""Resets the table to the newly provided blocks
(with their corresponding block ids)
"""
self
.
_blocks
.
update
(
blocks
)
def
append_token_ids
(
self
,
def
append_token_ids
(
self
,
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
num_lookahead_slots
:
int
=
0
,
num_lookahead_slots
:
int
=
0
,
...
@@ -140,11 +146,11 @@ class BlockTable:
...
@@ -140,11 +146,11 @@ class BlockTable:
num_lookahead_slots
)
num_lookahead_slots
)
# Update the blocks with the new tokens
# Update the blocks with the new tokens
blocks
=
self
.
_blocks
[
self
.
_num_full_slots
//
self
.
_block_size
:]
first_block_idx
=
self
.
_num_full_slots
//
self
.
_block_size
token_blocks
=
self
.
_chunk_token_blocks_for_append
(
token_ids
)
token_blocks
=
self
.
_chunk_token_blocks_for_append
(
token_ids
)
for
block
,
token_block
in
zip
(
blocks
,
token_blocks
):
for
i
,
token_block
in
enumerate
(
token_blocks
):
block
.
append_token_ids
(
token_block
)
self
.
_
block
s
.
append_token_ids
(
first_block_idx
+
i
,
token_block
)
self
.
_num_full_slots
+=
len
(
token_ids
)
self
.
_num_full_slots
+=
len
(
token_ids
)
...
@@ -174,8 +180,8 @@ class BlockTable:
...
@@ -174,8 +180,8 @@ class BlockTable:
for
_
in
range
(
blocks_to_allocate
):
for
_
in
range
(
blocks_to_allocate
):
assert
len
(
self
.
_blocks
)
>
0
assert
len
(
self
.
_blocks
)
>
0
self
.
_blocks
.
append
(
self
.
_blocks
.
append
(
self
.
_allocator
.
allocate_mutable
(
prev
_block
=
self
.
_blocks
[
-
1
],
self
.
_allocator
.
allocate_mutable_block
(
device
=
device
))
prev_block
=
self
.
_blocks
[
-
1
],
device
=
device
))
def
fork
(
self
)
->
"BlockTable"
:
def
fork
(
self
)
->
"BlockTable"
:
"""Creates a new BlockTable instance with a copy of the blocks from the
"""Creates a new BlockTable instance with a copy of the blocks from the
...
@@ -209,12 +215,12 @@ class BlockTable:
...
@@ -209,12 +215,12 @@ class BlockTable:
is set to `None`.
is set to `None`.
"""
"""
assert
self
.
_is_allocated
assert
self
.
_is_allocated
for
block
in
self
.
_
blocks
:
for
block
in
self
.
blocks
:
self
.
_allocator
.
free
(
block
)
self
.
_allocator
.
free
(
block
)
self
.
_blocks
=
[]
self
.
_blocks
.
reset
()
@
property
@
property
def
physical_block_ids
(
self
)
->
List
[
Optional
[
int
]
]
:
def
physical_block_ids
(
self
)
->
List
[
int
]:
"""Returns a list of physical block indices for the blocks in the
"""Returns a list of physical block indices for the blocks in the
BlockTable.
BlockTable.
...
@@ -228,7 +234,7 @@ class BlockTable:
...
@@ -228,7 +234,7 @@ class BlockTable:
BlockTable.
BlockTable.
"""
"""
assert
self
.
_is_allocated
assert
self
.
_is_allocated
return
[
block
.
block_id
for
block
in
self
.
_blocks
]
return
self
.
_blocks
.
ids
()
def
get_unseen_token_ids
(
self
,
sequence_token_ids
:
List
[
int
])
->
List
[
int
]:
def
get_unseen_token_ids
(
self
,
sequence_token_ids
:
List
[
int
])
->
List
[
int
]:
"""Get the number of "unseen" tokens in the sequence.
"""Get the number of "unseen" tokens in the sequence.
...
@@ -252,18 +258,32 @@ class BlockTable:
...
@@ -252,18 +258,32 @@ class BlockTable:
def
_allocate_blocks_for_token_ids
(
self
,
prev_block
:
Optional
[
Block
],
def
_allocate_blocks_for_token_ids
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
device
:
Device
)
->
List
[
Block
]:
device
:
Device
)
->
List
[
Block
]:
blocks
=
[]
blocks
:
List
[
Block
]
=
[]
for
block_token_ids
in
chunk_list
(
token_ids
,
self
.
_block_size
):
if
len
(
block_token_ids
)
==
self
.
_block_size
:
block_token_ids
=
[]
# If the block is full, create an immutable block.
tail_token_ids
=
[]
prev_block
=
self
.
_allocator
.
allocate_immutable
(
for
cur_token_ids
in
chunk_list
(
token_ids
,
self
.
_block_size
):
prev_block
,
token_ids
=
block_token_ids
,
device
=
device
)
if
len
(
cur_token_ids
)
==
self
.
_block_size
:
block_token_ids
.
append
(
cur_token_ids
)
else
:
else
:
# Else, partially fill a mutable block with token ids.
tail_token_ids
.
append
(
cur_token_ids
)
prev_block
=
self
.
_allocator
.
allocate_mutable
(
prev_block
=
prev_block
,
device
=
device
)
if
block_token_ids
:
prev_block
.
append_token_ids
(
block_token_ids
)
blocks
.
extend
(
blocks
.
append
(
prev_block
)
self
.
_allocator
.
allocate_immutable_blocks
(
prev_block
,
block_token_ids
=
block_token_ids
,
device
=
device
))
prev_block
=
blocks
[
-
1
]
if
tail_token_ids
:
assert
len
(
tail_token_ids
)
==
1
cur_token_ids
=
tail_token_ids
[
0
]
block
=
self
.
_allocator
.
allocate_mutable_block
(
prev_block
=
prev_block
,
device
=
device
)
block
.
append_token_ids
(
cur_token_ids
)
blocks
.
append
(
block
)
return
blocks
return
blocks
...
@@ -274,18 +294,25 @@ class BlockTable:
...
@@ -274,18 +294,25 @@ class BlockTable:
if
not
self
.
_is_allocated
:
if
not
self
.
_is_allocated
:
return
token_ids
return
token_ids
for
block
in
self
.
_
blocks
:
for
block
in
self
.
blocks
:
token_ids
.
extend
(
block
.
token_ids
)
token_ids
.
extend
(
block
.
token_ids
)
return
token_ids
return
token_ids
def
_get_num_token_ids
(
self
)
->
int
:
res
=
0
for
block
in
self
.
blocks
:
res
+=
len
(
block
.
token_ids
)
return
res
@
property
@
property
def
_is_allocated
(
self
)
->
bool
:
def
_is_allocated
(
self
)
->
bool
:
return
len
(
self
.
_blocks
)
>
0
return
len
(
self
.
_blocks
)
>
0
@
property
@
property
def
blocks
(
self
)
->
Optional
[
List
[
Block
]
]
:
def
blocks
(
self
)
->
List
[
Block
]:
return
self
.
_blocks
return
self
.
_blocks
.
list
()
@
property
@
property
def
_num_empty_slots
(
self
)
->
int
:
def
_num_empty_slots
(
self
)
->
int
:
...
...
vllm/core/block/common.py
View file @
705f6a35
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Protocol
,
Tuple
from
collections
import
deque
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Protocol
,
Tuple
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
...
@@ -95,64 +96,40 @@ class CopyOnWriteTracker:
...
@@ -95,64 +96,40 @@ class CopyOnWriteTracker:
The CopyOnWriteTracker class maintains a mapping of source block indices to
The CopyOnWriteTracker class maintains a mapping of source block indices to
their corresponding copy-on-write destination block indices. It works in
their corresponding copy-on-write destination block indices. It works in
conjunction with a RefCounter and a BlockAllocator to handle reference
conjunction with a RefCounter.
counting and block allocation.
Args:
Args:
refcounter (RefCounter): The reference counter used to track block
refcounter (RefCounter): The reference counter used to track block
reference counts.
reference counts.
allocator (BlockAllocator): The block allocator used to allocate and
free blocks.
"""
"""
def
__init__
(
def
__init__
(
self
,
refcounter
:
RefCounterProtocol
):
self
,
refcounter
:
RefCounterProtocol
,
allocator
:
BlockAllocator
,
):
self
.
_copy_on_writes
:
List
[
Tuple
[
BlockId
,
BlockId
]]
=
[]
self
.
_copy_on_writes
:
List
[
Tuple
[
BlockId
,
BlockId
]]
=
[]
self
.
_refcounter
=
refcounter
self
.
_refcounter
=
refcounter
self
.
_allocator
=
allocator
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
This method checks the reference count of the given block. If the
reference count is greater than 1, indicating that the block is shared,
a copy-on-write operation is performed. The original block is freed,
and a new block is allocated with the same content. The new block index
is returned.
Args:
block (Block): The block to check for copy-on-write.
Returns:
def
is_appendable
(
self
,
block
:
Block
)
->
bool
:
Optional[BlockId]: The block index of the new block if a copy-on
"""Checks if the block is shared or not. If shared, then it cannot
-write operation was performed, or the original block index if
be appended and needs to be duplicated via copy-on-write
no copy-on-write was necessary.
"""
"""
block_id
=
block
.
block_id
block_id
=
block
.
block_id
if
block_id
is
None
:
if
block_id
is
None
:
return
block_id
return
True
refcount
=
self
.
_refcounter
.
get
(
block_id
)
refcount
=
self
.
_refcounter
.
get
(
block_id
)
assert
refcount
!=
0
return
refcount
<=
1
if
refcount
>
1
:
src_block_id
=
block_id
# Decrement refcount of the old block.
self
.
_allocator
.
free
(
block
)
# Allocate a fresh new block.
block_id
=
self
.
_allocator
.
allocate_mutable
(
prev_block
=
block
.
prev_block
).
block_id
# Track src/dst copy.
def
record_cow
(
self
,
src_block_id
:
Optional
[
BlockId
],
assert
src_block_id
is
not
None
trg_block_id
:
Optional
[
BlockId
])
->
None
:
assert
block_id
is
not
None
"""Records a copy-on-write operation from source to target block id
self
.
_copy_on_writes
.
append
((
src_block_id
,
block_id
))
Args:
src_block_id (BlockId): The source block id from which to copy
return
block_id
the data
trg_block_id (BlockId): The target block id to which the data
is copied
"""
assert
src_block_id
is
not
None
assert
trg_block_id
is
not
None
self
.
_copy_on_writes
.
append
((
src_block_id
,
trg_block_id
))
def
clear_cows
(
self
)
->
List
[
Tuple
[
BlockId
,
BlockId
]]:
def
clear_cows
(
self
)
->
List
[
Tuple
[
BlockId
,
BlockId
]]:
"""Clears the copy-on-write tracking information and returns the current
"""Clears the copy-on-write tracking information and returns the current
...
@@ -172,6 +149,139 @@ class CopyOnWriteTracker:
...
@@ -172,6 +149,139 @@ class CopyOnWriteTracker:
return
cows
return
cows
class
BlockPool
:
"""Used to pre-allocate block objects, in order to avoid excessive python
object allocations/deallocations.
The pool starts from "pool_size" objects and will increase to more objects
if necessary
Note that multiple block objects may point to the same physical block id,
which is why this pool is needed, so that it will be easier to support
prefix caching and more complicated sharing of physical blocks.
"""
def
__init__
(
self
,
block_size
:
int
,
create_block
:
Block
.
Factory
,
allocator
:
BlockAllocator
,
pool_size
:
int
):
self
.
_block_size
=
block_size
self
.
_create_block
=
create_block
self
.
_allocator
=
allocator
self
.
_pool_size
=
pool_size
assert
self
.
_pool_size
>=
0
self
.
_free_ids
:
Deque
[
int
]
=
deque
(
range
(
self
.
_pool_size
))
self
.
_pool
=
[]
for
i
in
range
(
self
.
_pool_size
):
self
.
_pool
.
append
(
self
.
_create_block
(
prev_block
=
None
,
token_ids
=
[],
block_size
=
self
.
_block_size
,
allocator
=
self
.
_allocator
,
block_id
=
None
))
def
increase_pool
(
self
):
"""Doubles the internal pool size
"""
cur_pool_size
=
self
.
_pool_size
new_pool_size
=
cur_pool_size
*
2
self
.
_pool_size
=
new_pool_size
self
.
_free_ids
+=
deque
(
range
(
cur_pool_size
,
new_pool_size
))
for
i
in
range
(
cur_pool_size
,
new_pool_size
):
self
.
_pool
.
append
(
self
.
_create_block
(
prev_block
=
None
,
token_ids
=
[],
block_size
=
self
.
_block_size
,
allocator
=
self
.
_allocator
,
block_id
=
None
))
def
init_block
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
block_size
:
int
,
physical_block_id
:
Optional
[
int
])
->
Block
:
if
len
(
self
.
_free_ids
)
==
0
:
self
.
increase_pool
()
assert
len
(
self
.
_free_ids
)
>
0
pool_id
=
self
.
_free_ids
.
popleft
()
block
=
self
.
_pool
[
pool_id
]
block
.
__init__
(
# type: ignore[misc]
prev_block
=
prev_block
,
token_ids
=
token_ids
,
block_size
=
block_size
,
allocator
=
block
.
_allocator
,
# type: ignore[attr-defined]
block_id
=
physical_block_id
)
block
.
pool_id
=
pool_id
# type: ignore[attr-defined]
return
block
def
free_block
(
self
,
block
:
Block
)
->
None
:
self
.
_free_ids
.
appendleft
(
block
.
pool_id
)
# type: ignore[attr-defined]
class
BlockList
:
"""This class is an optimization to allow fast-access to physical
block ids. It maintains a block id list that is updated with the
block list and this avoids the need to reconstruct the block id
list on every iteration of the block manager
"""
def
__init__
(
self
,
blocks
:
List
[
Block
]):
self
.
_blocks
:
List
[
Block
]
=
[]
self
.
_block_ids
:
List
[
int
]
=
[]
self
.
update
(
blocks
)
def
_add_block_id
(
self
,
block_id
:
Optional
[
BlockId
])
->
None
:
assert
block_id
is
not
None
self
.
_block_ids
.
append
(
block_id
)
def
_update_block_id
(
self
,
block_index
:
int
,
new_block_id
:
Optional
[
BlockId
])
->
None
:
assert
new_block_id
is
not
None
self
.
_block_ids
[
block_index
]
=
new_block_id
def
update
(
self
,
blocks
:
List
[
Block
]):
self
.
_blocks
=
blocks
# Cache block ids for fast query
self
.
_block_ids
=
[]
for
block
in
self
.
_blocks
:
self
.
_add_block_id
(
block
.
block_id
)
def
append_token_ids
(
self
,
block_index
:
int
,
token_ids
:
List
[
int
])
->
None
:
block
=
self
.
_blocks
[
block_index
]
prev_block_id
=
block
.
block_id
block
.
append_token_ids
(
token_ids
)
# CoW or promotion may update the internal block_id
if
prev_block_id
!=
block
.
block_id
:
self
.
_update_block_id
(
block_index
,
block
.
block_id
)
def
append
(
self
,
new_block
:
Block
):
self
.
_blocks
.
append
(
new_block
)
self
.
_add_block_id
(
new_block
.
block_id
)
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_blocks
)
def
__getitem__
(
self
,
block_index
:
int
)
->
Block
:
return
self
.
_blocks
[
block_index
]
def
__setitem__
(
self
,
block_index
:
int
,
new_block
:
Block
)
->
None
:
self
.
_blocks
[
block_index
]
=
new_block
self
.
_update_block_id
(
block_index
,
new_block
.
block_id
)
def
reset
(
self
):
self
.
_blocks
=
[]
self
.
_block_ids
=
[]
def
list
(
self
)
->
List
[
Block
]:
return
self
.
_blocks
def
ids
(
self
)
->
List
[
int
]:
return
self
.
_block_ids
def
get_all_blocks_recursively
(
last_block
:
Block
)
->
List
[
Block
]:
def
get_all_blocks_recursively
(
last_block
:
Block
)
->
List
[
Block
]:
"""Retrieves all the blocks in a sequence starting from the last block.
"""Retrieves all the blocks in a sequence starting from the last block.
...
...
vllm/core/block/cpu_gpu_block_allocator.py
View file @
705f6a35
...
@@ -113,11 +113,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -113,11 +113,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def
allocate_or_get_null_block
(
self
)
->
Block
:
def
allocate_or_get_null_block
(
self
)
->
Block
:
if
self
.
_null_block
is
None
:
if
self
.
_null_block
is
None
:
self
.
_null_block
=
NullBlock
(
self
.
_null_block
=
NullBlock
(
self
.
allocate_mutable
(
None
,
Device
.
GPU
))
self
.
allocate_mutable
_block
(
None
,
Device
.
GPU
))
return
self
.
_null_block
return
self
.
_null_block
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
def
allocate_mutable
_block
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
device
:
Device
)
->
Block
:
"""Allocates a new mutable block on the specified device.
"""Allocates a new mutable block on the specified device.
Args:
Args:
...
@@ -128,10 +128,31 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -128,10 +128,31 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Returns:
Returns:
Block: The newly allocated mutable block.
Block: The newly allocated mutable block.
"""
"""
return
self
.
_allocators
[
device
].
allocate_mutable
(
prev_block
)
return
self
.
_allocators
[
device
].
allocate_mutable
_block
(
prev_block
)
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
def
allocate_immutable_blocks
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Device
)
->
Block
:
block_token_ids
:
List
[
List
[
int
]],
device
:
Optional
[
Device
])
->
List
[
Block
]:
"""Allocates a new group of immutable blocks with the provided block
token IDs on the specified device.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
Used for prefix hashing.
block_token_ids (List[int]): The list of block token IDs to be
stored in the new blocks.
device (Device): The device on which to allocate the new block.
Returns:
List[Block]: The newly allocated list of immutable blocks
containing the provided block token IDs.
"""
return
self
.
_allocators
[
device
].
allocate_immutable_blocks
(
prev_block
,
block_token_ids
)
def
allocate_immutable_block
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Device
)
->
Block
:
"""Allocates a new immutable block with the provided token IDs on the
"""Allocates a new immutable block with the provided token IDs on the
specified device.
specified device.
...
@@ -146,7 +167,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -146,7 +167,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Block: The newly allocated immutable block containing the provided
Block: The newly allocated immutable block containing the provided
token IDs.
token IDs.
"""
"""
return
self
.
_allocators
[
device
].
allocate_immutable
(
return
self
.
_allocators
[
device
].
allocate_immutable
_block
(
prev_block
,
token_ids
)
prev_block
,
token_ids
)
def
free
(
self
,
block
:
Block
)
->
None
:
def
free
(
self
,
block
:
Block
)
->
None
:
...
@@ -161,7 +182,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -161,7 +182,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
block_id
=
block
.
block_id
block_id
=
block
.
block_id
assert
block_id
is
not
None
assert
block_id
is
not
None
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
return
allocator
.
free
(
block
)
allocator
.
free
(
block
)
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
"""Creates a new sequence of blocks that shares the same underlying
"""Creates a new sequence of blocks that shares the same underlying
...
@@ -210,8 +231,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -210,8 +231,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
"""
"""
return
self
.
_allocators
[
device
].
get_physical_block_id
(
absolute_id
)
return
self
.
_allocators
[
device
].
get_physical_block_id
(
absolute_id
)
def
swap
(
self
,
blocks
:
List
[
Block
],
s
ou
rc
e
_device
:
Device
,
def
swap
(
self
,
blocks
:
List
[
Block
],
src_device
:
Device
,
d
e
st_device
:
Device
)
->
Dict
[
int
,
int
]:
dst_device
:
Device
)
->
Dict
[
int
,
int
]:
"""Execute the swap for the given blocks from source_device
"""Execute the swap for the given blocks from source_device
on to dest_device, save the current swap mapping and append
on to dest_device, save the current swap mapping and append
them to the accumulated `self._swap_mapping` for each
them to the accumulated `self._swap_mapping` for each
...
@@ -219,23 +240,23 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -219,23 +240,23 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
Args:
blocks: List of blocks to be swapped.
blocks: List of blocks to be swapped.
s
ou
rc
e
_device (Device): Device to swap the 'blocks' from.
src_device (Device): Device to swap the 'blocks' from.
d
e
st_device (Device): Device to swap the 'blocks' to.
dst_device (Device): Device to swap the 'blocks' to.
Returns:
Returns:
Dict[int, int]: Swap mapping from source_device
Dict[int, int]: Swap mapping from source_device
on to dest_device.
on to dest_device.
"""
"""
s
ou
rc
e
_block_ids
=
[
block
.
block_id
for
block
in
blocks
]
src_block_ids
=
[
block
.
block_id
for
block
in
blocks
]
self
.
_allocators
[
s
ou
rc
e
_device
].
swap_out
(
blocks
)
self
.
_allocators
[
src_device
].
swap_out
(
blocks
)
self
.
_allocators
[
d
e
st_device
].
swap_in
(
blocks
)
self
.
_allocators
[
dst_device
].
swap_in
(
blocks
)
d
e
st_block_ids
=
[
block
.
block_id
for
block
in
blocks
]
dst_block_ids
=
[
block
.
block_id
for
block
in
blocks
]
current_swap_mapping
:
Dict
[
int
,
int
]
=
{}
current_swap_mapping
:
Dict
[
int
,
int
]
=
{}
for
src
,
d
e
st
in
zip
(
s
ou
rc
e
_block_ids
,
d
e
st_block_ids
):
for
src
_block_id
,
dst
_block_id
in
zip
(
src_block_ids
,
dst_block_ids
):
if
src
is
not
None
and
d
e
st
is
not
None
:
if
src
_block_id
is
not
None
and
dst
_block_id
is
not
None
:
self
.
_swap_mapping
[
src
]
=
d
e
st
self
.
_swap_mapping
[
src
_block_id
]
=
dst
_block_id
current_swap_mapping
[
src
]
=
d
e
st
current_swap_mapping
[
src
_block_id
]
=
dst
_block_id
return
current_swap_mapping
return
current_swap_mapping
def
get_num_blocks_touched
(
self
,
def
get_num_blocks_touched
(
self
,
...
@@ -283,23 +304,25 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -283,23 +304,25 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device
=
Device
.
GPU
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
mark_blocks_as_computed
(
block_ids
)
return
self
.
_allocators
[
device
].
mark_blocks_as_computed
(
block_ids
)
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
# Prefix caching only supported on GPU.
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
get_computed_block_ids
(
prev_computed_block_ids
,
block_ids
,
skip_last_block_id
)
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
# Prefix caching only supported on GPU.
# Prefix caching only supported on GPU.
device
=
Device
.
GPU
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
get_common_computed_block_ids
(
return
self
.
_allocators
[
device
].
get_common_computed_block_ids
(
seq_block_ids
)
computed_
seq_block_ids
)
@
property
@
property
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]:
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]:
return
frozenset
(
self
.
_block_ids_to_allocator
.
keys
())
return
frozenset
(
self
.
_block_ids_to_allocator
.
keys
())
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
raise
NotImplementedError
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
raise
NotImplementedError
def
get_and_reset_swaps
(
self
)
->
List
[
Tuple
[
int
,
int
]]:
def
get_and_reset_swaps
(
self
)
->
List
[
Tuple
[
int
,
int
]]:
"""Returns and clears the mapping of source to destination block IDs.
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
Will be called after every swapping operations for now, and after every
...
@@ -341,6 +364,11 @@ class NullBlock(Block):
...
@@ -341,6 +364,11 @@ class NullBlock(Block):
def
token_ids
(
self
)
->
List
[
BlockId
]:
def
token_ids
(
self
)
->
List
[
BlockId
]:
return
self
.
_proxy
.
token_ids
return
self
.
_proxy
.
token_ids
@
property
def
num_tokens_total
(
self
)
->
int
:
raise
NotImplementedError
(
"num_tokens_total is not used for null block"
)
@
property
@
property
def
num_empty_slots
(
self
)
->
BlockId
:
def
num_empty_slots
(
self
)
->
BlockId
:
return
self
.
_proxy
.
num_empty_slots
return
self
.
_proxy
.
num_empty_slots
...
...
vllm/core/block/interfaces.py
View file @
705f6a35
...
@@ -28,6 +28,13 @@ class Block(ABC):
...
@@ -28,6 +28,13 @@ class Block(ABC):
def
token_ids
(
self
)
->
List
[
int
]:
def
token_ids
(
self
)
->
List
[
int
]:
pass
pass
@
property
@
abstractmethod
def
num_tokens_total
(
self
)
->
int
:
"""The number of tokens till the current block (inclusive)
"""
pass
@
property
@
property
@
abstractmethod
@
abstractmethod
def
num_empty_slots
(
self
)
->
int
:
def
num_empty_slots
(
self
)
->
int
:
...
@@ -92,12 +99,18 @@ class Block(ABC):
...
@@ -92,12 +99,18 @@ class Block(ABC):
class
BlockAllocator
(
ABC
):
class
BlockAllocator
(
ABC
):
@
abstractmethod
@
abstractmethod
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
])
->
Block
:
def
allocate_mutable
_block
(
self
,
prev_block
:
Optional
[
Block
])
->
Block
:
pass
pass
@
abstractmethod
@
abstractmethod
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
def
allocate_immutable_block
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
])
->
Block
:
token_ids
:
List
[
int
])
->
Block
:
pass
@
abstractmethod
def
allocate_immutable_blocks
(
self
,
prev_block
:
Optional
[
Block
],
block_token_ids
:
List
[
List
[
int
]])
->
List
[
Block
]:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -146,13 +159,19 @@ class BlockAllocator(ABC):
...
@@ -146,13 +159,19 @@ class BlockAllocator(ABC):
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
pass
pass
@
abstractmethod
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
pass
@
abstractmethod
@
abstractmethod
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
pass
pass
@
abstractmethod
@
abstractmethod
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
"
BlockId
"
]
:
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
BlockId
:
"""NOTE: This should not be used besides Block"""
"""NOTE: This should not be used besides Block"""
pass
pass
...
@@ -174,13 +193,20 @@ class BlockAllocator(ABC):
...
@@ -174,13 +193,20 @@ class BlockAllocator(ABC):
class
DeviceAwareBlockAllocator
(
ABC
):
class
DeviceAwareBlockAllocator
(
ABC
):
@
abstractmethod
@
abstractmethod
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
def
allocate_mutable_block
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
device
:
Device
)
->
Block
:
pass
@
abstractmethod
def
allocate_immutable_block
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Device
)
->
Block
:
pass
pass
@
abstractmethod
@
abstractmethod
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
def
allocate_immutable_blocks
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Device
)
->
Block
:
block_token_ids
:
List
[
List
[
int
]],
device
:
Device
)
->
List
[
Block
]:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -217,9 +243,15 @@ class DeviceAwareBlockAllocator(ABC):
...
@@ -217,9 +243,15 @@ class DeviceAwareBlockAllocator(ABC):
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
pass
pass
@
abstractmethod
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
pass
@
abstractmethod
@
abstractmethod
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -230,8 +262,8 @@ class DeviceAwareBlockAllocator(ABC):
...
@@ -230,8 +262,8 @@ class DeviceAwareBlockAllocator(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
swap
(
self
,
blocks
:
List
[
Block
],
s
ou
rc
e
_device
:
Device
,
def
swap
(
self
,
blocks
:
List
[
Block
],
src_device
:
Device
,
d
e
st_device
:
Device
)
->
Dict
[
int
,
int
]:
dst_device
:
Device
)
->
Dict
[
int
,
int
]:
pass
pass
@
abstractmethod
@
abstractmethod
...
...
vllm/core/block/naive_block.py
View file @
705f6a35
from
typing
import
FrozenSet
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
collections
import
deque
from
typing
import
Deque
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Tuple
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
RefCounter
,
from
vllm.core.block.common
import
(
BlockPool
,
CopyOnWriteTracker
,
RefCounter
,
get_all_blocks_recursively
)
get_all_blocks_recursively
)
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
...
@@ -31,28 +32,39 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -31,28 +32,39 @@ class NaiveBlockAllocator(BlockAllocator):
num_blocks
:
int
,
num_blocks
:
int
,
block_size
:
int
,
block_size
:
int
,
block_ids
:
Optional
[
Iterable
[
int
]]
=
None
,
block_ids
:
Optional
[
Iterable
[
int
]]
=
None
,
block_pool
:
Optional
[
BlockPool
]
=
None
,
):
):
if
block_ids
is
None
:
if
block_ids
is
None
:
block_ids
=
range
(
num_blocks
)
block_ids
=
range
(
num_blocks
)
self
.
_free_block_indices
:
Set
[
BlockId
]
=
set
(
block_ids
)
self
.
_free_block_indices
:
Deque
[
BlockId
]
=
deque
(
block_ids
)
self
.
_all_block_indices
=
frozenset
(
block_ids
)
self
.
_all_block_indices
=
frozenset
(
block_ids
)
assert
len
(
self
.
_all_block_indices
)
==
num_blocks
assert
len
(
self
.
_all_block_indices
)
==
num_blocks
self
.
_refcounter
=
RefCounter
(
self
.
_refcounter
=
RefCounter
(
all_block_indices
=
self
.
_free_block_indices
)
all_block_indices
=
self
.
_free_block_indices
)
self
.
_create_block
=
create_block
self
.
_block_size
=
block_size
self
.
_block_size
=
block_size
self
.
_cow_tracker
=
CopyOnWriteTracker
(
self
.
_cow_tracker
=
CopyOnWriteTracker
(
refcounter
=
self
.
_refcounter
.
as_readonly
(),
refcounter
=
self
.
_refcounter
.
as_readonly
())
allocator
=
self
,
)
if
block_pool
is
None
:
extra_factor
=
4
def
allocate_immutable
(
self
,
# Pre-allocate "num_blocks * extra_factor" block objects.
prev_block
:
Optional
[
Block
],
# The "* extra_factor" is a buffer to allow more block objects
token_ids
:
List
[
int
],
# than physical blocks
device
:
Optional
[
Device
]
=
None
)
->
Block
:
self
.
_block_pool
=
BlockPool
(
self
.
_block_size
,
create_block
,
self
,
num_blocks
*
extra_factor
)
else
:
# In this case, the block pool is provided by the caller,
# which means that there is most likely a need to share
# a block pool between allocators
self
.
_block_pool
=
block_pool
def
allocate_immutable_block
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates a new immutable block with the given token IDs, linked to
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
the previous block.
...
@@ -66,13 +78,36 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -66,13 +78,36 @@ class NaiveBlockAllocator(BlockAllocator):
Block: The newly allocated immutable block.
Block: The newly allocated immutable block.
"""
"""
assert
device
is
None
assert
device
is
None
block
=
self
.
allocate_mutable
(
prev_block
=
prev_block
)
block
=
self
.
allocate_mutable
_block
(
prev_block
=
prev_block
)
block
.
append_token_ids
(
token_ids
)
block
.
append_token_ids
(
token_ids
)
return
block
return
block
def
allocate_mutable
(
self
,
def
allocate_immutable_blocks
(
prev_block
:
Optional
[
Block
],
self
,
device
:
Optional
[
Device
]
=
None
)
->
Block
:
prev_block
:
Optional
[
Block
],
block_token_ids
:
List
[
List
[
int
]],
device
:
Optional
[
Device
]
=
None
)
->
List
[
Block
]:
assert
device
is
None
num_blocks
=
len
(
block_token_ids
)
block_ids
=
[]
for
i
in
range
(
num_blocks
):
block_ids
.
append
(
self
.
_allocate_block_id
())
blocks
=
[]
for
i
in
range
(
num_blocks
):
prev_block
=
self
.
_block_pool
.
init_block
(
prev_block
=
prev_block
,
token_ids
=
block_token_ids
[
i
],
block_size
=
self
.
_block_size
,
physical_block_id
=
block_ids
[
i
])
blocks
.
append
(
prev_block
)
return
blocks
def
allocate_mutable_block
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates a new mutable block, linked to the previous block.
"""Allocates a new mutable block, linked to the previous block.
Args:
Args:
...
@@ -84,20 +119,39 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -84,20 +119,39 @@ class NaiveBlockAllocator(BlockAllocator):
Block: The newly allocated mutable block.
Block: The newly allocated mutable block.
"""
"""
assert
device
is
None
assert
device
is
None
block_id
=
self
.
_allocate_new_block_id
()
block_id
=
self
.
_allocate_block_id
()
return
self
.
_create_block
(
block
=
self
.
_block_pool
.
init_block
(
prev_block
=
prev_block
,
prev_block
=
prev_block
,
token_ids
=
[],
token_ids
=
[],
block_size
=
self
.
_block_size
,
block_id
=
block_id
,
physical_block_id
=
block_id
)
block_size
=
self
.
_block_size
,
return
block
allocator
=
self
,
)
def
_allocate_block_id
(
self
)
->
BlockId
:
if
not
self
.
_free_block_indices
:
def
free
(
self
,
block
:
Block
)
->
None
:
raise
BlockAllocator
.
NoFreeBlocksError
()
assert
block
.
block_id
is
not
None
self
.
_free_block_id
(
block
.
block_id
)
block_id
=
self
.
_free_block_indices
.
popleft
()
self
.
_refcounter
.
incr
(
block_id
)
return
block_id
def
_free_block_id
(
self
,
block
:
Block
)
->
None
:
block_id
=
block
.
block_id
assert
block_id
is
not
None
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
if
refcount
==
0
:
self
.
_free_block_indices
.
appendleft
(
block_id
)
block
.
block_id
=
None
block
.
block_id
=
None
def
free
(
self
,
block
:
Block
,
keep_block_object
:
bool
=
False
)
->
None
:
# Release the physical block id
self
.
_free_block_id
(
block
)
# Release the block object
if
not
keep_block_object
:
self
.
_block_pool
.
free_block
(
block
)
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
"""Creates a new sequence of blocks that shares the same underlying
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
memory as the original sequence.
...
@@ -111,7 +165,7 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -111,7 +165,7 @@ class NaiveBlockAllocator(BlockAllocator):
"""
"""
source_blocks
=
get_all_blocks_recursively
(
last_block
)
source_blocks
=
get_all_blocks_recursively
(
last_block
)
forked_blocks
=
[]
forked_blocks
:
List
[
Block
]
=
[]
prev_block
=
None
prev_block
=
None
for
block
in
source_blocks
:
for
block
in
source_blocks
:
...
@@ -120,14 +174,13 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -120,14 +174,13 @@ class NaiveBlockAllocator(BlockAllocator):
refcount
=
self
.
_refcounter
.
incr
(
block
.
block_id
)
refcount
=
self
.
_refcounter
.
incr
(
block
.
block_id
)
assert
refcount
!=
1
,
"can't fork free'd block"
assert
refcount
!=
1
,
"can't fork free'd block"
forked_blocks
.
append
(
forked_block
=
self
.
_block_pool
.
init_block
(
self
.
_create_block
(
prev_block
=
prev_block
,
prev_block
=
prev_block
,
token_ids
=
block
.
token_ids
,
token_ids
=
block
.
token_ids
,
block_size
=
self
.
_block_size
,
block_id
=
block
.
block_id
,
physical_block_id
=
block
.
block_id
)
block_size
=
self
.
_block_size
,
allocator
=
self
,
forked_blocks
.
append
(
forked_block
)
))
prev_block
=
forked_blocks
[
-
1
]
prev_block
=
forked_blocks
[
-
1
]
return
forked_blocks
return
forked_blocks
...
@@ -138,20 +191,6 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -138,20 +191,6 @@ class NaiveBlockAllocator(BlockAllocator):
def
get_num_total_blocks
(
self
)
->
int
:
def
get_num_total_blocks
(
self
)
->
int
:
return
len
(
self
.
_all_block_indices
)
return
len
(
self
.
_all_block_indices
)
def
_allocate_new_block_id
(
self
)
->
BlockId
:
if
not
self
.
_free_block_indices
:
raise
BlockAllocator
.
NoFreeBlocksError
()
block_id
=
next
(
iter
(
self
.
_free_block_indices
))
self
.
_refcounter
.
incr
(
block_id
)
self
.
_free_block_indices
.
remove
(
block_id
)
return
block_id
def
_free_block_id
(
self
,
block_id
:
BlockId
)
->
None
:
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
if
refcount
==
0
:
self
.
_free_block_indices
.
add
(
block_id
)
def
get_physical_block_id
(
self
,
absolute_id
:
int
)
->
int
:
def
get_physical_block_id
(
self
,
absolute_id
:
int
)
->
int
:
"""Returns the zero-offset block id on certain block allocator
"""Returns the zero-offset block id on certain block allocator
given the absolute block id.
given the absolute block id.
...
@@ -173,7 +212,7 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -173,7 +212,7 @@ class NaiveBlockAllocator(BlockAllocator):
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]:
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]:
return
self
.
_all_block_indices
return
self
.
_all_block_indices
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]
:
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
BlockId
:
"""Performs a copy-on-write operation on the given block if it is not
"""Performs a copy-on-write operation on the given block if it is not
appendable.
appendable.
...
@@ -181,11 +220,22 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -181,11 +220,22 @@ class NaiveBlockAllocator(BlockAllocator):
block (Block): The block to check for copy-on-write.
block (Block): The block to check for copy-on-write.
Returns:
Returns:
Optional[
BlockId
]
: The block index of the new block if a copy-on
BlockId: The block index of the new block if a copy-on
-write
-write
operation was performed, or the original block index if
operation was performed, or the original block index if
no copy-on-write was necessary.
no copy-on-write was necessary.
"""
"""
return
self
.
_cow_tracker
.
cow_block_if_not_appendable
(
block
)
src_block_id
=
block
.
block_id
assert
src_block_id
is
not
None
if
self
.
_cow_tracker
.
is_appendable
(
block
):
return
src_block_id
self
.
_free_block_id
(
block
)
trg_block_id
=
self
.
_allocate_block_id
()
self
.
_cow_tracker
.
record_cow
(
src_block_id
,
trg_block_id
)
return
trg_block_id
def
clear_copy_on_writes
(
self
)
->
List
[
Tuple
[
BlockId
,
BlockId
]]:
def
clear_copy_on_writes
(
self
)
->
List
[
Tuple
[
BlockId
,
BlockId
]]:
"""Returns the copy-on-write source->destination mapping and clears it.
"""Returns the copy-on-write source->destination mapping and clears it.
...
@@ -213,8 +263,15 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -213,8 +263,15 @@ class NaiveBlockAllocator(BlockAllocator):
"""
"""
pass
pass
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
"""No prefix caching here => return empty list
"""
return
[]
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
"""Determine blocks that can be skipped in prefill.
"""Determine blocks that can be skipped in prefill.
Since the naive allocator does not support prefix caching, always return
Since the naive allocator does not support prefix caching, always return
...
@@ -223,7 +280,7 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -223,7 +280,7 @@ class NaiveBlockAllocator(BlockAllocator):
return
[]
return
[]
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
raise
NotImplementedError
raise
NotImplementedError
(
"There is no promotion for naive blocks"
)
def
get_num_blocks_touched
(
self
,
def
get_num_blocks_touched
(
self
,
blocks
:
List
[
Block
],
blocks
:
List
[
Block
],
...
@@ -263,17 +320,27 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -263,17 +320,27 @@ class NaiveBlockAllocator(BlockAllocator):
def
swap_out
(
self
,
blocks
:
List
[
Block
])
->
None
:
def
swap_out
(
self
,
blocks
:
List
[
Block
])
->
None
:
for
block
in
blocks
:
for
block
in
blocks
:
self
.
free
(
block
)
self
.
_
free
_block_id
(
block
)
def
swap_in
(
self
,
blocks
:
List
[
Block
])
->
None
:
def
swap_in
(
self
,
blocks
:
List
[
Block
])
->
None
:
for
block
in
blocks
:
for
block
in
blocks
:
# Here we allocate either immutable or mutable block and then
# extract its block_id. Note that the block object is released
# and the block_id is assigned to "block" to allow reusing the
# existing "block" object
if
block
.
is_full
:
if
block
.
is_full
:
al
loc
=
self
.
allocate_immutable
(
block
.
prev
_block
,
tmp_b
loc
k
=
self
.
allocate_immutable_block
(
block
.
token_ids
)
prev_block
=
block
.
prev_block
,
token_ids
=
block
.
token_ids
)
else
:
else
:
alloc
=
self
.
allocate_mutable
(
block
.
prev_block
)
tmp_block
=
self
.
allocate_mutable_block
(
alloc
.
append_token_ids
(
block
.
token_ids
)
prev_block
=
block
.
prev_block
)
block
.
block_id
=
alloc
.
block_id
tmp_block
.
append_token_ids
(
block
.
token_ids
)
block_id
=
tmp_block
.
block_id
tmp_block
.
block_id
=
None
self
.
_block_pool
.
free_block
(
tmp_block
)
block
.
block_id
=
block_id
# Assign block_id
class
NaiveBlock
(
Block
):
class
NaiveBlock
(
Block
):
...
@@ -315,11 +382,12 @@ class NaiveBlock(Block):
...
@@ -315,11 +382,12 @@ class NaiveBlock(Block):
self
.
_append_token_ids_no_cow
(
token_ids
)
self
.
_append_token_ids_no_cow
(
token_ids
)
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
"""Appends the given token IDs to the block
, instructing the allocator
"""Appends the given token IDs to the block
and performs a
to perform a
copy-on-write if necessary.
copy-on-write if necessary.
Args:
Args:
token_ids (List[int]): The token IDs to be appended to the block.
token_ids (Optional[List[int]]): The token IDs to be appended
to the block.
"""
"""
self
.
_append_token_ids_no_cow
(
token_ids
)
self
.
_append_token_ids_no_cow
(
token_ids
)
...
@@ -328,7 +396,16 @@ class NaiveBlock(Block):
...
@@ -328,7 +396,16 @@ class NaiveBlock(Block):
self
.
_cow_target
))
self
.
_cow_target
))
def
_append_token_ids_no_cow
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
_append_token_ids_no_cow
(
self
,
token_ids
:
List
[
int
])
->
None
:
assert
self
.
num_empty_slots
>=
len
(
token_ids
)
"""Appends the given token IDs to the block
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
if
len
(
token_ids
)
==
0
:
return
assert
len
(
token_ids
)
<=
self
.
num_empty_slots
self
.
_token_ids
.
extend
(
token_ids
)
self
.
_token_ids
.
extend
(
token_ids
)
@
property
@
property
...
@@ -361,12 +438,17 @@ class NaiveBlock(Block):
...
@@ -361,12 +438,17 @@ class NaiveBlock(Block):
@
property
@
property
def
num_empty_slots
(
self
)
->
int
:
def
num_empty_slots
(
self
)
->
int
:
return
self
.
_block_size
-
len
(
self
.
_
token_ids
)
return
self
.
_block_size
-
len
(
self
.
token_ids
)
@
property
@
property
def
token_ids
(
self
)
->
List
[
int
]:
def
token_ids
(
self
)
->
List
[
int
]:
return
self
.
_token_ids
return
self
.
_token_ids
@
property
def
num_tokens_total
(
self
)
->
int
:
raise
NotImplementedError
(
"num_tokens_total is not used for naive block"
)
@
property
@
property
def
block_size
(
self
)
->
int
:
def
block_size
(
self
)
->
int
:
return
self
.
_block_size
return
self
.
_block_size
...
...
vllm/core/block/prefix_caching_block.py
View file @
705f6a35
"""Token blocks."""
"""Token blocks."""
from
itertools
import
takewhile
from
os.path
import
commonprefix
from
os.path
import
commonprefix
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Tuple
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
get_all_blocks_recursively
)
get_all_blocks_recursively
)
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
from
vllm.core.block.naive_block
import
NaiveBlock
,
NaiveBlockAllocator
from
vllm.core.block.naive_block
import
(
BlockPool
,
NaiveBlock
,
NaiveBlockAllocator
)
from
vllm.core.evictor_v2
import
EvictionPolicy
,
Evictor
,
make_evictor
from
vllm.core.evictor_v2
import
EvictionPolicy
,
Evictor
,
make_evictor
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
...
@@ -19,6 +19,30 @@ PrefixHash = int
...
@@ -19,6 +19,30 @@ PrefixHash = int
_DEFAULT_LAST_ACCESSED_TIME
=
-
1
_DEFAULT_LAST_ACCESSED_TIME
=
-
1
class
BlockTracker
:
"""Used to track the status of a block inside the prefix caching allocator
"""
__slots__
=
(
"active"
,
"last_accessed"
,
"computed"
)
def
reset
(
self
):
self
.
last_accessed
:
float
=
_DEFAULT_LAST_ACCESSED_TIME
self
.
computed
:
bool
=
False
def
__init__
(
self
):
self
.
active
:
bool
=
False
self
.
reset
()
def
enable
(
self
):
assert
not
self
.
active
self
.
active
=
True
self
.
reset
()
def
disable
(
self
):
assert
self
.
active
self
.
active
=
False
self
.
reset
()
class
PrefixCachingBlockAllocator
(
BlockAllocator
):
class
PrefixCachingBlockAllocator
(
BlockAllocator
):
"""A block allocator that implements prefix caching.
"""A block allocator that implements prefix caching.
...
@@ -41,12 +65,26 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -41,12 +65,26 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_ids
:
Optional
[
Iterable
[
int
]]
=
None
,
block_ids
:
Optional
[
Iterable
[
int
]]
=
None
,
eviction_policy
:
EvictionPolicy
=
EvictionPolicy
.
LRU
,
eviction_policy
:
EvictionPolicy
=
EvictionPolicy
.
LRU
,
):
):
if
block_ids
is
None
:
block_ids
=
range
(
num_blocks
)
self
.
_block_size
=
block_size
# A mapping of prefix hash to block index. All blocks which have a
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0.
# prefix hash will be in this dict, even if they have refcount 0.
self
.
_cached_blocks
:
Dict
[
PrefixHash
,
BlockId
]
=
{}
self
.
_cached_blocks
:
Dict
[
PrefixHash
,
BlockId
]
=
{}
# A mapping of blockId to Block to track those cached blocks
# Used to track status of each physical block id
self
.
_blocks
:
Dict
[
BlockId
,
Block
]
=
{}
self
.
_block_tracker
:
Dict
[
BlockId
,
BlockTracker
]
=
{}
for
block_id
in
block_ids
:
self
.
_block_tracker
[
block_id
]
=
BlockTracker
()
# Pre-allocate "num_blocks * extra_factor" block objects.
# The "* extra_factor" is a buffer to allow more block objects
# than physical blocks
extra_factor
=
4
self
.
_block_pool
=
BlockPool
(
self
.
_block_size
,
self
.
_create_block
,
self
,
num_blocks
*
extra_factor
)
# An allocator for blocks that do not have prefix hashes.
# An allocator for blocks that do not have prefix hashes.
self
.
_hashless_allocator
=
NaiveBlockAllocator
(
self
.
_hashless_allocator
=
NaiveBlockAllocator
(
...
@@ -54,10 +92,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -54,10 +92,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
num_blocks
=
num_blocks
,
num_blocks
=
num_blocks
,
block_size
=
block_size
,
block_size
=
block_size
,
block_ids
=
block_ids
,
block_ids
=
block_ids
,
block_pool
=
self
.
_block_pool
,
# Share block pool here
)
)
self
.
_block_size
=
block_size
# Evitor used to maintain how we want to handle those computed blocks
# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
# if we find memory pressure is high.
self
.
evictor
:
Evictor
=
make_evictor
(
eviction_policy
)
self
.
evictor
:
Evictor
=
make_evictor
(
eviction_policy
)
...
@@ -68,9 +105,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -68,9 +105,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self
.
_refcounter
=
self
.
_hashless_allocator
.
refcounter
self
.
_refcounter
=
self
.
_hashless_allocator
.
refcounter
self
.
_cow_tracker
=
CopyOnWriteTracker
(
self
.
_cow_tracker
=
CopyOnWriteTracker
(
refcounter
=
self
.
_refcounter
.
as_readonly
(),
refcounter
=
self
.
_refcounter
.
as_readonly
())
allocator
=
self
,
)
# Implements Block.Factory.
# Implements Block.Factory.
def
_create_block
(
def
_create_block
(
...
@@ -90,14 +125,14 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -90,14 +125,14 @@ class PrefixCachingBlockAllocator(BlockAllocator):
token_ids
=
token_ids
,
token_ids
=
token_ids
,
block_size
=
block_size
,
block_size
=
block_size
,
block_id
=
block_id
,
block_id
=
block_id
,
prefix_caching_
allocator
=
allocator
,
allocator
=
allocator
,
computed
=
computed
,
computed
=
computed
,
)
)
def
allocate_immutable
(
self
,
def
allocate_immutable
_block
(
self
,
prev_block
:
Optional
[
Block
],
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates an immutable block with the given token IDs, reusing cached
"""Allocates an immutable block with the given token IDs, reusing cached
blocks if possible.
blocks if possible.
...
@@ -111,29 +146,41 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -111,29 +146,41 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert
device
is
None
assert
device
is
None
assert_prefix_caching_block_or_none
(
prev_block
)
assert_prefix_caching_block_or_none
(
prev_block
)
block
=
self
.
_create_block
(
# First, try to create a block that points to cached data
prev_block
=
prev_block
,
block
=
self
.
_block_pool
.
init_block
(
prev_block
=
prev_block
,
token_ids
=
token_ids
,
token_ids
=
token_ids
,
block_size
=
self
.
_block_size
,
block_size
=
self
.
_block_size
,
allocator
=
self
,
physical_block_id
=
None
)
)
assert
block
.
content_hash
is
not
None
assert
block
.
content_hash
is
not
None
cached_block_id
=
self
.
_cached_blocks
.
get
(
block
.
content_hash
,
None
)
cached_block_id
=
self
.
_cached_blocks
.
get
(
block
.
content_hash
,
None
)
if
cached_block_id
is
not
None
:
if
cached_block_id
is
not
None
:
block
.
block_id
=
cached_block_id
block
.
block_id
=
cached_block_id
self
.
_incr_refcount_cached_block
(
block
,
block
.
block_id
)
self
.
_incr_refcount_cached_block
(
block
)
return
block
return
block
self
.
_block_pool
.
free_block
(
block
)
block
=
self
.
allocate_mutable
(
prev_block
)
# No cached block => Allocate a new block
block
=
self
.
allocate_mutable_block
(
prev_block
)
block
.
append_token_ids
(
token_ids
)
block
.
append_token_ids
(
token_ids
)
assert
block
.
content_hash
is
not
None
return
block
return
block
def
allocate_mutable
(
self
,
def
allocate_immutable_blocks
(
prev_block
:
Optional
[
Block
],
self
,
device
:
Optional
[
Device
]
=
None
)
->
Block
:
prev_block
:
Optional
[
Block
],
block_token_ids
:
List
[
List
[
int
]],
device
:
Optional
[
Device
]
=
None
)
->
List
[
Block
]:
blocks
=
[]
for
token_ids
in
block_token_ids
:
prev_block
=
self
.
allocate_immutable_block
(
prev_block
=
prev_block
,
token_ids
=
token_ids
,
device
=
device
)
blocks
.
append
(
prev_block
)
return
blocks
def
allocate_mutable_block
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates a mutable block. If there are no free blocks, this will
"""Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks.
evict unused cached blocks.
...
@@ -147,113 +194,154 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -147,113 +194,154 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert
device
is
None
assert
device
is
None
assert_prefix_caching_block_or_none
(
prev_block
)
assert_prefix_caching_block_or_none
(
prev_block
)
try
:
block_id
=
self
.
_allocate_block_id
()
block
=
self
.
_hashless_allocator
.
allocate_mutable
(
block
=
self
.
_block_pool
.
init_block
(
prev_block
=
prev_block
,
prev_block
=
prev_block
)
token_ids
=
[],
block_size
=
self
.
_block_size
,
physical_block_id
=
block_id
)
assert
not
block
.
computed
assert
block
.
content_hash
is
None
return
block
assert
block
.
block_id
not
in
self
.
_blocks
def
_incr_refcount_cached_block
(
self
,
block
:
Block
)
->
None
:
assert
block
.
block_id
is
not
None
# Set this block to be "computed" since it is pointing to a
self
.
_blocks
[
block
.
block_id
]
=
block
# cached block id (which was already computed)
return
block
block
.
computed
=
True
except
BlockAllocator
.
NoFreeBlocksError
:
# We must check the unused cached blocks before raising OOM.
pass
# If the evictor has blocks available for eviction, evict a block
block_id
=
block
.
block_id
# and return it.
assert
block_id
is
not
None
if
self
.
evictor
.
num_blocks
>
0
:
# here we get an evicted block, which is only added
# into evictor if its ref counter is 0
# and since its content would be changed, we need
# to remove it from _cached_blocks's tracking list
block_id
,
content_hash_to_evict
=
self
.
evictor
.
evict
()
_block_id
=
self
.
_cached_blocks
[
content_hash_to_evict
]
refcount
=
self
.
_refcounter
.
incr
(
block_id
)
assert
self
.
_refcounter
.
get
(
_block_id
)
==
0
if
refcount
==
1
:
assert
_block_id
==
block_id
# In case a cached block was evicted, restore its tracking
if
block_id
in
self
.
evictor
:
self
.
evictor
.
remove
(
block_id
)
self
.
_
cached_blocks
.
pop
(
content_hash_to_evict
)
self
.
_
track_block_id
(
block_id
,
computed
=
True
)
self
.
_refcounter
.
incr
(
block_id
)
def
_decr_refcount_cached_block
(
self
,
block
:
Block
)
->
None
:
# Ensure this is immutable/cached block
assert
block
.
content_hash
is
not
None
# the block comes from evictor already contain computed result
block_id
=
block
.
block_id
block
=
self
.
_create_block
(
assert
block_id
is
not
None
prev_block
=
prev_block
,
token_ids
=
[],
block_size
=
self
.
_block_size
,
allocator
=
self
,
block_id
=
block_id
,
computed
=
True
,
)
assert
block
.
content_hash
is
None
assert
block
.
block_id
not
in
self
.
_blocks
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
assert
block
.
block_id
is
not
None
if
refcount
>
0
:
self
.
_blocks
[
block
.
block_id
]
=
block
block
.
block_id
=
None
return
block
return
else
:
assert
refcount
==
0
# No
b
lo
ck available in hashless allocator, nor in unused cache blocks.
# No lo
nger used
raise
BlockAllocator
.
NoFreeBlocksError
()
assert
block
.
content_hash
in
self
.
_cached_blocks
def
_incr_refcount_cached_block
(
self
,
block
:
Block
,
# Add the cached block to the evictor
block_id
:
BlockId
)
->
None
:
# (This keeps the cached block around so it can be reused)
# now _incr_refcount_cached_block comes from two place
self
.
evictor
.
add
(
block_id
,
block
.
content_hash
,
block
.
num_tokens_total
,
# allocate_immutable/promote_to_immutable_block where hit
self
.
_block_tracker
[
block_id
].
last_accessed
)
# _cached_blocks hash key.
# In both cases, it means that already exists a already
# computed block which shared with block now
block
.
computed
=
True
refcount
=
self
.
_refcounter
.
incr
(
block_id
)
# Stop tracking the block
self
.
_untrack_block_id
(
block_id
)
block
.
block_id
=
None
def
_decr_refcount_hashless_block
(
self
,
block
:
Block
)
->
None
:
block_id
=
block
.
block_id
assert
block_id
is
not
None
# We may have a fork case where block is shared,
# in which case, we cannot remove it from tracking
refcount
=
self
.
_refcounter
.
get
(
block_id
)
if
refcount
==
1
:
if
refcount
==
1
:
# if block get referred, then it shall not be in evictor
self
.
_untrack_block_id
(
block_id
)
# and put it into _blocks for tracking
if
block_id
in
self
.
evictor
:
self
.
evictor
.
remove
(
block_id
)
self
.
_blocks
[
block_id
]
=
block
def
free
(
self
,
block
:
Block
)
->
None
:
# Decrement refcount of the block_id, but do not free the block object
"""Decrement the refcount of the block. If the decremented refcount is
# itself (will be handled by the caller)
zero, store the block in the freelist.
self
.
_hashless_allocator
.
free
(
block
,
keep_block_object
=
True
)
If the block has a content hash (meaning it is immutable), then we will
def
_allocate_block_id
(
self
)
->
BlockId
:
keep the block around in case future allocations require it.
"""First tries to allocate a block id from the hashless allocator,
and if there are no blocks, then tries to evict an unused cached block.
"""
"""
assert
(
block
.
block_id
hashless_block_id
=
self
.
_maybe_allocate_hashless_block_id
()
is
not
None
),
"freeing unallocated block is undefined"
if
hashless_block_id
is
not
None
:
return
hashless_block_id
self
.
_free_block_id_for_block
(
block
.
block_id
,
block
)
evicted_block_id
=
self
.
_maybe_allocate_evicted_block_id
()
if
evicted_block_id
is
not
None
:
return
evicted_block_id
block
.
block_id
=
None
# No block available in hashless allocator, nor in unused cache blocks.
raise
BlockAllocator
.
NoFreeBlocksError
()
def
_maybe_allocate_hashless_block_id
(
self
)
->
Optional
[
BlockId
]:
try
:
# Allocate mutable block and extract its block_id
block
=
self
.
_hashless_allocator
.
allocate_mutable_block
(
prev_block
=
None
)
block_id
=
block
.
block_id
self
.
_block_pool
.
free_block
(
block
)
self
.
_track_block_id
(
block_id
,
computed
=
False
)
return
block_id
except
BlockAllocator
.
NoFreeBlocksError
:
return
None
def
_free_block_id_for_block
(
self
,
block_id
:
BlockId
,
def
_maybe_allocate_evicted_block_id
(
self
)
->
Optional
[
BlockId
]:
block
:
Block
)
->
None
:
if
self
.
evictor
.
num_blocks
==
0
:
assert
isinstance
(
block
,
PrefixCachingBlock
)
return
None
# if we comes from promote_to_immutable_block, it means that
# block.content_hash is never None.
# However we need to release the same content block, so that
# physical block could get reused.
if
block
.
block_id
!=
block_id
or
block
.
content_hash
is
None
:
refcount
=
self
.
_refcounter
.
get
(
block_id
)
# We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1
assert
block
.
block_id
is
not
None
refcount
=
self
.
_refcounter
.
get
(
block
.
block_id
)
if
refcount
==
1
:
del
self
.
_blocks
[
block
.
block_id
]
return
self
.
_hashless_allocator
.
free
(
block
)
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
# Here we get an evicted block, which is only added
# into evictor if its ref counter is 0
# and since its content would be changed, we need
# to remove it from _cached_blocks's tracking list
block_id
,
content_hash_to_evict
=
self
.
evictor
.
evict
()
# Sanity checks
assert
content_hash_to_evict
in
self
.
_cached_blocks
_block_id
=
self
.
_cached_blocks
[
content_hash_to_evict
]
assert
self
.
_refcounter
.
get
(
_block_id
)
==
0
assert
_block_id
==
block_id
# If no longer used, add the block to the evictor.
self
.
_cached_blocks
.
pop
(
content_hash_to_evict
)
if
refcount
==
0
:
assert
block
.
content_hash
in
self
.
_cached_blocks
self
.
_refcounter
.
incr
(
block_id
)
assert
block
.
block_id
is
not
None
self
.
_track_block_id
(
block_id
,
computed
=
False
)
del
self
.
_blocks
[
block
.
block_id
]
self
.
evictor
.
add
(
block
.
block_id
,
block
.
content_hash
,
return
block_id
block
.
num_tokens_total
,
block
.
last_accessed
)
def
_free_block_id
(
self
,
block
:
Block
)
->
None
:
"""Decrements the refcount of the block. The block may be in two
possible states: (1) immutable/cached or (2) mutable/hashless.
In the first case, the refcount is decremented directly and the block
may be possibly added to the evictor. In other case, hashless
allocator free(..) with keep_block_object=True is called to only free
the block id (since the block object may be reused by the caller)
"""
block_id
=
block
.
block_id
assert
block_id
is
not
None
,
"Freeing unallocated block is undefined"
if
block
.
content_hash
is
not
None
:
# Immutable: This type of block is always cached, and we want to
# keep it in the evictor for future reuse
self
.
_decr_refcount_cached_block
(
block
)
else
:
# Mutable: This type of block is not cached, so we release it
# directly to the hashless allocator
self
.
_decr_refcount_hashless_block
(
block
)
assert
block
.
block_id
is
None
def
free
(
self
,
block
:
Block
,
keep_block_object
:
bool
=
False
)
->
None
:
"""Release the block (look at free_block_id(..) docs)
"""
# Release the physical block index
self
.
_free_block_id
(
block
)
# Release the block object to the pool
if
not
keep_block_object
:
self
.
_block_pool
.
free_block
(
block
)
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
"""Creates a new sequence of blocks that shares the same underlying
"""Creates a new sequence of blocks that shares the same underlying
...
@@ -268,20 +356,23 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -268,20 +356,23 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"""
"""
source_blocks
=
get_all_blocks_recursively
(
last_block
)
source_blocks
=
get_all_blocks_recursively
(
last_block
)
forked_blocks
=
[]
forked_blocks
:
List
[
Block
]
=
[]
prev_block
=
None
prev_block
=
None
for
block
in
source_blocks
:
for
block
in
source_blocks
:
refcount
=
self
.
_refcounter
.
incr
(
block
.
block_id
)
block_id
=
block
.
block_id
assert
refcount
!=
1
,
"can't fork free'd block"
assert
block_id
is
not
None
forked_blocks
.
append
(
refcount
=
self
.
_refcounter
.
incr
(
block_id
)
self
.
_create_block
(
assert
refcount
!=
1
,
"can't fork free'd block_id = {}"
.
format
(
prev_block
=
prev_block
,
block_id
)
token_ids
=
block
.
token_ids
,
block_id
=
block
.
block_id
,
forked_block
=
self
.
_block_pool
.
init_block
(
block_size
=
self
.
_block_size
,
prev_block
=
prev_block
,
allocator
=
self
,
token_ids
=
block
.
token_ids
,
))
block_size
=
self
.
_block_size
,
physical_block_id
=
block_id
)
forked_blocks
.
append
(
forked_block
)
prev_block
=
forked_blocks
[
-
1
]
prev_block
=
forked_blocks
[
-
1
]
return
forked_blocks
return
forked_blocks
...
@@ -326,7 +417,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -326,7 +417,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Note that if we already have a cached block with the same content, we
Note that if we already have a cached block with the same content, we
will replace the newly-promoted block's mapping with the existing cached
will replace the newly-promoted block's mapping with the existing cached
block.
block
id
.
Args:
Args:
block: The mutable block to be promoted.
block: The mutable block to be promoted.
...
@@ -335,23 +426,30 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -335,23 +426,30 @@ class PrefixCachingBlockAllocator(BlockAllocator):
BlockId: Either the original block index, or the block index of
BlockId: Either the original block index, or the block index of
the previously cached block matching the same content.
the previously cached block matching the same content.
"""
"""
# Ensure block can be promoted
assert
block
.
content_hash
is
not
None
assert
block
.
content_hash
is
not
None
assert
block
.
block_id
is
not
None
assert
block
.
block_id
is
not
None
assert
self
.
_refcounter
.
get
(
block
.
block_id
)
>
0
assert
self
.
_refcounter
.
get
(
block
.
block_id
)
>
0
# If the content hash does not have a corresponding cached block,
# set this block as the cached block.
if
block
.
content_hash
not
in
self
.
_cached_blocks
:
if
block
.
content_hash
not
in
self
.
_cached_blocks
:
# No cached content hash => Set this block as cached
# (Note that this block is not computed yet =>
# Will be computed after free())
self
.
_cached_blocks
[
block
.
content_hash
]
=
block
.
block_id
self
.
_cached_blocks
[
block
.
content_hash
]
=
block
.
block_id
else
:
return
block
.
block_id
self
.
_free_block_id_for_block
(
self
.
_cached_blocks
[
block
.
content
_
hash
],
block
)
# Reuse the cached
content
hash
self
.
_
in
cr_refcount_
cached
_block
(
self
.
_
de
cr_refcount_
hashless
_block
(
block
)
block
,
self
.
_cached_blocks
[
block
.
content_hash
]
)
block
.
block_id
=
self
.
_cached_blocks
[
block
.
content_hash
]
return
self
.
_cached_blocks
[
block
.
content_hash
]
# Increment refcount of the cached block and (possibly) restore
# it from the evictor.
# Note that in this case, the block is marked as computed
self
.
_incr_refcount_cached_block
(
block
)
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
return
block
.
block_id
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
BlockId
:
"""Performs a copy-on-write operation on the given block if it is not
"""Performs a copy-on-write operation on the given block if it is not
appendable.
appendable.
...
@@ -359,11 +457,22 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -359,11 +457,22 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block (Block): The block to check for copy-on-write.
block (Block): The block to check for copy-on-write.
Returns:
Returns:
Optional[
BlockId
]
: The block index of the new block if a copy-on
BlockId: The block index of the new block if a copy-on
-write
-write
operation was performed, or the original block index if
operation was performed, or the original block index if
no copy-on-write was necessary.
no copy-on-write was necessary.
"""
"""
return
self
.
_cow_tracker
.
cow_block_if_not_appendable
(
block
)
src_block_id
=
block
.
block_id
assert
src_block_id
is
not
None
if
self
.
_cow_tracker
.
is_appendable
(
block
):
return
src_block_id
self
.
_free_block_id
(
block
)
trg_block_id
=
self
.
_allocate_block_id
()
self
.
_cow_tracker
.
record_cow
(
src_block_id
,
trg_block_id
)
return
trg_block_id
def
clear_copy_on_writes
(
self
)
->
List
[
Tuple
[
BlockId
,
BlockId
]]:
def
clear_copy_on_writes
(
self
)
->
List
[
Tuple
[
BlockId
,
BlockId
]]:
"""Returns the copy-on-write source->destination mapping and clears it.
"""Returns the copy-on-write source->destination mapping and clears it.
...
@@ -383,8 +492,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -383,8 +492,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"""
"""
for
block_id
in
block_ids
:
for
block_id
in
block_ids
:
if
block_id
in
self
.
_blocks
:
if
self
.
_block_tracker
[
block_id
].
active
:
self
.
_block
s
[
block_id
].
last_accessed
=
now
self
.
_block
_tracker
[
block_id
].
last_accessed
=
now
elif
block_id
in
self
.
evictor
:
elif
block_id
in
self
.
evictor
:
self
.
evictor
.
update
(
block_id
,
now
)
self
.
evictor
.
update
(
block_id
,
now
)
else
:
else
:
...
@@ -392,25 +501,46 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -392,25 +501,46 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"Mark block as accessed which is not belonged to GPU"
)
"Mark block as accessed which is not belonged to GPU"
)
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
"""Mark blocks as computed, used in prefix caching."""
raise
NotImplementedError
(
"Marking as computed is incremental"
)
for
block_id
in
block_ids
:
def
_track_block_id
(
self
,
block_id
:
Optional
[
BlockId
],
if
block_id
in
self
.
_blocks
:
computed
:
bool
)
->
None
:
# only those full block is valid for prefix caching
assert
block_id
is
not
None
if
self
.
_blocks
[
block_id
].
is_full
:
self
.
_block_tracker
[
block_id
].
enable
()
self
.
_blocks
[
block_id
].
computed
=
True
self
.
_block_tracker
[
block_id
].
computed
=
computed
elif
block_id
not
in
self
.
evictor
:
raise
ValueError
(
f
"Mark
{
block_id
=
}
as computed which "
def
_untrack_block_id
(
self
,
block_id
:
Optional
[
BlockId
])
->
None
:
"is not belonged to GPU"
)
assert
block_id
is
not
None
self
.
_block_tracker
[
block_id
].
disable
()
def
block_is_computed
(
self
,
block_id
:
int
)
->
bool
:
def
block_is_computed
(
self
,
block_id
:
int
)
->
bool
:
if
block_id
in
self
.
_blocks
:
if
self
.
_block_tracker
[
block_id
].
active
:
return
self
.
_block
s
[
block_id
].
computed
return
self
.
_block
_tracker
[
block_id
].
computed
else
:
else
:
return
block_id
in
self
.
evictor
return
block_id
in
self
.
evictor
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
=
True
)
->
List
[
int
]:
prev_prefix_size
=
len
(
prev_computed_block_ids
)
cur_size
=
len
(
block_ids
)
if
skip_last_block_id
:
cur_size
-=
1
# Sanity checks
assert
cur_size
>=
0
assert
prev_prefix_size
<=
cur_size
ret
=
prev_computed_block_ids
for
i
in
range
(
prev_prefix_size
,
cur_size
):
block_id
=
block_ids
[
i
]
if
self
.
block_is_computed
(
block_id
):
ret
.
append
(
block_id
)
return
ret
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
"""Return the block ids that are common for a given sequence group.
"""Return the block ids that are common for a given sequence group.
Only those blocks that are immutable and already be marked
Only those blocks that are immutable and already be marked
...
@@ -421,14 +551,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -421,14 +551,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# prompt is cached. This would cause erroneous behavior in model
# prompt is cached. This would cause erroneous behavior in model
# runner.
# runner.
ids_list
=
[
list
(
takewhile
(
lambda
block_id
:
self
.
block_is_computed
(
block_id
),
seq
[:
-
1
]))
for
seq
in
seq_block_ids
]
# It returns a list of int although type annotation says list of string.
# It returns a list of int although type annotation says list of string.
return
commonprefix
([
return
commonprefix
([
ids
for
ids
in
ids_list
# type: ignore
ids
for
ids
in
computed_seq_block_ids
# type: ignore
if
ids
!=
[]
if
ids
!=
[]
])
])
...
@@ -470,10 +595,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -470,10 +595,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
blocks: List of blocks to be swapped out.
blocks: List of blocks to be swapped out.
"""
"""
for
block
in
blocks
:
for
block
in
blocks
:
self
.
free
(
block
)
self
.
_
free
_block_id
(
block
)
def
swap_in
(
self
,
blocks
:
List
[
Block
])
->
None
:
def
swap_in
(
self
,
blocks
:
List
[
Block
])
->
None
:
"""Execute the swap in
t
actions. Change the block id from
"""Execute the swap in actions. Change the block id from
old allocator to current allocator for each block to finish
old allocator to current allocator for each block to finish
the block table update.
the block table update.
...
@@ -481,13 +606,22 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -481,13 +606,22 @@ class PrefixCachingBlockAllocator(BlockAllocator):
blocks: List of blocks to be swapped in.
blocks: List of blocks to be swapped in.
"""
"""
for
block
in
blocks
:
for
block
in
blocks
:
# Here we allocate either immutable or mutable block and then
# extract its block_id. Note that the block object is released
# and the block_id is assigned to "block" to allow reusing the
# existing "block" object
if
block
.
is_full
:
if
block
.
is_full
:
al
loc
=
self
.
allocate_immutable
(
block
.
prev
_block
,
tmp_b
loc
k
=
self
.
allocate_immutable_block
(
block
.
token_ids
)
prev_block
=
block
.
prev_block
,
token_ids
=
block
.
token_ids
)
else
:
else
:
alloc
=
self
.
allocate_mutable
(
block
.
prev_block
)
tmp_block
=
self
.
allocate_mutable_block
(
alloc
.
append_token_ids
(
block
.
token_ids
)
prev_block
=
block
.
prev_block
)
block
.
block_id
=
alloc
.
block_id
tmp_block
.
append_token_ids
(
block
.
token_ids
)
block_id
=
tmp_block
.
block_id
self
.
_block_pool
.
free_block
(
tmp_block
)
block
.
block_id
=
block_id
# Assign block_id
class
PrefixCachingBlock
(
Block
):
class
PrefixCachingBlock
(
Block
):
...
@@ -504,7 +638,7 @@ class PrefixCachingBlock(Block):
...
@@ -504,7 +638,7 @@ class PrefixCachingBlock(Block):
token_ids (List[int]): The initial token IDs to be stored in the block.
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
block_size (int): The maximum number of token IDs that can be stored in
the block.
the block.
prefix_caching_
allocator (BlockAllocator): The prefix
allocator (BlockAllocator): The prefix
caching block allocator associated with this block.
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
of this block. Defaults to None.
...
@@ -515,31 +649,55 @@ class PrefixCachingBlock(Block):
...
@@ -515,31 +649,55 @@ class PrefixCachingBlock(Block):
prev_block
:
Optional
[
Block
],
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
block_size
:
int
,
block_size
:
int
,
prefix_caching_
allocator
:
BlockAllocator
,
allocator
:
BlockAllocator
,
block_id
:
Optional
[
int
]
=
None
,
block_id
:
Optional
[
int
]
=
None
,
computed
:
bool
=
False
,
computed
:
bool
=
False
,
):
):
assert
isinstance
(
p
refix
_c
aching
_a
llocator
,
assert
isinstance
(
allocator
,
P
refix
C
aching
BlockA
llocator
),
(
PrefixCachingBlockAllocator
),
(
"Currently this class is only tested with "
"Currently this class is only te
sted
with "
"PrefixCachingBlockAllocator. Got in
ste
a
d
allocator = {}"
.
format
(
"PrefixCachingBlockA
llocator
."
)
a
llocator
)
)
assert_prefix_caching_block_or_none
(
prev_block
)
assert_prefix_caching_block_or_none
(
prev_block
)
self
.
_prev_block
=
prev_block
self
.
_prev_block
=
prev_block
self
.
_cached_content_hash
:
Optional
[
int
]
=
None
self
.
_cached_content_hash
:
Optional
[
int
]
=
None
self
.
_cached_num_tokens_total
:
Optional
[
int
]
=
None
self
.
_cached_num_tokens_total
:
int
=
0
self
.
_
prefix_caching_allocator
=
prefix_caching_
allocator
self
.
_
allocator
=
allocator
self
.
_last_accessed
:
float
=
_DEFAULT_LAST_ACCESSED_TIME
self
.
_last_accessed
:
float
=
_DEFAULT_LAST_ACCESSED_TIME
self
.
_computed
=
computed
self
.
_computed
=
computed
self
.
_block
=
NaiveBlock
(
# On the first time, we create the block object, and next we only
prev_block
=
prev_block
,
# reinitialize it
token_ids
=
token_ids
,
if
hasattr
(
self
,
"_block"
):
block_size
=
block_size
,
self
.
_block
.
__init__
(
# type: ignore[has-type]
block_id
=
block_id
,
prev_block
=
prev_block
,
allocator
=
prefix_caching_allocator
,
token_ids
=
token_ids
,
_cow_target
=
self
,
block_size
=
block_size
,
)
block_id
=
block_id
,
allocator
=
self
.
_allocator
)
else
:
self
.
_block
=
NaiveBlock
(
prev_block
=
prev_block
,
token_ids
=
token_ids
,
block_size
=
block_size
,
block_id
=
block_id
,
allocator
=
self
.
_allocator
)
self
.
_update_num_tokens_total
()
def
_update_num_tokens_total
(
self
):
"""Incrementally computes the number of tokens that there is
till the current block (included)
"""
res
=
0
# Add all previous blocks
if
self
.
_prev_block
is
not
None
:
res
+=
self
.
_prev_block
.
num_tokens_total
# Add current block
res
+=
len
(
self
.
token_ids
)
self
.
_cached_num_tokens_total
=
res
@
property
@
property
def
computed
(
self
)
->
bool
:
def
computed
(
self
)
->
bool
:
...
@@ -561,22 +719,28 @@ class PrefixCachingBlock(Block):
...
@@ -561,22 +719,28 @@ class PrefixCachingBlock(Block):
"""Appends the given token IDs to the block and registers the block as
"""Appends the given token IDs to the block and registers the block as
immutable if the block becomes full.
immutable if the block becomes full.
Internally, the naive block handles CoW.
Args:
Args:
token_ids (List[int]): The token IDs to be appended to the block.
token_ids (List[int]): The token IDs to be appended to the block.
"""
"""
assert
token_ids
# Ensure this is mutable block (not promoted)
assert
self
.
content_hash
is
None
assert
not
self
.
computed
if
len
(
token_ids
)
==
0
:
return
# naive block handles CoW.
# Ensure there are input tokens
assert
token_ids
,
"Got token_ids = {}"
.
format
(
token_ids
)
# Naive block handles CoW.
self
.
_block
.
append_token_ids
(
token_ids
)
self
.
_block
.
append_token_ids
(
token_ids
)
self
.
_update_num_tokens_total
()
# If the content hash is present, then the block can be made immutable.
# If the content hash is present, then the block can be made immutable.
# Register ourselves with the allocator, potentially replacing the
# Register ourselves with the allocator, potentially replacing the
# physical block index.
# physical block index.
if
self
.
content_hash
is
not
None
:
if
self
.
content_hash
is
not
None
:
self
.
block_id
=
(
self
.
_prefix_caching_allocator
.
self
.
block_id
=
self
.
_allocator
.
promote_to_immutable_block
(
self
)
promote_to_immutable_block
(
self
))
@
property
@
property
def
block_id
(
self
)
->
Optional
[
int
]:
def
block_id
(
self
)
->
Optional
[
int
]:
...
@@ -596,23 +760,6 @@ class PrefixCachingBlock(Block):
...
@@ -596,23 +760,6 @@ class PrefixCachingBlock(Block):
@
property
@
property
def
num_tokens_total
(
self
)
->
int
:
def
num_tokens_total
(
self
)
->
int
:
"""return the total tokens so far.
Here we iterate the block chain till to the first block, while
cache the result in local to prevent repeated computations.
"""
if
self
.
_cached_num_tokens_total
is
not
None
:
return
self
.
_cached_num_tokens_total
_block
:
Optional
[
Block
]
=
self
self
.
_cached_num_tokens_total
=
0
# TODO: current implement here take O(N^2), we expect future
# we have O(1) here
while
_block
is
not
None
:
self
.
_cached_num_tokens_total
+=
len
(
_block
.
token_ids
)
_block
=
_block
.
prev_block
return
self
.
_cached_num_tokens_total
return
self
.
_cached_num_tokens_total
@
property
@
property
...
@@ -635,7 +782,6 @@ class PrefixCachingBlock(Block):
...
@@ -635,7 +782,6 @@ class PrefixCachingBlock(Block):
For the content-based hash to be defined, the current block must be
For the content-based hash to be defined, the current block must be
full.
full.
"""
"""
# If the hash is already computed, return it.
# If the hash is already computed, return it.
if
self
.
_cached_content_hash
is
not
None
:
if
self
.
_cached_content_hash
is
not
None
:
return
self
.
_cached_content_hash
return
self
.
_cached_content_hash
...
@@ -685,7 +831,129 @@ class PrefixCachingBlock(Block):
...
@@ -685,7 +831,129 @@ class PrefixCachingBlock(Block):
return
hash
((
is_first_block
,
prev_block_hash
,
*
cur_block_token_ids
))
return
hash
((
is_first_block
,
prev_block_hash
,
*
cur_block_token_ids
))
class
ComputedBlocksTracker
:
"""Handles caching of per-sequence computed block ids.
When a sequence appears for the first time, it traverses all of the
blocks and detects the prefix of blocks that is computed. On the
subsequent times, it only traverses the new blocks that were added
and updates the already recorded prefix of blocks with the newly
computed blocks.
To avoid redundant traversals, the algorithm also detects when there
is a "gap" in the computed prefix. For example, if we have blocks =
[1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
we won't try to add more computed blocks to [1,2,3] in this sequence
iteration, and will add more computed blocks only after the sequence is
freed and reused again.
Note that currently, for a given sequence, we also skip the last
block id for caching purposes, to avoid caching of a full sequence
"""
def
__init__
(
self
,
allocator
):
self
.
_allocator
=
allocator
self
.
_cached_computed_seq_blocks
:
Dict
[
int
,
Tuple
[
List
[
int
],
bool
]]
=
{}
def
add_seq
(
self
,
seq_id
:
int
)
->
None
:
"""Start tracking seq_id
"""
assert
seq_id
not
in
self
.
_cached_computed_seq_blocks
self
.
_cached_computed_seq_blocks
[
seq_id
]
=
([],
False
)
def
remove_seq
(
self
,
seq_id
:
int
)
->
None
:
"""Stop tracking seq_id
"""
assert
seq_id
in
self
.
_cached_computed_seq_blocks
del
self
.
_cached_computed_seq_blocks
[
seq_id
]
def
get_cached_computed_blocks_and_update
(
self
,
seq_id
:
int
,
block_ids
:
List
[
int
])
->
List
[
int
]:
""" Look at the class documentation for details
"""
# Ensure seq_id is already tracked
assert
seq_id
in
self
.
_cached_computed_seq_blocks
# Get cached data (may be empty on the first time)
prev_computed_block_ids
,
has_gap
=
self
.
_cached_computed_seq_blocks
[
seq_id
]
if
has_gap
:
# When gap is detected, we do not add more computed blocks at this
# sequence iteration
return
prev_computed_block_ids
# We do not consider the last block id for caching purposes.
num_cur_blocks
=
len
(
block_ids
)
-
1
assert
num_cur_blocks
>=
0
if
len
(
prev_computed_block_ids
)
>=
num_cur_blocks
:
# Cache HIT
assert
len
(
prev_computed_block_ids
)
==
num_cur_blocks
return
prev_computed_block_ids
# If here, then we may possibly add more computed blocks. As a result,
# traverse the additional blocks after prev_computed_block_ids to
# detect more computed blocks and add them.
# Incremental init for seq_id => Look only at the new blocks
computed_block_ids
=
self
.
_allocator
.
get_computed_block_ids
(
# noqa: E501
prev_computed_block_ids
,
block_ids
,
skip_last_block_id
=
True
,
# We skip last block id to avoid caching of full seq
)
# Detect if there is a "gap"
has_gap
=
len
(
computed_block_ids
)
<
num_cur_blocks
# Record
self
.
_cached_computed_seq_blocks
[
seq_id
]
=
(
computed_block_ids
,
has_gap
)
return
computed_block_ids
class
LastAccessBlocksTracker
:
"""Manages the last access time of the tracked sequences, in order to allow
an efficient update of allocator's block last access times
"""
def
__init__
(
self
,
allocator
):
self
.
_allocator
=
allocator
self
.
_seq_last_access
:
Dict
[
int
,
Optional
[
float
]]
=
{}
def
add_seq
(
self
,
seq_id
:
int
)
->
None
:
"""Start tracking seq_id
"""
assert
seq_id
not
in
self
.
_seq_last_access
self
.
_seq_last_access
[
seq_id
]
=
None
def
remove_seq
(
self
,
seq_id
:
int
)
->
None
:
"""Stop tracking seq_id
"""
assert
seq_id
in
self
.
_seq_last_access
del
self
.
_seq_last_access
[
seq_id
]
def
update_last_access
(
self
,
seq_id
:
int
,
time
:
float
)
->
None
:
assert
seq_id
in
self
.
_seq_last_access
self
.
_seq_last_access
[
seq_id
]
=
time
def
update_seq_blocks_last_access
(
self
,
seq_id
:
int
,
block_ids
:
List
[
int
])
->
None
:
assert
seq_id
in
self
.
_seq_last_access
ts
=
self
.
_seq_last_access
[
seq_id
]
if
ts
is
None
:
# No last access was recorded, no need to update.
return
self
.
_allocator
.
mark_blocks_as_accessed
(
block_ids
,
ts
)
def
assert_prefix_caching_block_or_none
(
block
:
Optional
[
Block
]):
def
assert_prefix_caching_block_or_none
(
block
:
Optional
[
Block
]):
if
block
is
None
:
if
block
is
None
:
return
return
assert
isinstance
(
block
,
PrefixCachingBlock
)
assert
isinstance
(
block
,
PrefixCachingBlock
),
"Got block = {}"
.
format
(
block
)
vllm/core/block_manager_v1.py
View file @
705f6a35
...
@@ -262,8 +262,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -262,8 +262,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self
.
cross_block_tables
:
Dict
[
str
,
BlockTable
]
=
{}
self
.
cross_block_tables
:
Dict
[
str
,
BlockTable
]
=
{}
def
_get_seq_num_required_blocks
(
self
,
seq
:
Sequence
)
->
int
:
def
_get_seq_num_required_blocks
(
self
,
seq
:
Sequence
)
->
int
:
return
0
if
seq
is
None
\
return
0
if
seq
is
None
else
seq
.
n_blocks
else
len
(
seq
.
logical_token_blocks
)
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
AllocStatus
:
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
AllocStatus
:
# FIXME(woosuk): Here we assume that all sequences in the group share
# FIXME(woosuk): Here we assume that all sequences in the group share
...
@@ -298,7 +297,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -298,7 +297,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
ref_count
:
int
,
\
ref_count
:
int
,
\
is_encoder_decoder
:
bool
=
True
)
->
BlockTable
:
is_encoder_decoder
:
bool
=
True
)
->
BlockTable
:
# Allocate new physical token blocks that will store the prompt tokens.
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks
=
len
(
seq
.
logical_toke
n_blocks
)
num_prompt_blocks
=
seq
.
n_blocks
block_table
:
BlockTable
=
[]
block_table
:
BlockTable
=
[]
for
logical_idx
in
range
(
num_prompt_blocks
):
for
logical_idx
in
range
(
num_prompt_blocks
):
...
@@ -367,7 +366,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -367,7 +366,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Compute a new hash for the block so that it can be shared by other
# Compute a new hash for the block so that it can be shared by other
# Sequences
# Sequences
new_hash
=
seq
.
hash_of_block
(
len
(
seq
.
logical_toke
n_blocks
)
-
1
)
new_hash
=
seq
.
hash_of_block
(
seq
.
n_blocks
-
1
)
# if new_hash is already in the cached table, then free last_block
# if new_hash is already in the cached table, then free last_block
# and return the cached version
# and return the cached version
...
@@ -407,10 +406,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -407,10 +406,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if
not
self
.
enable_caching
:
if
not
self
.
enable_caching
:
return
self
.
gpu_allocator
.
allocate
()
return
self
.
gpu_allocator
.
allocate
()
block_hash
:
Optional
[
int
]
=
None
block_hash
:
Optional
[
int
]
=
None
n_blocks
=
seq
.
n_blocks
if
(
self
.
_is_last_block_full
(
seq
)):
if
(
self
.
_is_last_block_full
(
seq
)):
block_hash
=
seq
.
hash_of_block
(
len
(
seq
.
logical_token_blocks
)
-
1
)
block_hash
=
seq
.
hash_of_block
(
n_blocks
-
1
)
num_hashed_tokens
=
seq
.
num_hashed_tokens_of_block
(
num_hashed_tokens
=
seq
.
num_hashed_tokens_of_block
(
n_blocks
-
1
)
len
(
seq
.
logical_token_blocks
)
-
1
)
# num_hashed_tokens is used to compute future hashes
# num_hashed_tokens is used to compute future hashes
# (e.g. in the hashing function, it is used to ask the sequence for
# (e.g. in the hashing function, it is used to ask the sequence for
...
@@ -429,12 +428,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -429,12 +428,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
num_lookahead_slots
:
int
=
0
,
num_lookahead_slots
:
int
=
0
,
)
->
List
[
Tuple
[
int
,
int
]]:
)
->
List
[
Tuple
[
int
,
int
]]:
"""Allocate a physical slot for a new token."""
"""Allocate a physical slot for a new token."""
logical
_blocks
=
seq
.
logical_toke
n_blocks
n
_blocks
=
seq
.
n_blocks
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
# If we need to allocate a new physical block
# If we need to allocate a new physical block
if
len
(
block_table
)
<
len
(
logical
_blocks
)
:
if
len
(
block_table
)
<
n
_blocks
:
# Currently this code only supports adding one physical block
# Currently this code only supports adding one physical block
assert
len
(
block_table
)
==
len
(
logical
_blocks
)
-
1
assert
len
(
block_table
)
==
n
_blocks
-
1
if
(
self
.
block_sliding_window
if
(
self
.
block_sliding_window
and
len
(
block_table
)
>=
self
.
block_sliding_window
):
and
len
(
block_table
)
>=
self
.
block_sliding_window
):
...
@@ -472,6 +471,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -472,6 +471,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
# NOTE: fork does not allocate a new physical block.
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
# Thus, it is always safe from OOM.
if
parent_seq
.
seq_id
not
in
self
.
block_tables
:
# Parent sequence has either been freed or never existed.
return
src_block_table
=
self
.
block_tables
[
parent_seq
.
seq_id
]
src_block_table
=
self
.
block_tables
[
parent_seq
.
seq_id
]
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
copy
()
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
copy
()
# When using a sliding window, blocks will be eventually reused.
# When using a sliding window, blocks will be eventually reused.
...
...
vllm/core/block_manager_v2.py
View file @
705f6a35
...
@@ -7,6 +7,8 @@ from typing import Tuple
...
@@ -7,6 +7,8 @@ from typing import Tuple
from
vllm.core.block.block_table
import
BlockTable
from
vllm.core.block.block_table
import
BlockTable
from
vllm.core.block.cpu_gpu_block_allocator
import
CpuGpuBlockAllocator
from
vllm.core.block.cpu_gpu_block_allocator
import
CpuGpuBlockAllocator
from
vllm.core.block.interfaces
import
Block
from
vllm.core.block.interfaces
import
Block
from
vllm.core.block.prefix_caching_block
import
(
ComputedBlocksTracker
,
LastAccessBlocksTracker
)
from
vllm.core.block.utils
import
check_no_caching_or_swa_for_blockmgr_encdec
from
vllm.core.block.utils
import
check_no_caching_or_swa_for_blockmgr_encdec
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
...
@@ -100,6 +102,11 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -100,6 +102,11 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self
.
block_tables
:
Dict
[
SeqId
,
BlockTable
]
=
{}
self
.
block_tables
:
Dict
[
SeqId
,
BlockTable
]
=
{}
self
.
cross_block_tables
:
Dict
[
EncoderSeqId
,
BlockTable
]
=
{}
self
.
cross_block_tables
:
Dict
[
EncoderSeqId
,
BlockTable
]
=
{}
self
.
_computed_blocks_tracker
=
ComputedBlocksTracker
(
self
.
block_allocator
)
self
.
_last_access_blocks_tracker
=
LastAccessBlocksTracker
(
self
.
block_allocator
)
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
AllocStatus
:
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
AllocStatus
:
# FIXME(woosuk): Here we assume that all sequences in the group share
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
# the same prompt. This may not be true for preempted sequences.
...
@@ -157,10 +164,18 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -157,10 +164,18 @@ class BlockSpaceManagerV2(BlockSpaceManager):
block_table
:
BlockTable
=
self
.
_allocate_sequence
(
seq
)
block_table
:
BlockTable
=
self
.
_allocate_sequence
(
seq
)
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
# Track seq
self
.
_computed_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
# Assign the block table for each sequence.
# Assign the block table for each sequence.
for
seq
in
waiting_seqs
[
1
:]:
for
seq
in
waiting_seqs
[
1
:]:
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
fork
()
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
fork
()
# Track seq
self
.
_computed_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
# Allocate cross-attention block table for encoder sequence
# Allocate cross-attention block table for encoder sequence
#
#
# NOTE: Here we assume that all sequences in the group have the same
# NOTE: Here we assume that all sequences in the group have the same
...
@@ -224,11 +239,23 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -224,11 +239,23 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return
new_cows
return
new_cows
def
free
(
self
,
seq
:
Sequence
)
->
None
:
def
free
(
self
,
seq
:
Sequence
)
->
None
:
if
seq
.
seq_id
not
in
self
.
block_tables
:
seq_id
=
seq
.
seq_id
if
seq_id
not
in
self
.
block_tables
:
# Already freed or haven't been scheduled yet.
# Already freed or haven't been scheduled yet.
return
return
self
.
block_tables
[
seq
.
seq_id
].
free
()
del
self
.
block_tables
[
seq
.
seq_id
]
# Update seq block ids with the latest access time
self
.
_last_access_blocks_tracker
.
update_seq_blocks_last_access
(
seq_id
,
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
)
# Untrack seq
self
.
_last_access_blocks_tracker
.
remove_seq
(
seq_id
)
self
.
_computed_blocks_tracker
.
remove_seq
(
seq_id
)
# Free table/blocks
self
.
block_tables
[
seq_id
].
free
()
del
self
.
block_tables
[
seq_id
]
def
free_cross
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
free_cross
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
request_id
=
seq_group
.
request_id
request_id
=
seq_group
.
request_id
...
@@ -239,9 +266,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -239,9 +266,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
del
self
.
cross_block_tables
[
request_id
]
del
self
.
cross_block_tables
[
request_id
]
def
get_block_table
(
self
,
seq
:
Sequence
)
->
List
[
int
]:
def
get_block_table
(
self
,
seq
:
Sequence
)
->
List
[
int
]:
assert
seq
.
seq_id
in
self
.
block_tables
block_ids
=
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
block_ids
=
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
assert
all
(
b
is
not
None
for
b
in
block_ids
)
return
block_ids
# type: ignore
return
block_ids
# type: ignore
def
get_cross_block_table
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
int
]:
def
get_cross_block_table
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
int
]:
...
@@ -252,20 +277,14 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -252,20 +277,14 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return
block_ids
# type: ignore
return
block_ids
# type: ignore
def
access_all_blocks_in_seq
(
self
,
seq
:
Sequence
,
now
:
float
):
def
access_all_blocks_in_seq
(
self
,
seq
:
Sequence
,
now
:
float
):
# Update the last accessed time of all the blocks accessed
# in this step.
# And the accessed time is only useful for prefix caching now,
# as it support internal evictor policy for which cached
# block could be refilled, to keep cached content could be reused
# at max extend.
if
self
.
enable_caching
:
if
self
.
enable_caching
:
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
# Record the latest access time for the sequence. The actual update
block
_
ids
=
[]
# of the
block
ids
is deferred to the sequence free(..) call, since
for
block
_
id
in
block_table
.
physical_block_ids
:
# only during freeing of
block
id
s, the blocks are actually added to
block_ids
.
append
(
block_i
d
)
# the evictor (which is when the most updated time is require
d)
self
.
block_allocator
.
mark_blocks_as_accessed
(
# (This avoids expensive calls to
mark_blocks_as_accessed(
..))
block_ids
,
# type: ignore
self
.
_last_access_blocks_tracker
.
update_last_access
(
now
)
seq
.
seq_id
,
now
)
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
# The only need for mark block as computed is for prefix caching,
# The only need for mark block as computed is for prefix caching,
...
@@ -285,17 +304,29 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -285,17 +304,29 @@ class BlockSpaceManagerV2(BlockSpaceManager):
This method determines which blocks can be safely skipped for all
This method determines which blocks can be safely skipped for all
sequences in the sequence group.
sequences in the sequence group.
"""
"""
seq_block_ids
=
[
computed_seq_block_ids
=
[]
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
for
seq
in
seqs
for
seq
in
seqs
:
]
computed_seq_block_ids
.
append
(
self
.
_computed_blocks_tracker
.
get_cached_computed_blocks_and_update
(
seq
.
seq_id
,
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
))
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return
self
.
block_allocator
.
get_common_computed_block_ids
(
return
self
.
block_allocator
.
get_common_computed_block_ids
(
seq_block_ids
)
# type: ignore
computed_
seq_block_ids
)
# type: ignore
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
if
parent_seq
.
seq_id
not
in
self
.
block_tables
:
# Parent sequence has either been freed or never existed.
return
src_block_table
=
self
.
block_tables
[
parent_seq
.
seq_id
]
src_block_table
=
self
.
block_tables
[
parent_seq
.
seq_id
]
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
fork
()
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
fork
()
# Track child seq
self
.
_computed_blocks_tracker
.
add_seq
(
child_seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
child_seq
.
seq_id
)
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
AllocStatus
:
num_lookahead_slots
:
int
)
->
AllocStatus
:
"""Returns the AllocStatus for the given sequence_group
"""Returns the AllocStatus for the given sequence_group
...
@@ -323,19 +354,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -323,19 +354,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
List[Tuple[int, int]]: The mapping of swapping block from CPU
List[Tuple[int, int]]: The mapping of swapping block from CPU
to GPU.
to GPU.
"""
"""
blocks
=
self
.
_get_blocks_for_swap
(
seq_group
,
SequenceStatus
.
SWAPPED
)
physical_block_id_mapping
=
[]
current_swap_mapping
=
self
.
block_allocator
.
swap
(
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
blocks
=
blocks
,
source_device
=
Device
.
CPU
,
dest_device
=
Device
.
GPU
)
blocks
=
self
.
block_tables
[
seq
.
seq_id
].
blocks
if
len
(
blocks
)
==
0
:
block_number_mapping
=
{
continue
self
.
block_allocator
.
get_physical_block_id
(
Device
.
CPU
,
cpu_block_id
):
seq_swap_mapping
=
self
.
block_allocator
.
swap
(
blocks
=
blocks
,
self
.
block_allocator
.
get_physical_block_id
(
Device
.
GPU
,
src_device
=
Device
.
CPU
,
gpu_block_id
)
dst_device
=
Device
.
GPU
)
for
cpu_block_id
,
gpu_block_id
in
current_swap_mapping
.
items
()
}
# Refresh the block ids of the table (post-swap)
# convert to list of tuples once here
self
.
block_tables
[
seq
.
seq_id
].
update
(
blocks
)
return
list
(
block_number_mapping
.
items
())
seq_physical_block_id_mapping
=
{
self
.
block_allocator
.
get_physical_block_id
(
Device
.
CPU
,
cpu_block_id
):
self
.
block_allocator
.
get_physical_block_id
(
Device
.
GPU
,
gpu_block_id
)
for
cpu_block_id
,
gpu_block_id
in
seq_swap_mapping
.
items
()
}
physical_block_id_mapping
.
extend
(
list
(
seq_physical_block_id_mapping
.
items
()))
return
physical_block_id_mapping
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
"""Returns whether we can swap out the given sequence_group
"""Returns whether we can swap out the given sequence_group
...
@@ -355,7 +398,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -355,7 +398,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return
True
return
True
return
False
return
False
def
swap_out
(
self
,
seq
uence
_group
:
SequenceGroup
)
->
List
[
Tuple
[
int
,
int
]]:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
Tuple
[
int
,
int
]]:
"""Returns the block id mapping (from GPU to CPU) generated by
"""Returns the block id mapping (from GPU to CPU) generated by
swapping out the given sequence_group with num_lookahead_slots.
swapping out the given sequence_group with num_lookahead_slots.
...
@@ -366,19 +409,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -366,19 +409,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
List[Tuple[int, int]]: The mapping of swapping block from
List[Tuple[int, int]]: The mapping of swapping block from
GPU to CPU.
GPU to CPU.
"""
"""
blocks
=
self
.
_get_blocks_for_swap
(
sequence_group
,
physical_block_id_mapping
=
[]
SequenceStatus
.
RUNNING
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
current_swap_mapping
=
self
.
block_allocator
.
swap
(
blocks
=
self
.
block_tables
[
seq
.
seq_id
].
blocks
blocks
=
blocks
,
source_device
=
Device
.
GPU
,
dest_device
=
Device
.
CPU
)
if
len
(
blocks
)
==
0
:
block_number_mapping
=
{
continue
self
.
block_allocator
.
get_physical_block_id
(
Device
.
GPU
,
gpu_block_id
):
seq_swap_mapping
=
self
.
block_allocator
.
swap
(
blocks
=
blocks
,
self
.
block_allocator
.
get_physical_block_id
(
Device
.
CPU
,
src_device
=
Device
.
GPU
,
cpu_block_id
)
dst_device
=
Device
.
CPU
)
for
gpu_block_id
,
cpu_block_id
in
current_swap_mapping
.
items
()
}
# Refresh the block ids of the table (post-swap)
# convert to list of tuples once here
self
.
block_tables
[
seq
.
seq_id
].
update
(
blocks
)
return
list
(
block_number_mapping
.
items
())
seq_physical_block_id_mapping
=
{
self
.
block_allocator
.
get_physical_block_id
(
Device
.
GPU
,
gpu_block_id
):
self
.
block_allocator
.
get_physical_block_id
(
Device
.
CPU
,
cpu_block_id
)
for
gpu_block_id
,
cpu_block_id
in
seq_swap_mapping
.
items
()
}
physical_block_id_mapping
.
extend
(
list
(
seq_physical_block_id_mapping
.
items
()))
return
physical_block_id_mapping
def
get_num_free_gpu_blocks
(
self
)
->
int
:
def
get_num_free_gpu_blocks
(
self
)
->
int
:
return
self
.
block_allocator
.
get_num_free_blocks
(
Device
.
GPU
)
return
self
.
block_allocator
.
get_num_free_blocks
(
Device
.
GPU
)
...
...
vllm/core/scheduler.py
View file @
705f6a35
...
@@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager
...
@@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from
vllm.core.policy
import
Policy
,
PolicyFactory
from
vllm.core.policy
import
Policy
,
PolicyFactory
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceGroupMetadata
,
SequenceStatus
)
...
@@ -50,8 +51,8 @@ class SchedulingBudget:
...
@@ -50,8 +51,8 @@ class SchedulingBudget:
"""
"""
token_budget
:
int
token_budget
:
int
max_num_seqs
:
int
max_num_seqs
:
int
_reques
e
t_ids_num_batched_tokens
:
Set
[
str
]
=
field
(
default_factory
=
set
)
_request_ids_num_batched_tokens
:
Set
[
str
]
=
field
(
default_factory
=
set
)
_reques
e
t_ids_num_curr_seqs
:
Set
[
str
]
=
field
(
default_factory
=
set
)
_request_ids_num_curr_seqs
:
Set
[
str
]
=
field
(
default_factory
=
set
)
_num_batched_tokens
:
int
=
0
_num_batched_tokens
:
int
=
0
_num_curr_seqs
:
int
=
0
_num_curr_seqs
:
int
=
0
...
@@ -65,28 +66,28 @@ class SchedulingBudget:
...
@@ -65,28 +66,28 @@ class SchedulingBudget:
return
self
.
token_budget
-
self
.
num_batched_tokens
return
self
.
token_budget
-
self
.
num_batched_tokens
def
add_num_batched_tokens
(
self
,
req_id
:
str
,
num_batched_tokens
:
int
):
def
add_num_batched_tokens
(
self
,
req_id
:
str
,
num_batched_tokens
:
int
):
if
req_id
in
self
.
_reques
e
t_ids_num_batched_tokens
:
if
req_id
in
self
.
_request_ids_num_batched_tokens
:
return
return
self
.
_reques
e
t_ids_num_batched_tokens
.
add
(
req_id
)
self
.
_request_ids_num_batched_tokens
.
add
(
req_id
)
self
.
_num_batched_tokens
+=
num_batched_tokens
self
.
_num_batched_tokens
+=
num_batched_tokens
def
subtract_num_batched_tokens
(
self
,
req_id
:
str
,
def
subtract_num_batched_tokens
(
self
,
req_id
:
str
,
num_batched_tokens
:
int
):
num_batched_tokens
:
int
):
if
req_id
in
self
.
_reques
e
t_ids_num_batched_tokens
:
if
req_id
in
self
.
_request_ids_num_batched_tokens
:
self
.
_reques
e
t_ids_num_batched_tokens
.
remove
(
req_id
)
self
.
_request_ids_num_batched_tokens
.
remove
(
req_id
)
self
.
_num_batched_tokens
-=
num_batched_tokens
self
.
_num_batched_tokens
-=
num_batched_tokens
def
add_num_seqs
(
self
,
req_id
:
str
,
num_curr_seqs
:
int
):
def
add_num_seqs
(
self
,
req_id
:
str
,
num_curr_seqs
:
int
):
if
req_id
in
self
.
_reques
e
t_ids_num_curr_seqs
:
if
req_id
in
self
.
_request_ids_num_curr_seqs
:
return
return
self
.
_reques
e
t_ids_num_curr_seqs
.
add
(
req_id
)
self
.
_request_ids_num_curr_seqs
.
add
(
req_id
)
self
.
_num_curr_seqs
+=
num_curr_seqs
self
.
_num_curr_seqs
+=
num_curr_seqs
def
subtract_num_seqs
(
self
,
req_id
:
str
,
num_curr_seqs
:
int
):
def
subtract_num_seqs
(
self
,
req_id
:
str
,
num_curr_seqs
:
int
):
if
req_id
in
self
.
_reques
e
t_ids_num_curr_seqs
:
if
req_id
in
self
.
_request_ids_num_curr_seqs
:
self
.
_reques
e
t_ids_num_curr_seqs
.
remove
(
req_id
)
self
.
_request_ids_num_curr_seqs
.
remove
(
req_id
)
self
.
_num_curr_seqs
-=
num_curr_seqs
self
.
_num_curr_seqs
-=
num_curr_seqs
@
property
@
property
...
@@ -139,6 +140,8 @@ class SchedulerOutputs:
...
@@ -139,6 +140,8 @@ class SchedulerOutputs:
if
self
.
num_loras
>
0
:
if
self
.
num_loras
>
0
:
self
.
_sort_by_lora_ids
()
self
.
_sort_by_lora_ids
()
self
.
num_prompt_adapters
:
int
=
len
(
self
.
prompt_adapter_requests
)
def
is_empty
(
self
)
->
bool
:
def
is_empty
(
self
)
->
bool
:
# NOTE: We do not consider the ignored sequence groups.
# NOTE: We do not consider the ignored sequence groups.
return
(
not
self
.
scheduled_seq_groups
and
not
self
.
blocks_to_swap_in
return
(
not
self
.
scheduled_seq_groups
and
not
self
.
blocks_to_swap_in
...
@@ -157,6 +160,14 @@ class SchedulerOutputs:
...
@@ -157,6 +160,14 @@ class SchedulerOutputs:
if
g
.
seq_group
.
lora_request
is
not
None
if
g
.
seq_group
.
lora_request
is
not
None
}
}
@
property
def
prompt_adapter_requests
(
self
)
->
Set
[
PromptAdapterRequest
]:
return
{
g
.
seq_group
.
prompt_adapter_request
for
g
in
self
.
scheduled_seq_groups
if
g
.
seq_group
.
prompt_adapter_request
is
not
None
}
@
dataclass
@
dataclass
class
SchedulerRunningOutputs
:
class
SchedulerRunningOutputs
:
...
@@ -256,6 +267,7 @@ class Scheduler:
...
@@ -256,6 +267,7 @@ class Scheduler:
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
pipeline_parallel_size
:
int
=
1
,
)
->
None
:
)
->
None
:
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
...
@@ -273,11 +285,19 @@ class Scheduler:
...
@@ -273,11 +285,19 @@ class Scheduler:
BlockSpaceManagerImpl
=
BlockSpaceManager
.
get_block_space_manager_class
(
BlockSpaceManagerImpl
=
BlockSpaceManager
.
get_block_space_manager_class
(
version
)
version
)
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
if
num_gpu_blocks
:
num_gpu_blocks
//=
pipeline_parallel_size
num_cpu_blocks
=
cache_config
.
num_cpu_blocks
if
num_cpu_blocks
:
num_cpu_blocks
//=
pipeline_parallel_size
# Create the block space manager.
# Create the block space manager.
self
.
block_manager
=
BlockSpaceManagerImpl
(
self
.
block_manager
=
BlockSpaceManagerImpl
(
block_size
=
self
.
cache_config
.
block_size
,
block_size
=
self
.
cache_config
.
block_size
,
num_gpu_blocks
=
self
.
cache_config
.
num_gpu_blocks
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
sliding_window
=
self
.
cache_config
.
sliding_window
,
sliding_window
=
self
.
cache_config
.
sliding_window
,
enable_caching
=
self
.
cache_config
.
enable_prefix_caching
)
enable_caching
=
self
.
cache_config
.
enable_prefix_caching
)
...
@@ -290,7 +310,10 @@ class Scheduler:
...
@@ -290,7 +310,10 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
# Contain decode requests that are swapped out.
self
.
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
self
.
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
self
.
_finished_requests_ids
:
List
[
str
]
=
list
()
# Time at previous scheduling step
# Time at previous scheduling step
self
.
prev_time
=
0.0
self
.
prev_time
=
0.0
# Did we schedule a prompt at previous step?
# Did we schedule a prompt at previous step?
...
@@ -364,6 +387,12 @@ class Scheduler:
...
@@ -364,6 +387,12 @@ class Scheduler:
def
get_num_unfinished_seq_groups
(
self
)
->
int
:
def
get_num_unfinished_seq_groups
(
self
)
->
int
:
return
len
(
self
.
waiting
)
+
len
(
self
.
running
)
+
len
(
self
.
swapped
)
return
len
(
self
.
waiting
)
+
len
(
self
.
running
)
+
len
(
self
.
swapped
)
def
get_and_reset_finished_requests_ids
(
self
)
->
List
[
str
]:
"""Flushes the list of request ids of previously finished seq_groups."""
finished_requests_ids
=
self
.
_finished_requests_ids
self
.
_finished_requests_ids
=
list
()
return
finished_requests_ids
def
_schedule_running
(
def
_schedule_running
(
self
,
self
,
running_queue
:
deque
,
running_queue
:
deque
,
...
@@ -1006,6 +1035,7 @@ class Scheduler:
...
@@ -1006,6 +1035,7 @@ class Scheduler:
# `multi_modal_data` will be None.
# `multi_modal_data` will be None.
multi_modal_data
=
seq_group
.
multi_modal_data
multi_modal_data
=
seq_group
.
multi_modal_data
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
prompt_adapter_request
=
seq_group
.
prompt_adapter_request
,
)
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
...
@@ -1027,6 +1057,11 @@ class Scheduler:
...
@@ -1027,6 +1057,11 @@ class Scheduler:
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
def
free_finished_seq_groups
(
self
)
->
None
:
def
free_finished_seq_groups
(
self
)
->
None
:
for
queue
in
[
self
.
running
,
self
.
swapped
,
self
.
waiting
]:
self
.
_finished_requests_ids
+=
[
seq_group
.
request_id
for
seq_group
in
queue
if
seq_group
.
is_finished
()
]
self
.
running
=
deque
(
seq_group
for
seq_group
in
self
.
running
self
.
running
=
deque
(
seq_group
for
seq_group
in
self
.
running
if
not
seq_group
.
is_finished
())
if
not
seq_group
.
is_finished
())
...
...
Prev
1
…
13
14
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment