Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1591c68f
Commit
1591c68f
authored
May 25, 2024
by
zhuwenwen
Browse files
merge v0.4.2
parents
09bcf00b
c7f2cf2b
Changes
265
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1080 additions
and
359 deletions
+1080
-359
vllm/__init__.py
vllm/__init__.py
+2
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+62
-12
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+9
-4
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+23
-22
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+220
-0
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+55
-58
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+18
-18
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+33
-33
vllm/attention/layer.py
vllm/attention/layer.py
+7
-0
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+19
-18
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+56
-19
vllm/attention/ops/triton_flash_attention.py
vllm/attention/ops/triton_flash_attention.py
+11
-13
vllm/attention/selector.py
vllm/attention/selector.py
+13
-9
vllm/config.py
vllm/config.py
+149
-58
vllm/core/block/block_table.py
vllm/core/block/block_table.py
+10
-6
vllm/core/block/common.py
vllm/core/block/common.py
+16
-4
vllm/core/block/cpu_gpu_block_allocator.py
vllm/core/block/cpu_gpu_block_allocator.py
+35
-13
vllm/core/block/interfaces.py
vllm/core/block/interfaces.py
+100
-5
vllm/core/block/naive_block.py
vllm/core/block/naive_block.py
+53
-10
vllm/core/block/prefix_caching_block.py
vllm/core/block/prefix_caching_block.py
+189
-55
No files found.
vllm/__init__.py
View file @
1591c68f
...
...
@@ -3,14 +3,14 @@
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.ray_utils
import
initialize_ray_cluster
from
vllm.entrypoints.llm
import
LLM
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.version
import
__dcu_version__
__version__
=
"0.4.
1
"
__version__
=
"0.4.
2
"
__all__
=
[
"LLM"
,
...
...
vllm/_custom_ops.py
View file @
1591c68f
...
...
@@ -39,17 +39,17 @@ def paged_attention_v1(
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
context
_lens
:
torch
.
Tensor
,
seq
_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_
context
_len
:
int
,
max_
seq
_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
)
->
None
:
vllm_ops
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_
context_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
)
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_
seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
)
def
paged_attention_v2
(
...
...
@@ -63,17 +63,17 @@ def paged_attention_v2(
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
context
_lens
:
torch
.
Tensor
,
seq
_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_
context
_len
:
int
,
max_
seq
_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
)
->
None
:
vllm_ops
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
block_size
,
max_
context
_len
,
alibi_slopes
,
kv_cache_dtype
,
block_tables
,
seq
_lens
,
block_size
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
)
...
...
@@ -153,11 +153,49 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n
,
size_k
)
# aqlm
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
vllm_ops
.
aqlm_gemm
(
input
,
codes
,
codebooks
,
scales
,
codebook_partition_sizes
,
bias
)
def
aqlm_dequant
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
vllm_ops
.
aqlm_dequant
(
codes
,
codebooks
,
codebook_partition_sizes
)
# gptq_marlin
def
gptq_marlin_repack
(
b_q_weight
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_repack
(
b_q_weight
,
perm
,
size_k
,
size_n
,
num_bits
)
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
g_idx
,
perm
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
,
is_k_full
)
# fp8
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
float8_e4m3fn
)
vllm_ops
.
scaled_fp8_quant
(
output
,
input
,
scale
)
if
scale
is
None
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
vllm_ops
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
vllm_ops
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
...
...
@@ -184,6 +222,18 @@ def reshape_and_cache(
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
def
reshape_and_cache_flash
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
)
->
None
:
vllm_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
)
def
copy_blocks
(
key_caches
:
torch
.
Tensor
,
value_caches
:
torch
.
Tensor
,
block_mapping
:
torch
.
Tensor
)
->
None
:
vllm_cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
...
...
vllm/attention/backends/abstract.py
View file @
1591c68f
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
fields
from
typing
import
Any
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
from
typing
import
(
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
)
import
torch
...
...
@@ -15,7 +16,7 @@ class AttentionBackend(ABC):
@
staticmethod
@
abstractmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata
PerStage
"
:
raise
NotImplementedError
@
staticmethod
...
...
@@ -50,13 +51,17 @@ class AttentionBackend(ABC):
class
AttentionMetadataPerStage
:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
def
asdict_zerocopy
(
self
)
->
Dict
[
str
,
Any
]:
def
asdict_zerocopy
(
self
,
skip_fields
:
Optional
[
Set
[
str
]]
=
None
)
->
Dict
[
str
,
Any
]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if
skip_fields
is
None
:
skip_fields
=
set
()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return
{
field
.
name
:
getattr
(
self
,
field
.
name
)
for
field
in
fields
(
self
)
for
field
in
fields
(
self
)
if
field
.
name
not
in
skip_fields
}
...
...
vllm/attention/backends/flash_attn.py
View file @
1591c68f
...
...
@@ -66,27 +66,24 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len,
sub
query_len, and seqlen.
# NOTE(sang): Definition of context_len, query_len, and seq
_
len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |-
sub
query_len -|
# |-------------------- seq
_
len ----------------------|
# |-
-
query_len
--
-|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len
:
Optional
[
int
]
# Maximum prompt length in the batch.
max_prompt_len
:
Optional
[
int
]
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
# Maximum sequence length in the batch.
max_seq_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
...
...
@@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
...
...
@@ -223,8 +223,8 @@ class FlashAttentionImpl(AttentionImpl):
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_
prompt
_len
,
max_seqlen_k
=
prefill_meta
.
max_
prompt
_len
,
max_seqlen_q
=
prefill_meta
.
max_
seq
_len
,
max_seqlen_k
=
prefill_meta
.
max_
seq
_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
...
...
@@ -245,10 +245,11 @@ class FlashAttentionImpl(AttentionImpl):
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
prompt
_lens_tensor
,
prefill_meta
.
context_lens
,
prefill_meta
.
max_
sub
query_len
,
prefill_meta
.
seq
_lens_tensor
,
prefill_meta
.
context_lens
_tensor
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
...
...
@@ -257,8 +258,8 @@ class FlashAttentionImpl(AttentionImpl):
key_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
context_lens
,
decode_meta
.
max_
context
_len
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_
seq
_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
...
...
vllm/attention/backends/flashinfer.py
0 → 100644
View file @
1591c68f
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
try
:
import
flashinfer
from
flash_attn
import
flash_attn_varlen_func
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
except
ImportError
:
flashinfer
=
None
flash_attn_varlen_func
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadataPerStage
)
class
FlashInferBackend
(
AttentionBackend
):
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashInferImpl"
]:
return
FlashInferImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"FlashInferMetadata"
:
return
FlashInferMetadata
(
*
args
,
**
kwargs
)
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
raise
NotImplementedError
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
raise
NotImplementedError
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
128
,
256
]
@
dataclass
class
FlashInferMetadata
(
AttentionMetadataPerStage
):
is_prompt
:
bool
use_cuda_graph
:
bool
=
False
decode_wrapper
:
Optional
[
BatchDecodeWithPagedKVCacheWrapper
]
=
None
# Metadata for the prefill stage since we still
# use flash attention for prefill.
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
max_seq_len
:
Optional
[
int
]
=
None
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer
:
Optional
[
torch
.
Tensor
]
=
None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
# The page indices of the paged kv cache
paged_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len
:
Optional
[
torch
.
Tensor
]
=
None
# The number of query/output heads
num_qo_heads
:
Optional
[
int
]
=
None
# The number of key/value heads
num_kv_heads
:
Optional
[
int
]
=
None
# The dimension of the attention heads
head_dim
:
Optional
[
int
]
=
None
# Block size of vllm
page_size
:
Optional
[
int
]
=
None
# The data type of the paged kv cache
data_type
:
torch
.
dtype
=
None
def
__post_init__
(
self
):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
supported_head_sizes
=
FlashInferBackend
.
get_supported_head_sizes
()
if
self
.
head_dim
is
not
None
and
self
.
head_dim
\
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Only
{
supported_head_sizes
}
are supported for head_dim,"
,
f
"received
{
self
.
head_dim
}
."
)
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if
not
self
.
is_prompt
:
self
.
decode_wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
self
.
decode_wrapper
.
begin_forward
(
self
.
paged_kv_indptr
,
self
.
paged_kv_indices
,
self
.
paged_kv_last_page_len
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
data_type
)
def
asdict_zerocopy
(
self
,
skip_fields
:
Optional
[
Set
[
str
]]
=
None
)
->
Dict
[
str
,
Any
]:
if
skip_fields
is
None
:
skip_fields
=
set
()
# We need to skip the decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields
.
add
(
'decode_wrapper'
)
return
super
().
asdict_zerocopy
(
skip_fields
)
class
FlashInferImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
if
sliding_window
is
not
None
:
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
alibi_slopes
=
alibi_slopes
self
.
scale
=
scale
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
[
FlashInferMetadata
],
kv_scale
:
float
):
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
attn_metadata
.
num_prefill_tokens
>
0
:
assert
attn_metadata
.
num_decode_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
attn_metadata
.
num_decode_tokens
>
0
:
assert
attn_metadata
.
num_prefill_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
kv_cache
is
not
None
:
# Use the same reshape and cache kernel as flash attention.
ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[:,
0
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
attn_metadata
.
kv_cache_dtype
,
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
assert
prefill_meta
.
block_tables
is
not
None
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
)
else
:
raise
NotImplementedError
(
"Prefix caching is not supported with flashinfer yet."
)
else
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
query
=
query
.
contiguous
(
)
# Flashinfer requires query to be contiguous
output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
query
,
kv_cache
,
sm_scale
=
self
.
scale
,
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/rocm_flash_attn.py
View file @
1591c68f
"""Attention layer ROCm GPUs."""
import
os
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadataPerStage
)
...
...
@@ -64,27 +64,24 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len,
sub
query_len, and seqlen.
# NOTE(sang): Definition of context_len, query_len, and seq
_
len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |-
sub
query_len -|
# |-------------------- seq
_
len ----------------------|
# |-
-
query_len
--
-|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len
:
Optional
[
int
]
# Maximum prompt length in the batch.
max_prompt_len
:
Optional
[
int
]
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
# Maximum sequence length in the batch.
max_seq_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
...
...
@@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
class
ROCmFlashAttentionImpl
(
AttentionImpl
):
...
...
@@ -156,8 +156,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"True"
).
lower
()
in
(
"true"
,
"1"
))
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
if
self
.
use_triton_flash_attn
:
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
triton_attention
)
...
...
@@ -248,41 +247,36 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
assert
prefill_meta
.
prompt
_lens
is
not
None
assert
prefill_meta
.
seq
_lens
is
not
None
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if
self
.
use_triton_flash_attn
or
self
.
use_naive_attn
:
if
self
.
use_triton_flash_attn
:
out
,
_
=
self
.
attn_func
(
query
,
key
,
value
,
None
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_seq_len
,
prefill_meta
.
max_seq_len
,
True
,
self
.
scale
,
)
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Interleave for MQA workaround.
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
if
self
.
use_naive_attn
:
out
=
self
.
attn_func
(
query
,
key
,
value
,
prefill_meta
.
prompt_lens
,
self
.
scale
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
out
,
_
=
self
.
attn_func
(
query
,
key
,
value
,
None
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_prompt_len
,
prefill_meta
.
max_prompt_len
,
True
,
self
.
scale
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
out
=
self
.
attn_func
(
query
,
key
,
value
,
prefill_meta
.
seq_lens
,
self
.
scale
,
)
else
:
out
=
self
.
attn_func
(
q
=
query
,
...
...
@@ -290,13 +284,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_
prompt
_len
,
max_seqlen_k
=
prefill_meta
.
max_
prompt
_len
,
max_seqlen_q
=
prefill_meta
.
max_
seq
_len
,
max_seqlen_k
=
prefill_meta
.
max_
seq
_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
# common code for prefill
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
# prefix-enabled attention
output
[:
num_prefill_tokens
]
=
PagedAttention
.
forward_prefix
(
...
...
@@ -307,10 +303,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
prompt
_lens_tensor
,
prefill_meta
.
context_lens
,
prefill_meta
.
max_
sub
query_len
,
prefill_meta
.
seq
_lens_tensor
,
prefill_meta
.
context_lens
_tensor
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
...
@@ -320,8 +317,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
context_lens
,
decode_meta
.
max_
context
_len
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_
seq
_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
...
...
@@ -337,13 +334,13 @@ def _naive_attention(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
scale
:
float
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
start
=
0
for
_
,
prompt
_len
in
enumerate
(
prompt
_lens
):
end
=
start
+
prompt
_len
for
_
,
seq
_len
in
enumerate
(
seq
_lens
):
end
=
start
+
seq
_len
out
=
_naive_masked_attention
(
query
[
start
:
end
],
key
[
start
:
end
],
...
...
@@ -352,7 +349,7 @@ def _naive_attention(
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
)
start
+=
prompt
_len
start
+=
seq
_len
return
output
...
...
vllm/attention/backends/torch_sdpa.py
View file @
1591c68f
...
...
@@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
slot_mapping
:
torch
.
Tensor
prompt
_lens
:
Optional
[
List
[
int
]]
seq
_lens
:
Optional
[
List
[
int
]]
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
...
...
@@ -136,7 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
kv_scale
)
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
prompt
_lens
is
not
None
assert
attn_metadata
.
seq
_lens
is
not
None
if
(
kv_cache
is
None
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
...
...
@@ -147,13 +147,13 @@ class TorchSDPABackendImpl(AttentionImpl):
if
self
.
alibi_slopes
is
not
None
:
att_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
prompt
_lens
)
# type: ignore
attn_metadata
.
seq
_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
prompt
_lens
,
self
.
sliding_window
,
attn_metadata
.
seq
_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
att_masks
=
[
None
]
*
len
(
attn_metadata
.
prompt
_lens
)
att_masks
=
[
None
]
*
len
(
attn_metadata
.
seq
_lens
)
attn_metadata
.
attn_bias
=
att_masks
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
...
...
@@ -164,9 +164,9 @@ class TorchSDPABackendImpl(AttentionImpl):
output
=
torch
.
empty
(
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
)
for
prompt
_len
,
mask
in
zip
(
attn_metadata
.
prompt
_lens
,
attn_metadata
.
attn_bias
):
end
=
start
+
prompt
_len
for
seq
_len
,
mask
in
zip
(
attn_metadata
.
seq
_lens
,
attn_metadata
.
attn_bias
):
end
=
start
+
seq
_len
sub_out
=
scaled_dot_product_attention
(
query
[:,
start
:
end
,
:],
key
[:,
start
:
end
,
:],
...
...
@@ -189,8 +189,8 @@ class TorchSDPABackendImpl(AttentionImpl):
key_cache
,
value_cache
,
attn_metadata
.
block_tables
,
attn_metadata
.
context_lens
,
attn_metadata
.
max_
context
_len
,
attn_metadata
.
seq_lens_tensor
,
attn_metadata
.
max_
seq
_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
...
...
@@ -205,13 +205,13 @@ class TorchSDPABackendImpl(AttentionImpl):
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
prompt
_len
in
prompt
_lens
:
bias
=
torch
.
arange
(
prompt
_len
,
dtype
=
dtype
)
for
seq
_len
in
seq
_lens
:
bias
=
torch
.
arange
(
seq
_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(
prompt
_len, 1)`
# `bias = bias[None, :].repeat(
seq
_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
...
...
@@ -221,7 +221,7 @@ def _make_alibi_bias(
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
inf_mask
=
torch
.
empty
(
(
1
,
prompt_len
,
prompt
_len
),
(
1
,
seq_len
,
seq
_len
),
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
...
...
@@ -229,14 +229,14 @@ def _make_alibi_bias(
def
_make_sliding_window_bias
(
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
prompt
_len
in
prompt
_lens
:
for
seq
_len
in
seq
_lens
:
tensor
=
torch
.
full
(
(
1
,
prompt_len
,
prompt
_len
),
(
1
,
seq_len
,
seq
_len
),
dtype
=
dtype
,
fill_value
=
1
,
)
...
...
vllm/attention/backends/xformers.py
View file @
1591c68f
...
...
@@ -66,28 +66,24 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |-
sub
query_len -|
# |-------------------- seq
_
len ----------------------|
# |-
-
query_len
--
-|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len
:
Optional
[
int
]
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
# FIXME: It is for flash attn.
# Maximum
prompt
length in the batch.
max_
prompt
_len
:
Optional
[
int
]
# Maximum
sequence
length in the batch.
max_
seq
_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
...
...
@@ -97,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
...
...
@@ -242,10 +241,11 @@ class XFormersImpl(AttentionImpl):
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
prompt
_lens_tensor
,
prefill_meta
.
context_lens
,
prefill_meta
.
max_
sub
query_len
,
prefill_meta
.
seq
_lens_tensor
,
prefill_meta
.
context_lens
_tensor
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
...
...
@@ -256,8 +256,8 @@ class XFormersImpl(AttentionImpl):
key_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
context_lens
,
decode_meta
.
max_
context
_len
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_
seq
_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
...
...
@@ -288,7 +288,7 @@ class XFormersImpl(AttentionImpl):
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
assert
attn_metadata
.
prompt
_lens
is
not
None
assert
attn_metadata
.
seq
_lens
is
not
None
original_query
=
query
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# GQA/MQA requires the shape [B, M, G, H, K].
...
...
@@ -309,7 +309,7 @@ class XFormersImpl(AttentionImpl):
if
attn_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
None
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
attn_metadata
.
prompt
_lens
)
attn_metadata
.
seq
_lens
)
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
...
...
@@ -317,7 +317,7 @@ class XFormersImpl(AttentionImpl):
else
:
attn_metadata
.
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
query
.
dtype
,
attn_metadata
.
prompt
_lens
)
attn_metadata
.
seq
_lens
)
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
...
...
@@ -342,8 +342,8 @@ class XFormersImpl(AttentionImpl):
# one. This is inefficient, especially when we have many short prompts.
output
=
torch
.
empty_like
(
original_query
)
start
=
0
for
i
,
prompt
_len
in
enumerate
(
attn_metadata
.
prompt
_lens
):
end
=
start
+
prompt
_len
for
i
,
seq
_len
in
enumerate
(
attn_metadata
.
seq
_lens
):
end
=
start
+
seq
_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
...
...
@@ -353,7 +353,7 @@ class XFormersImpl(AttentionImpl):
scale
=
self
.
scale
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
.
view_as
(
original_query
[
start
:
end
]))
start
+=
prompt
_len
start
+=
seq
_len
return
output
...
...
@@ -361,13 +361,13 @@ def _make_alibi_bias(
alibi_slopes
:
torch
.
Tensor
,
num_kv_heads
:
int
,
dtype
:
torch
.
dtype
,
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
)
->
LowerTriangularMaskWithTensorBias
:
attn_biases
=
[]
for
prompt
_len
in
prompt
_lens
:
bias
=
torch
.
arange
(
prompt
_len
,
dtype
=
dtype
)
for
seq
_len
in
seq
_lens
:
bias
=
torch
.
arange
(
seq
_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(
prompt
_len, 1)`
# `bias = bias[None, :].repeat(
seq
_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
...
...
@@ -375,16 +375,16 @@ def _make_alibi_bias(
# element.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
padded_len
=
(
prompt
_len
+
7
)
//
8
*
8
padded_len
=
(
seq
_len
+
7
)
//
8
*
8
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
torch
.
empty
(
1
,
# batch size
num_heads
,
prompt
_len
,
seq
_len
,
padded_len
,
device
=
alibi_slopes
.
device
,
dtype
=
dtype
,
)[:,
:,
:,
:
prompt
_len
].
copy_
(
bias
)
)[:,
:,
:,
:
seq
_len
].
copy_
(
bias
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
if
num_heads
!=
num_kv_heads
:
bias
=
bias
.
unflatten
(
1
,
(
num_kv_heads
,
num_heads
//
num_kv_heads
))
...
...
vllm/attention/layer.py
View file @
1591c68f
...
...
@@ -47,3 +47,10 @@ class Attention(nn.Module):
)
->
torch
.
Tensor
:
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
kv_scale
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
s
+=
f
", num_heads=
{
self
.
impl
.
num_heads
}
"
# type: ignore
s
+=
f
", num_kv_heads=
{
self
.
impl
.
num_kv_heads
}
"
# type: ignore
s
+=
f
", scale=
{
self
.
impl
.
scale
}
"
# type: ignore
return
s
vllm/attention/ops/paged_attn.py
View file @
1591c68f
...
...
@@ -13,12 +13,11 @@ _PARTITION_SIZE = 512
@
dataclass
class
PagedAttentionMetadata
:
"""Metadata for PagedAttention."""
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens
:
Optional
[
torch
.
Tensor
]
# Maximum context length in the batch.
max_context_len
:
Optional
[
int
]
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Maximum sequence length in the batch.
max_seq_len
:
Optional
[
int
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
...
...
@@ -85,8 +84,8 @@ class PagedAttention:
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context
_lens
:
torch
.
Tensor
,
max_
context
_len
:
int
,
seq
_lens
:
torch
.
Tensor
,
max_
seq
_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
...
...
@@ -97,7 +96,7 @@ class PagedAttention:
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
((
max_
context
_len
+
_PARTITION_SIZE
-
1
)
//
max_num_partitions
=
((
max_
seq
_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
...
...
@@ -106,7 +105,7 @@ class PagedAttention:
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
(
max_
context
_len
<=
8192
use_v1
=
(
max_
seq
_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
if
use_v1
:
# Run PagedAttention V1.
...
...
@@ -118,9 +117,9 @@ class PagedAttention:
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
...
...
@@ -150,9 +149,9 @@ class PagedAttention:
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
...
...
@@ -168,10 +167,11 @@ class PagedAttention:
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
subquery_start_loc
:
torch
.
Tensor
,
prompt
_lens_tensor
:
torch
.
Tensor
,
seq
_lens_tensor
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_
sub
query_len
:
int
,
max_query_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
sliding_window
:
Optional
[
int
],
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
context_attention_fwd
(
...
...
@@ -184,10 +184,11 @@ class PagedAttention:
block_tables
,
# subquery_start_loc is (batch_size + 1,)
subquery_start_loc
[:
-
1
],
prompt
_lens_tensor
,
seq
_lens_tensor
,
context_lens
,
max_
sub
query_len
,
max_query_len
,
alibi_slopes
,
sliding_window
,
)
return
output
...
...
vllm/attention/ops/prefix_prefill.py
View file @
1591c68f
...
...
@@ -50,6 +50,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL
:
tl
.
constexpr
,
# head size
BLOCK_DMODEL_PADDED
:
tl
.
constexpr
,
# head size padded to a power of 2
BLOCK_N
:
tl
.
constexpr
,
SLIDING_WINDOW
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
...
...
@@ -62,42 +63,53 @@ if triton.__version__ >= "2.1.0":
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_query_len
=
cur_batch_seq_len
-
cur_batch_ctx_len
# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
# [N]; starts at 0
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
# [D]; starts at 0
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
# [M]; starts at current position in query
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
# [M,D]
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
dim_mask
=
tl
.
where
(
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
<
BLOCK_DMODEL
,
1
,
0
).
to
(
tl
.
int1
)
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
<
BLOCK_DMODEL
,
1
,
0
).
to
(
tl
.
int1
)
# [D]
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
dim_mask
[
None
,
:]
&
(
offs_m
[:,
None
]
<
cur_batch_query_len
),
other
=
0.0
)
other
=
0.0
)
# [M,D]
# # initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL_PADDED
],
dtype
=
tl
.
float32
)
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
# [M]
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
# [M]
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL_PADDED
],
dtype
=
tl
.
float32
)
# [M,D]
# compute query against context (no causal mask here)
for
start_n
in
range
(
0
,
cur_batch_ctx_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
bn
=
tl
.
load
(
B_Loc
+
cur_batch
*
stride_b_loc_b
+
((
start_n
+
offs_n
)
//
block_size
)
*
stride_b_loc_s
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
other
=
0
)
# [N]
# [D,N]
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_kv_head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
# [N,D]
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
...
...
@@ -106,23 +118,39 @@ if triton.__version__ >= "2.1.0":
k
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
dim_mask
[:,
None
]
&
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
# [D,N]
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
# [M,N]
qk
+=
tl
.
dot
(
q
,
k
)
qk
=
tl
.
where
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
qk
,
float
(
"-inf"
))
qk
*=
sm_scale
if
SLIDING_WINDOW
>
0
:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk
=
tl
.
where
((
cur_batch_ctx_len
+
offs_m
[:,
None
])
-
(
start_n
+
offs_n
[
None
,
:])
<
SLIDING_WINDOW
,
qk
,
-
10000
)
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
m_ij
=
tl
.
max
(
qk
,
1
)
# [M]
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
# [M,N]
l_ij
=
tl
.
sum
(
p
,
1
)
# [M]
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
# [M]
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
# [M]
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
# [M]
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# [M]
# -- update output accumulator --
# scale p
p_scale
=
beta
/
l_i_new
...
...
@@ -134,7 +162,7 @@ if triton.__version__ >= "2.1.0":
v
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
dim_mask
[
None
,
:]
&
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
# [N,D]
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
...
...
@@ -149,8 +177,10 @@ if triton.__version__ >= "2.1.0":
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
# block_mask is 0 when we're already past the current query length
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_query_len
,
1
,
0
)
# compute query against itself (with causal mask)
for
start_n
in
range
(
0
,
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
...
...
@@ -163,8 +193,13 @@ if triton.__version__ >= "2.1.0":
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
# apply causal mask
qk
=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
if
SLIDING_WINDOW
>
0
:
qk
=
tl
.
where
(
offs_m
[:,
None
]
-
(
start_n
+
offs_n
[
None
,
:])
<
SLIDING_WINDOW
,
qk
,
-
10000
)
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
...
...
@@ -636,7 +671,8 @@ if triton.__version__ >= "2.1.0":
b_seq_len
,
b_ctx_len
,
max_input_len
,
alibi_slopes
=
None
):
alibi_slopes
=
None
,
sliding_window
=
None
):
cap
=
torch
.
cuda
.
get_device_capability
()
BLOCK
=
128
if
cap
[
0
]
>=
8
else
64
...
...
@@ -644,7 +680,7 @@ if triton.__version__ >= "2.1.0":
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded
=
2
**
((
Lk
-
1
).
bit_length
()
)
Lk_padded
=
triton
.
next_power_of_2
(
Lk
)
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
...
...
@@ -749,6 +785,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
SLIDING_WINDOW
=
sliding_window
if
sliding_window
is
not
None
else
0
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
...
...
vllm/attention/ops/triton_flash_attention.py
View file @
1591c68f
...
...
@@ -293,7 +293,7 @@ def _attn_fwd_inner(
num_warps
=
4
,
),
],
key
=
[
"hq"
,
"hk"
,
"
IS_CAUSAL
"
,
"
dropout_p
"
,
"
BLOCK_DMODEL
"
],
key
=
[
'
IS_CAUSAL
'
,
'
dropout_p
'
,
'
BLOCK_DMODEL
'
],
)
@
triton
.
jit
def
attn_fwd
(
...
...
@@ -330,8 +330,8 @@ def attn_fwd(
philox_seed
,
philox_offset_base
,
encoded_softmax
,
hq
,
hk
,
HQ
:
tl
.
constexpr
,
HK
:
tl
.
constexpr
,
ACTUAL_BLOCK_DMODEL
:
tl
.
constexpr
,
MAX_SEQLENS_Q
:
tl
.
constexpr
,
MAX_SEQLENS_K
:
tl
.
constexpr
,
...
...
@@ -403,7 +403,7 @@ def attn_fwd(
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z *
hq
* MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# l_ptrs = L + off_z *
HQ
* MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
...
...
@@ -414,11 +414,9 @@ def attn_fwd(
# TODO: Should dropout and return encoded softmax be handled here?
return
is_mqa
=
hq
!=
hk
if
is_mqa
:
# noqa: SIM108
off_h_k
=
off_h_q
%
hk
else
:
off_h_k
=
off_h_q
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE
:
tl
.
constexpr
=
HQ
//
HK
off_h_k
=
off_h_q
//
GROUP_SIZE
if
GROUP_SIZE
!=
1
else
off_h_q
n_extra_tokens
=
0
if
seqlen_k
<
BLOCK_N
:
...
...
@@ -471,7 +469,7 @@ def attn_fwd(
bias_ptr
=
None
if
ENABLE_DROPOUT
:
batch_philox_offset
=
philox_offset_base
\
+
(
off_z
*
hq
+
off_h_q
)
\
+
(
off_z
*
HQ
+
off_h_q
)
\
*
seqlen_q
*
seqlen_k
else
:
batch_philox_offset
=
0
...
...
@@ -624,7 +622,7 @@ def attn_fwd(
z
=
0.0
acc
=
tl
.
where
(
out_ptrs_mask
,
acc
,
z
.
to
(
acc
.
type
.
element_ty
))
# write back LSE
# l_ptrs = L + off_z *
hq
* MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# l_ptrs = L + off_z *
HQ
* MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
...
...
@@ -784,8 +782,8 @@ class _attention(torch.autograd.Function):
philox_seed
=
philox_seed
,
philox_offset_base
=
philox_offset
,
encoded_softmax
=
encoded_softmax
,
hq
=
nheads_q
,
hk
=
nheads_k
,
HQ
=
nheads_q
,
HK
=
nheads_k
,
ACTUAL_BLOCK_DMODEL
=
head_size
,
MAX_SEQLENS_Q
=
max_seqlens_q
,
MAX_SEQLENS_K
=
max_seqlens_k
,
...
...
vllm/attention/selector.py
View file @
1591c68f
import
enum
import
os
from
functools
import
lru_cache
from
typing
import
Type
import
torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_cpu
,
is_hip
logger
=
init_logger
(
__name__
)
VLLM_ATTENTION_BACKEND
=
"VLLM_ATTENTION_BACKEND"
class
_Backend
(
enum
.
Enum
):
FLASH_ATTN
=
enum
.
auto
()
XFORMERS
=
enum
.
auto
()
ROCM_FLASH
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
@
lru_cache
(
maxsize
=
None
)
def
get_attn_backend
(
dtype
:
torch
.
dtype
)
->
Type
[
AttentionBackend
]:
backend
=
_which_attn_to_use
(
dtype
)
if
backend
==
_Backend
.
FLASH_ATTN
:
logger
.
info
(
"Using FlashAttention backend."
)
logger
.
info
(
"Using FlashAttention
-2
backend."
)
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
FlashAttentionBackend
)
return
FlashAttentionBackend
...
...
@@ -43,6 +42,11 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
logger
.
info
(
"Using Torch SDPA backend."
)
from
vllm.attention.backends.torch_sdpa
import
TorchSDPABackend
return
TorchSDPABackend
elif
backend
==
_Backend
.
FLASHINFER
:
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
warning
(
"Eager mode is enforced for the Flashinfer backend. "
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
return
FlashInferBackend
else
:
raise
ValueError
(
"Invalid attention backend."
)
...
...
@@ -62,12 +66,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
# NVIDIA GPUs.
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
# Volta and Turing NVIDIA GPUs.
logger
.
info
(
"Cannot use FlashAttention backend for Volta and Turing "
logger
.
info
(
"Cannot use FlashAttention
-2
backend for Volta and Turing "
"GPUs."
)
return
_Backend
.
XFORMERS
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
logger
.
info
(
"Cannot use FlashAttention backend for dtype other than "
logger
.
info
(
"Cannot use FlashAttention
-2
backend for dtype other than "
"torch.float16 or torch.bfloat16."
)
return
_Backend
.
XFORMERS
...
...
@@ -75,11 +79,11 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
import
flash_attn
# noqa: F401
except
ImportError
:
logger
.
info
(
"Cannot use FlashAttention backend because the flash_attn
package
"
"is not found. Please install it for better performance."
)
"Cannot use FlashAttention
-2
backend because the flash_attn "
"
package
is not found. Please install it for better performance."
)
return
_Backend
.
XFORMERS
backend_by_env_var
=
os
.
get
env
(
VLLM_ATTENTION_BACKEND
)
backend_by_env_var
=
env
s
.
VLLM_ATTENTION_BACKEND
if
backend_by_env_var
is
not
None
:
return
_Backend
[
backend_by_env_var
]
...
...
vllm/config.py
View file @
1591c68f
import
enum
import
json
import
os
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
TYPE_CHECKING
,
ClassVar
,
List
,
Optional
,
Union
...
...
@@ -9,11 +8,14 @@ from packaging.version import Version
from
transformers
import
PretrainedConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
(
QUANTIZATION_METHODS
,
get_quantization_config
)
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.utils
import
(
get_cpu_memory
,
get_nvcc_cuda_version
,
is_cpu
,
is_hip
,
is_neuron
)
GPTQMarlinConfig
=
get_quantization_config
(
"gptq_marlin"
)
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
...
...
@@ -21,10 +23,6 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
# If true, will load models from ModelScope instead of Hugging Face Hub.
VLLM_USE_MODELSCOPE
=
os
.
environ
.
get
(
"VLLM_USE_MODELSCOPE"
,
"False"
).
lower
()
==
"true"
_GB
=
1
<<
30
...
...
@@ -33,6 +31,8 @@ class ModelConfig:
Args:
model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
...
...
@@ -65,9 +65,16 @@ class ModelConfig:
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified,
the model name will be the same as `model`.
"""
def
__init__
(
...
...
@@ -86,8 +93,10 @@ class ModelConfig:
quantization_param_path
:
Optional
[
str
]
=
None
,
enforce_eager
:
bool
=
False
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
5
,
skip_tokenizer_init
:
bool
=
False
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
)
->
None
:
self
.
model
=
model
self
.
tokenizer
=
tokenizer
...
...
@@ -101,6 +110,11 @@ class ModelConfig:
self
.
quantization_param_path
=
quantization_param_path
self
.
enforce_eager
=
enforce_eager
self
.
max_context_len_to_capture
=
max_context_len_to_capture
if
self
.
max_context_len_to_capture
is
not
None
:
raise
ValueError
(
"`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead."
)
self
.
max_seq_len_to_capture
=
(
max_seq_len_to_capture
or
max_context_len_to_capture
)
self
.
max_logprobs
=
max_logprobs
self
.
skip_tokenizer_init
=
skip_tokenizer_init
...
...
@@ -110,6 +124,8 @@ class ModelConfig:
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
hf_text_config
,
max_model_len
)
self
.
served_model_name
=
get_served_model_name
(
model
,
served_model_name
)
if
not
self
.
skip_tokenizer_init
:
self
.
_verify_tokenizer_mode
()
self
.
_verify_quantization
()
...
...
@@ -138,14 +154,34 @@ class ModelConfig:
is_format_marlin
=
(
quant_cfg
.
get
(
"checkpoint_format"
)
==
"marlin"
or
quant_cfg
.
get
(
"is_marlin_format"
,
False
))
# Use marlin if the GPTQ model is serialized in marlin format.
if
quant_method
==
"gptq"
and
is_format_marlin
:
logger
.
info
(
"The model is serialized in Marlin format. "
# Check which LinearMethod the GPTQ model should use.
if
quant_method
==
"gptq"
:
# If serialized in Marlin format, use MarlinLinearMethod.
# TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod.
if
is_format_marlin
:
logger
.
info
(
"The model is serialized in Marlin format. "
"Using Marlin kernel."
)
quant_method
=
"marlin"
if
self
.
quantization
==
"gptq"
:
self
.
quantization
=
quant_method
# If convertible to Marlin format, use GPTQMarlinLinearMethod
# unless the user explicitly specified GPTQLinearMethod.
elif
GPTQMarlinConfig
.
is_marlin_compatible
(
quant_cfg
):
if
self
.
quantization
==
"gptq"
:
logger
.
warning
(
"The model is convertible to Marlin format, but "
"you specified quantization=gptq. Use "
"quantization=marlin for faster inference."
)
else
:
logger
.
info
(
"The model is convertible to Marlin format. "
"Using Marlin kernel."
)
quant_method
=
"marlin"
if
self
.
quantization
==
"
gptq
"
:
self
.
quantization
=
quant_method
quant_method
=
"
gptq_
marlin"
if
self
.
quantization
==
"
marlin
"
:
self
.
quantization
=
quant_method
# Verify quantization configurations.
if
self
.
quantization
is
None
:
self
.
quantization
=
quant_method
elif
self
.
quantization
!=
quant_method
:
...
...
@@ -165,17 +201,17 @@ class ModelConfig:
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
if
self
.
quantization
!=
"
marlin"
:
if
(
self
.
quantization
not
in
[
"marlin"
,
"gptq_
marlin"
])
:
logger
.
warning
(
f
"
{
self
.
quantization
}
quantization is not fully "
"%s
quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models."
)
"non-quantized models."
,
self
.
quantization
)
def
_verify_cuda_graph
(
self
)
->
None
:
if
self
.
max_
context
_len_to_capture
is
None
:
self
.
max_
context
_len_to_capture
=
self
.
max_model_len
self
.
max_
context
_len_to_capture
=
min
(
self
.
max_
context
_len_to_capture
,
self
.
max_model_len
)
if
self
.
max_
seq
_len_to_capture
is
None
:
self
.
max_
seq
_len_to_capture
=
self
.
max_model_len
self
.
max_
seq
_len_to_capture
=
min
(
self
.
max_
seq
_len_to_capture
,
self
.
max_model_len
)
def
verify_with_parallel_config
(
self
,
...
...
@@ -271,6 +307,11 @@ class ModelConfig:
return
max
(
1
,
total_num_kv_heads
//
parallel_config
.
tensor_parallel_size
)
def
get_num_attention_heads
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
return
self
.
hf_text_config
.
num_attention_heads
//
\
parallel_config
.
tensor_parallel_size
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
total_num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
return
total_num_hidden_layers
//
parallel_config
.
pipeline_parallel_size
...
...
@@ -330,7 +371,8 @@ class CacheConfig:
elif
self
.
cache_dtype
==
"fp8"
:
if
not
is_hip
():
nvcc_cuda_version
=
get_nvcc_cuda_version
()
if
nvcc_cuda_version
<
Version
(
"11.8"
):
if
nvcc_cuda_version
is
not
None
\
and
nvcc_cuda_version
<
Version
(
"11.8"
):
raise
ValueError
(
"FP8 is not supported when cuda version is"
"lower than 11.8."
)
...
...
@@ -360,7 +402,7 @@ class CacheConfig:
if
cpu_memory_usage
>
0.7
*
total_cpu_memory
:
raise
ValueError
(
"Too large swap space. "
+
msg
)
elif
cpu_memory_usage
>
0.4
*
total_cpu_memory
:
logger
.
warning
(
"Possibly too large swap space.
"
+
msg
)
logger
.
warning
(
"Possibly too large swap space.
%s"
,
msg
)
@
dataclass
...
...
@@ -574,8 +616,9 @@ class SchedulerConfig:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
else
:
if
enable_chunked_prefill
:
# For chunked prefill, choose the well-tuned batch size.
self
.
max_num_batched_tokens
=
768
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
self
.
max_num_batched_tokens
=
512
else
:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
...
...
@@ -658,6 +701,8 @@ class SpeculativeConfig:
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
use_v2_block_manager
:
bool
,
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
)
->
Optional
[
"SpeculativeConfig"
]:
"""Create a SpeculativeConfig if possible, else return None.
...
...
@@ -684,6 +729,10 @@ class SpeculativeConfig:
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
...
...
@@ -718,39 +767,57 @@ class SpeculativeConfig:
draft_code_revision
=
None
draft_quantization
=
None
draft_model_config
=
ModelConfig
(
model
=
speculative_model
,
tokenizer
=
target_model_config
.
tokenizer
,
tokenizer_mode
=
target_model_config
.
tokenizer_mode
,
trust_remote_code
=
target_model_config
.
trust_remote_code
,
dtype
=
target_model_config
.
dtype
,
seed
=
target_model_config
.
seed
,
revision
=
draft_revision
,
code_revision
=
draft_code_revision
,
tokenizer_revision
=
target_model_config
.
tokenizer_revision
,
max_model_len
=
None
,
quantization
=
draft_quantization
,
enforce_eager
=
target_model_config
.
enforce_eager
,
max_context_len_to_capture
=
target_model_config
.
max_context_len_to_capture
,
max_logprobs
=
target_model_config
.
max_logprobs
,
)
draft_model_config
.
max_model_len
=
(
SpeculativeConfig
.
_maybe_override_draft_max_model_len
(
speculative_max_model_len
,
draft_model_config
.
max_model_len
,
target_model_config
.
max_model_len
,
))
if
speculative_model
==
"[ngram]"
:
assert
(
ngram_prompt_lookup_max
is
not
None
and
ngram_prompt_lookup_max
>
0
)
if
ngram_prompt_lookup_min
is
None
:
ngram_prompt_lookup_min
=
0
else
:
assert
ngram_prompt_lookup_max
>
ngram_prompt_lookup_min
draft_parallel_config
=
(
SpeculativeConfig
.
create_draft_parallel_config
(
target_parallel_config
))
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
# draft related config as None here.
draft_model_config
=
target_model_config
draft_parallel_config
=
target_parallel_config
else
:
ngram_prompt_lookup_max
=
0
ngram_prompt_lookup_min
=
0
draft_model_config
=
ModelConfig
(
model
=
speculative_model
,
tokenizer
=
target_model_config
.
tokenizer
,
tokenizer_mode
=
target_model_config
.
tokenizer_mode
,
trust_remote_code
=
target_model_config
.
trust_remote_code
,
dtype
=
target_model_config
.
dtype
,
seed
=
target_model_config
.
seed
,
revision
=
draft_revision
,
code_revision
=
draft_code_revision
,
tokenizer_revision
=
target_model_config
.
tokenizer_revision
,
max_model_len
=
None
,
quantization
=
draft_quantization
,
enforce_eager
=
target_model_config
.
enforce_eager
,
max_seq_len_to_capture
=
target_model_config
.
max_seq_len_to_capture
,
max_logprobs
=
target_model_config
.
max_logprobs
,
)
draft_model_config
.
max_model_len
=
(
SpeculativeConfig
.
_maybe_override_draft_max_model_len
(
speculative_max_model_len
,
draft_model_config
.
max_model_len
,
target_model_config
.
max_model_len
,
))
draft_parallel_config
=
(
SpeculativeConfig
.
create_draft_parallel_config
(
target_parallel_config
))
return
SpeculativeConfig
(
draft_model_config
,
draft_parallel_config
,
num_speculative_tokens
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
,
)
@
staticmethod
...
...
@@ -818,6 +885,8 @@ class SpeculativeConfig:
draft_model_config
:
ModelConfig
,
draft_parallel_config
:
ParallelConfig
,
num_speculative_tokens
:
int
,
ngram_prompt_lookup_max
:
int
,
ngram_prompt_lookup_min
:
int
,
):
"""Create a SpeculativeConfig object.
...
...
@@ -830,6 +899,8 @@ class SpeculativeConfig:
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
self
.
_verify_args
()
...
...
@@ -853,7 +924,10 @@ class SpeculativeConfig:
return
self
.
num_speculative_tokens
def
__repr__
(
self
)
->
str
:
draft_model
=
self
.
draft_model_config
.
model
if
self
.
ngram_prompt_lookup_max
>
0
:
draft_model
=
"[ngram]"
else
:
draft_model
=
self
.
draft_model_config
.
model
num_spec_tokens
=
self
.
num_speculative_tokens
return
f
"SpeculativeConfig(
{
draft_model
=
}
,
{
num_spec_tokens
=
}
)"
...
...
@@ -862,6 +936,7 @@ class SpeculativeConfig:
class
LoRAConfig
:
max_lora_rank
:
int
max_loras
:
int
fully_sharded_loras
:
bool
=
False
max_cpu_loras
:
Optional
[
int
]
=
None
lora_dtype
:
Optional
[
torch
.
dtype
]
=
None
lora_extra_vocab_size
:
int
=
256
...
...
@@ -898,8 +973,8 @@ class LoRAConfig:
"awq"
,
"gptq"
]:
# TODO support marlin and squeezellm
logger
.
warning
(
f
"
{
model_config
.
quantization
}
quantization is not "
"tested with LoRA yet."
)
logger
.
warning
(
"%s quantization is not tested with LoRA yet."
,
model_config
.
quantization
)
def
verify_with_scheduler_config
(
self
,
scheduler_config
:
SchedulerConfig
):
if
scheduler_config
.
max_num_batched_tokens
>
65528
:
...
...
@@ -1008,7 +1083,7 @@ def _get_and_verify_dtype(
pass
else
:
# Casting between float16 and bfloat16 is allowed with a warning.
logger
.
warning
(
f
"Casting
{
config_dtype
}
to
{
torch_dtype
}
."
)
logger
.
warning
(
"Casting
%s to %s."
,
config_dtype
,
torch_dtype
)
return
torch_dtype
...
...
@@ -1051,12 +1126,12 @@ def _get_and_verify_max_len(
logger
.
warning
(
"The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: "
f
"
{
possible_keys
}
. Assuming the model's maximum length is
"
f
"
{
default_max_len
}
."
)
"%d
. Assuming the model's maximum length is
%d."
,
possible_keys
,
default_max_len
)
derived_max_model_len
=
default_max_len
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
:
if
rope_scaling
is
not
None
and
rope_scaling
[
"type"
]
!=
"su"
:
assert
"factor"
in
rope_scaling
scaling_factor
=
rope_scaling
[
"factor"
]
if
rope_scaling
[
"type"
]
==
"yarn"
:
...
...
@@ -1084,6 +1159,22 @@ def _get_and_verify_max_len(
return
int
(
max_model_len
)
def
get_served_model_name
(
model
:
str
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]):
"""
If the input is a non-empty list, the first model_name in
`served_model_name` is taken.
If the input is a non-empty string, it is used directly.
For cases where the input is either an empty string or an
empty list, the fallback is to use `self.model`.
"""
if
not
served_model_name
:
return
model
if
isinstance
(
served_model_name
,
list
):
return
served_model_name
[
0
]
return
served_model_name
@
dataclass
class
DecodingConfig
:
"""Dataclass which contains the decoding strategy of the engine"""
...
...
vllm/core/block/block_table.py
View file @
1591c68f
...
...
@@ -40,7 +40,9 @@ class BlockTable:
):
self
.
_block_size
=
block_size
self
.
_allocator
=
block_allocator
self
.
_blocks
:
Optional
[
List
[
Block
]]
=
_blocks
if
_blocks
is
None
:
_blocks
=
[]
self
.
_blocks
:
List
[
Block
]
=
_blocks
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
...
...
@@ -104,7 +106,7 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert
self
.
_is_allocated
assert
self
.
_blocks
is
not
None
assert
len
(
self
.
_blocks
)
>
0
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
)
+
num_lookahead_slots
)
...
...
@@ -141,6 +143,7 @@ class BlockTable:
blocks_to_allocate
=
cdiv
(
slots_to_allocate
,
self
.
_block_size
)
for
_
in
range
(
blocks_to_allocate
):
assert
len
(
self
.
_blocks
)
>
0
self
.
_blocks
.
append
(
self
.
_allocator
.
allocate_mutable
(
prev_block
=
self
.
_blocks
[
-
1
],
device
=
device
))
...
...
@@ -159,6 +162,7 @@ class BlockTable:
the current instance.
"""
assert
self
.
_is_allocated
assert
len
(
self
.
_blocks
)
>
0
forked_blocks
=
self
.
_allocator
.
fork
(
self
.
_blocks
[
-
1
])
return
BlockTable
(
block_size
=
self
.
_block_size
,
...
...
@@ -177,10 +181,10 @@ class BlockTable:
assert
self
.
_is_allocated
for
block
in
self
.
_blocks
:
self
.
_allocator
.
free
(
block
)
self
.
_blocks
=
None
self
.
_blocks
=
[]
@
property
def
physical_block_ids
(
self
)
->
List
[
int
]:
def
physical_block_ids
(
self
)
->
List
[
Optional
[
int
]
]
:
"""Returns a list of physical block indices for the blocks in the
BlockTable.
...
...
@@ -235,7 +239,7 @@ class BlockTable:
def
_get_all_token_ids
(
self
)
->
List
[
int
]:
# NOTE: This function is O(seq_len); use sparingly.
token_ids
=
[]
token_ids
:
List
[
int
]
=
[]
if
not
self
.
_is_allocated
:
return
token_ids
...
...
@@ -247,7 +251,7 @@ class BlockTable:
@
property
def
_is_allocated
(
self
)
->
bool
:
return
self
.
_blocks
is
not
None
return
len
(
self
.
_blocks
)
>
0
@
property
def
_num_empty_slots
(
self
)
->
int
:
...
...
vllm/core/block/common.py
View file @
1591c68f
from
collections
import
defaultdict
from
typing
import
Dict
,
Iterable
,
List
,
Optional
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Protocol
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
...
...
@@ -7,7 +7,19 @@ BlockId = int
RefCount
=
int
class
RefCounter
:
class
RefCounterProtocol
(
Protocol
):
def
incr
(
self
,
block_id
:
BlockId
)
->
RefCount
:
raise
NotImplementedError
def
decr
(
self
,
block_id
:
BlockId
)
->
RefCount
:
raise
NotImplementedError
def
get
(
self
,
block_id
:
BlockId
)
->
RefCount
:
raise
NotImplementedError
class
RefCounter
(
RefCounterProtocol
):
"""A class for managing reference counts for a set of block indices.
The RefCounter class maintains a dictionary that maps block indices to their
...
...
@@ -54,7 +66,7 @@ class RefCounter:
return
ReadOnlyRefCounter
(
self
)
class
ReadOnlyRefCounter
:
class
ReadOnlyRefCounter
(
RefCounterProtocol
)
:
"""A read-only view of the RefCounter class.
The ReadOnlyRefCounter class provides a read-only interface to access the
...
...
@@ -96,7 +108,7 @@ class CopyOnWriteTracker:
def
__init__
(
self
,
refcounter
:
RefCounter
,
refcounter
:
RefCounter
Protocol
,
allocator
:
BlockAllocator
,
):
self
.
_copy_on_writes
:
Dict
[
BlockId
,
List
[
BlockId
]]
=
defaultdict
(
list
)
...
...
vllm/core/block/cpu_gpu_block_allocator.py
View file @
1591c68f
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
FrozenSet
,
List
,
Optional
from
vllm.core.block.interfaces
import
(
Block
,
BlockAllocator
,
from
vllm.core.block.interfaces
import
(
Block
,
BlockAllocator
,
BlockId
,
DeviceAwareBlockAllocator
)
from
vllm.core.block.naive_block
import
NaiveBlock
,
NaiveBlockAllocator
from
vllm.core.block.prefix_caching_block
import
PrefixCachingBlockAllocator
...
...
@@ -57,15 +57,15 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
cpu_block_ids
=
block_ids
[
num_gpu_blocks
:]
if
allocator_type
==
"naive"
:
gpu_allocator
=
NaiveBlockAllocator
(
create_block
=
NaiveBlock
,
gpu_allocator
:
BlockAllocator
=
NaiveBlockAllocator
(
create_block
=
NaiveBlock
,
# type: ignore
num_blocks
=
num_gpu_blocks
,
block_size
=
block_size
,
block_ids
=
gpu_block_ids
,
)
cpu_allocator
=
NaiveBlockAllocator
(
create_block
=
NaiveBlock
,
cpu_allocator
:
BlockAllocator
=
NaiveBlockAllocator
(
create_block
=
NaiveBlock
,
# type: ignore
num_blocks
=
num_cpu_blocks
,
block_size
=
block_size
,
block_ids
=
cpu_block_ids
,
...
...
@@ -105,7 +105,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Device
.
GPU
:
gpu_block_allocator
,
}
self
.
_block_ids_to_allocator
=
{}
self
.
_block_ids_to_allocator
:
Dict
[
int
,
BlockAllocator
]
=
{}
for
_
,
allocator
in
self
.
_allocators
.
items
():
for
block_id
in
allocator
.
all_block_ids
:
self
.
_block_ids_to_allocator
[
block_id
]
=
allocator
...
...
@@ -149,7 +149,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
block (Block): The block to be freed.
"""
allocator
=
self
.
_block_ids_to_allocator
[
block
.
block_id
]
block_id
=
block
.
block_id
assert
block_id
is
not
None
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
return
allocator
.
free
(
block
)
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
...
...
@@ -163,7 +165,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
List[Block]: A new list of blocks that shares the same memory as the
original sequence.
"""
allocator
=
self
.
_block_ids_to_allocator
[
last_block
.
block_id
]
block_id
=
last_block
.
block_id
assert
block_id
is
not
None
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
return
allocator
.
fork
(
last_block
)
def
get_num_free_blocks
(
self
,
device
:
Device
)
->
int
:
...
...
@@ -171,13 +175,16 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
device (Device): The device for which to query the number of free
blocks.
blocks.
AssertionError is raised if None is passed.
Returns:
int: The number of free blocks available on the specified device.
"""
return
self
.
_allocators
[
device
].
get_num_free_blocks
()
def
get_num_total_blocks
(
self
,
device
:
Device
)
->
int
:
return
self
.
_allocators
[
device
].
get_num_total_blocks
()
def
clear_copy_on_writes
(
self
)
->
Dict
[
int
,
List
[
int
]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs.
...
...
@@ -190,10 +197,18 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
clear_copy_on_writes
()
def
mark_blocks_as_computed
(
self
)
->
None
:
def
mark_blocks_as_accessed
(
self
,
block_ids
:
List
[
int
],
now
:
float
)
->
None
:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
mark_blocks_as_accessed
(
block_ids
,
now
)
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
mark_blocks_as_computed
()
return
self
.
_allocators
[
device
].
mark_blocks_as_computed
(
block_ids
)
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
...
...
@@ -202,5 +217,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
return
self
.
_allocators
[
device
].
get_common_computed_block_ids
(
seq_block_ids
)
def
all_block_ids
(
self
)
->
frozenset
[
int
]:
@
property
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]:
return
frozenset
(
self
.
_block_ids_to_allocator
.
keys
())
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
raise
NotImplementedError
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
raise
NotImplementedError
vllm/core/block/interfaces.py
View file @
1591c68f
...
...
@@ -3,6 +3,8 @@ from typing import Dict, FrozenSet, List, Optional, Protocol
from
vllm.utils
import
Device
BlockId
=
int
class
Block
(
ABC
):
...
...
@@ -15,6 +17,12 @@ class Block(ABC):
def
block_id
(
self
)
->
Optional
[
int
]:
pass
@
block_id
.
setter
@
abstractmethod
def
block_id
(
self
,
value
:
Optional
[
int
])
->
None
:
"""NOTE: Do not use this API outside Block."""
self
.
_block_id
=
value
@
property
@
abstractmethod
def
token_ids
(
self
)
->
List
[
int
]:
...
...
@@ -35,6 +43,27 @@ class Block(ABC):
def
prev_block
(
self
)
->
Optional
[
"Block"
]:
pass
@
property
@
abstractmethod
def
computed
(
self
)
->
bool
:
raise
NotImplementedError
@
computed
.
setter
@
abstractmethod
def
computed
(
self
,
value
)
->
bool
:
"""Should be only used by PrefixCacingAllocator"""
raise
NotImplementedError
@
property
@
abstractmethod
def
last_accessed
(
self
)
->
float
:
raise
NotImplementedError
@
last_accessed
.
setter
@
abstractmethod
def
last_accessed
(
self
,
last_accessed_ts
:
float
):
raise
NotImplementedError
class
Factory
(
Protocol
):
@
abstractmethod
...
...
@@ -48,6 +77,17 @@ class Block(ABC):
)
->
"Block"
:
pass
@
property
@
abstractmethod
def
content_hash
(
self
)
->
Optional
[
int
]:
"""Return the content-based hash of the current block, or None if it is
not yet defined or not supported.
For the content-based hash to be defined, the current block must be
full.
"""
return
None
class
BlockAllocator
(
ABC
):
...
...
@@ -57,7 +97,7 @@ class BlockAllocator(ABC):
@
abstractmethod
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
]
,
device
:
Device
)
->
Block
:
token_ids
:
List
[
int
])
->
Block
:
pass
@
abstractmethod
...
...
@@ -69,7 +109,11 @@ class BlockAllocator(ABC):
pass
@
abstractmethod
def
get_num_free_blocks
(
self
,
device
:
Device
)
->
int
:
def
get_num_total_blocks
(
self
)
->
int
:
pass
@
abstractmethod
def
get_num_free_blocks
(
self
)
->
int
:
pass
@
property
...
...
@@ -82,7 +126,12 @@ class BlockAllocator(ABC):
pass
@
abstractmethod
def
mark_blocks_as_computed
(
self
)
->
None
:
def
mark_blocks_as_accessed
(
self
,
block_ids
:
List
[
int
],
now
:
float
)
->
None
:
pass
@
abstractmethod
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
pass
@
abstractmethod
...
...
@@ -90,14 +139,25 @@ class BlockAllocator(ABC):
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
pass
@
abstractmethod
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
"BlockId"
]:
"""NOTE: This should not be used besides Block"""
pass
@
abstractmethod
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
"""NOTE: This should not be used besides Block"""
pass
class
NoFreeBlocksError
(
ValueError
):
pass
class
DeviceAwareBlockAllocator
(
BlockAllocator
):
class
DeviceAwareBlockAllocator
(
ABC
):
@
abstractmethod
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
])
->
Block
:
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
pass
@
abstractmethod
...
...
@@ -108,3 +168,38 @@ class DeviceAwareBlockAllocator(BlockAllocator):
@
abstractmethod
def
get_num_free_blocks
(
self
,
device
:
Device
)
->
int
:
pass
@
abstractmethod
def
get_num_total_blocks
(
self
,
device
:
Device
)
->
int
:
pass
@
abstractmethod
def
free
(
self
,
block
:
Block
)
->
None
:
pass
@
abstractmethod
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
pass
@
property
@
abstractmethod
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]:
pass
@
abstractmethod
def
clear_copy_on_writes
(
self
)
->
Dict
[
int
,
List
[
int
]]:
pass
@
abstractmethod
def
mark_blocks_as_accessed
(
self
,
block_ids
:
List
[
int
],
now
:
float
)
->
None
:
pass
@
abstractmethod
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
pass
@
abstractmethod
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
pass
vllm/core/block/naive_block.py
View file @
1591c68f
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Set
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Set
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
RefCounter
,
get_all_blocks_recursively
)
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
BlockId
=
int
Refcount
=
int
...
...
@@ -49,8 +48,10 @@ class NaiveBlockAllocator(BlockAllocator):
allocator
=
self
,
)
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
])
->
Block
:
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
...
...
@@ -63,11 +64,14 @@ class NaiveBlockAllocator(BlockAllocator):
Returns:
Block: The newly allocated immutable block.
"""
assert
device
is
None
block
=
self
.
allocate_mutable
(
prev_block
=
prev_block
)
block
.
append_token_ids
(
token_ids
)
return
block
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
])
->
Block
:
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates a new mutable block, linked to the previous block.
Args:
...
...
@@ -78,6 +82,7 @@ class NaiveBlockAllocator(BlockAllocator):
Returns:
Block: The newly allocated mutable block.
"""
assert
device
is
None
block_id
=
self
.
_allocate_new_block_id
()
return
self
.
_create_block
(
prev_block
=
prev_block
,
...
...
@@ -88,6 +93,7 @@ class NaiveBlockAllocator(BlockAllocator):
)
def
free
(
self
,
block
:
Block
)
->
None
:
assert
block
.
block_id
is
not
None
self
.
_free_block_id
(
block
.
block_id
)
# Mark the block as having no allocation.
...
...
@@ -111,6 +117,7 @@ class NaiveBlockAllocator(BlockAllocator):
for
block
in
source_blocks
:
# Increment refcount for each block.
assert
block
.
block_id
is
not
None
refcount
=
self
.
_refcounter
.
incr
(
block
.
block_id
)
assert
refcount
!=
1
,
"can't fork free'd block"
...
...
@@ -129,6 +136,9 @@ class NaiveBlockAllocator(BlockAllocator):
def
get_num_free_blocks
(
self
)
->
int
:
return
len
(
self
.
_free_block_indices
)
def
get_num_total_blocks
(
self
)
->
int
:
return
len
(
self
.
_all_block_indices
)
def
_allocate_new_block_id
(
self
)
->
BlockId
:
if
not
self
.
_free_block_indices
:
raise
BlockAllocator
.
NoFreeBlocksError
()
...
...
@@ -148,7 +158,7 @@ class NaiveBlockAllocator(BlockAllocator):
return
self
.
_refcounter
@
property
def
all_block_ids
(
self
):
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]
:
return
self
.
_all_block_indices
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
...
...
@@ -174,7 +184,16 @@ class NaiveBlockAllocator(BlockAllocator):
"""
return
self
.
_cow_tracker
.
clear_cows
()
def
mark_blocks_as_computed
(
self
)
->
None
:
def
mark_blocks_as_accessed
(
self
,
block_ids
:
List
[
int
],
now
:
float
)
->
None
:
"""Mark blocks as accessed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
nothing.
"""
pass
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
"""Mark blocks as computed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
...
...
@@ -191,6 +210,9 @@ class NaiveBlockAllocator(BlockAllocator):
"""
return
[]
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
raise
NotImplementedError
class
NaiveBlock
(
Block
):
"""An implementation of the Block class that does not support prefix
...
...
@@ -215,13 +237,13 @@ class NaiveBlock(Block):
"""
def
__init__
(
self
,
prev_block
:
Block
,
prev_block
:
Optional
[
Block
]
,
token_ids
:
List
[
int
],
block_size
:
int
,
allocator
:
BlockAllocator
,
block_id
:
Optional
[
int
]
=
None
,
_cow_target
:
Optional
[
Block
]
=
None
):
self
.
_token_ids
=
[]
self
.
_token_ids
:
List
[
int
]
=
[]
self
.
_block_size
=
block_size
self
.
_prev_block
=
prev_block
self
.
_block_id
=
block_id
...
...
@@ -247,6 +269,22 @@ class NaiveBlock(Block):
assert
self
.
num_empty_slots
>=
len
(
token_ids
)
self
.
_token_ids
.
extend
(
token_ids
)
@
property
def
computed
(
self
)
->
bool
:
raise
NotImplementedError
@
computed
.
setter
def
computed
(
self
,
value
)
->
None
:
raise
NotImplementedError
@
property
def
last_accessed
(
self
)
->
float
:
raise
NotImplementedError
@
last_accessed
.
setter
def
last_accessed
(
self
,
last_accessed_ts
:
float
):
raise
NotImplementedError
@
property
def
block_id
(
self
)
->
Optional
[
int
]:
return
self
.
_block_id
...
...
@@ -267,9 +305,14 @@ class NaiveBlock(Block):
def
token_ids
(
self
)
->
List
[
int
]:
return
self
.
_token_ids
@
property
def
block_size
(
self
)
->
int
:
return
self
.
_block_size
@
property
def
prev_block
(
self
)
->
Optional
[
"Block"
]:
return
self
.
_prev_block
@
property
def
content_hash
(
self
)
->
Optional
[
int
]:
return
None
vllm/core/block/prefix_caching_block.py
View file @
1591c68f
"""Token blocks."""
from
itertools
import
takewhile
from
os.path
import
commonprefix
from
typing
import
Dict
,
Iterable
,
List
,
Optional
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
get_all_blocks_recursively
)
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
from
vllm.core.block.naive_block
import
NaiveBlock
,
NaiveBlockAllocator
from
vllm.core.evictor_v2
import
EvictionPolicy
,
Evictor
,
make_evictor
PrefixHash
=
int
BlockId
=
int
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
# then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME
=
-
1
class
PrefixCachingBlockAllocator
(
BlockAllocator
):
...
...
@@ -27,26 +32,23 @@ class PrefixCachingBlockAllocator(BlockAllocator):
from 0 to num_blocks - 1.
"""
# TODO last access time / evictor integration
def
__init__
(
self
,
num_blocks
:
int
,
block_size
:
int
,
block_ids
:
Optional
[
Iterable
[
int
]]
=
None
,
eviction_policy
:
EvictionPolicy
=
EvictionPolicy
.
LRU
,
):
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0.
self
.
_cached_blocks
:
Dict
[
PrefixHash
,
BlockId
]
=
{}
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash AND refcount 0 will be in this dict. Thus, it is a subset
# of self._cached_blocks.
self
.
_unused_cached_blocks
:
Dict
[
PrefixHash
,
BlockId
]
=
{}
# A mapping of blockId to Block to track those cached blocks
self
.
_blocks
:
Dict
[
BlockId
,
Block
]
=
{}
# An allocator for blocks that do not have prefix hashes.
self
.
_hashless_allocator
=
NaiveBlockAllocator
(
create_block
=
self
.
_create_block
,
create_block
=
self
.
_create_block
,
# type: ignore
num_blocks
=
num_blocks
,
block_size
=
block_size
,
block_ids
=
block_ids
,
...
...
@@ -54,6 +56,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self
.
_block_size
=
block_size
# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
self
.
evictor
:
Evictor
=
make_evictor
(
eviction_policy
)
# We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable
# blocks.
...
...
@@ -72,6 +78,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_size
:
int
,
allocator
:
BlockAllocator
,
block_id
:
Optional
[
int
]
=
None
,
computed
:
bool
=
False
,
)
->
Block
:
# Bind block to self.
allocator
=
self
...
...
@@ -82,10 +89,13 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_size
=
block_size
,
block_id
=
block_id
,
prefix_caching_allocator
=
allocator
,
computed
=
computed
,
)
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
])
->
Block
:
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates an immutable block with the given token IDs, reusing cached
blocks if possible.
...
...
@@ -96,6 +106,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Returns:
Block: The allocated immutable block.
"""
assert
device
is
None
assert_prefix_caching_block_or_none
(
prev_block
)
block
=
self
.
_create_block
(
...
...
@@ -109,65 +120,95 @@ class PrefixCachingBlockAllocator(BlockAllocator):
cached_block_id
=
self
.
_cached_blocks
.
get
(
block
.
content_hash
,
None
)
if
cached_block_id
is
not
None
:
block
.
block_id
=
cached_block_id
self
.
_incr_refcount_cached_block
(
block
.
content_hash
,
block
.
block_id
)
self
.
_incr_refcount_cached_block
(
block
,
block
.
block_id
)
return
block
block
=
self
.
allocate_mutable
(
prev_block
)
block
.
append_token_ids
(
token_ids
)
assert
block
.
content_hash
is
not
None
# TODO computed bit
return
block
def
allocate_mutable
(
self
,
prev_block
:
Block
)
->
Block
:
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks.
Args:
prev_block (Block): The previous block in the sequence.
None is not allowed unlike it is super class.
Returns:
Block: The allocated mutable block.
"""
assert
device
is
None
assert_prefix_caching_block_or_none
(
prev_block
)
try
:
return
self
.
_hashless_allocator
.
allocate_mutable
(
block
=
self
.
_hashless_allocator
.
allocate_mutable
(
prev_block
=
prev_block
)
assert
block
.
block_id
not
in
self
.
_blocks
assert
block
.
block_id
is
not
None
self
.
_blocks
[
block
.
block_id
]
=
block
return
block
except
BlockAllocator
.
NoFreeBlocksError
:
# We must check the unused cached blocks before raising OOM.
pass
if
self
.
_unused_cached_blocks
:
# TODO policy for selecting block to remove
content_hash_to_evict
=
next
(
iter
(
self
.
_unused_cached_blocks
))
# If the evictor has blocks available for eviction, evict a block
# and return it.
if
self
.
evictor
.
num_blocks
>
0
:
block_id
,
content_hash_to_evict
=
self
.
evictor
.
evict
()
# Here we may have scenario that several blocks have
# the same content hash, but due to the latter coming block
# is coming from mutable to immutable path, their physical
# block is added into evictor.
# However in this case, we shall not pop the _cached_blocks,
# as the same content is still used by others, which means
# we need to check ref before decide to pop the list.
# Clear content hash mapping; the block will be overwritten.
del
self
.
_cached_blocks
[
content_hash_to_evict
]
_block_id
=
self
.
_cached_blocks
[
content_hash_to_evict
]
refcount
=
self
.
_refcounter
.
get
(
_block_id
)
if
refcount
==
1
:
self
.
_cached_blocks
.
pop
(
content_hash_to_evict
)
assert
_block_id
==
block_id
block_id
=
self
.
_unused_cached_blocks
.
pop
(
content_hash_to_evict
)
refcount
=
self
.
_refcounter
.
incr
(
block_id
)
assert
ref
co
u
nt
==
1
self
.
_refcounter
.
incr
(
block_id
)
# the block comes from evictor already
cont
ain computed result
block
=
self
.
_create_block
(
prev_block
=
prev_block
,
token_ids
=
[],
block_size
=
self
.
_block_size
,
allocator
=
self
,
block_id
=
block_id
,
computed
=
True
,
)
assert
block
.
content_hash
is
None
assert
block
.
block_id
not
in
self
.
_blocks
assert
block
.
block_id
is
not
None
self
.
_blocks
[
block
.
block_id
]
=
block
return
block
# No block available in hashless allocator, nor in unused cache blocks.
raise
BlockAllocator
.
NoFreeBlocksError
()
def
_incr_refcount_cached_block
(
self
,
content_hash
:
int
,
def
_incr_refcount_cached_block
(
self
,
block
:
Block
,
block_id
:
BlockId
)
->
None
:
# since block is already computed, mark it
block
.
computed
=
True
refcount
=
self
.
_refcounter
.
incr
(
block_id
)
if
refcount
==
1
:
assert
content_hash
in
self
.
_unused_cached_blocks
del
self
.
_unused_cached_blocks
[
content_hash
]
# if block get referred, then it shall not be in evictor
# and put it into _blocks for tracking
if
block_id
in
self
.
evictor
:
self
.
evictor
.
remove
(
block_id
)
self
.
_blocks
[
block_id
]
=
block
def
free
(
self
,
block
:
Block
)
->
None
:
"""Decrement the refcount of the block. If the decremented refcount is
...
...
@@ -180,6 +221,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
is
not
None
),
"freeing unallocated block is undefined"
self
.
_free_block_id_for_block
(
block
.
block_id
,
block
)
block
.
block_id
=
None
def
_free_block_id_for_block
(
self
,
block_id
:
BlockId
,
...
...
@@ -187,15 +229,23 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert
isinstance
(
block
,
PrefixCachingBlock
)
if
block
.
content_hash
is
None
:
refcount
=
self
.
_refcounter
.
get
(
block_id
)
# We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1
if
refcount
<=
1
:
assert
block
.
block_id
is
not
None
del
self
.
_blocks
[
block
.
block_id
]
return
self
.
_hashless_allocator
.
free
(
block
)
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
# If no longer used, add the block to the
unused cached blocks
.
# If no longer used, add the block to the
evictor
.
if
refcount
==
0
:
assert
block
.
content_hash
not
in
self
.
_unused_cached_blocks
assert
block
.
content_hash
in
self
.
_cached_blocks
self
.
_unused_cached_blocks
[
block
.
content_hash
]
=
block_id
assert
block
.
block_id
is
not
None
del
self
.
_blocks
[
block
.
block_id
]
self
.
evictor
.
add
(
block
.
block_id
,
block
.
content_hash
,
block
.
num_tokens_total
,
block
.
last_accessed
)
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
"""Creates a new sequence of blocks that shares the same underlying
...
...
@@ -228,18 +278,21 @@ class PrefixCachingBlockAllocator(BlockAllocator):
return
forked_blocks
def
get_num_free_blocks
(
self
)
->
int
:
def
get_num_free_blocks
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
int
:
assert
device
is
None
# The number of free blocks is the number of hashless free blocks
# plus the number of hashful blocks that are unused.
return
self
.
_hashless_allocator
.
get_num_free_blocks
()
+
len
(
self
.
_unused_cached_blocks
)
# plus the number of blocks evictor could free from its list.
return
self
.
_hashless_allocator
.
get_num_free_blocks
(
)
+
self
.
evictor
.
num_blocks
def
get_num_total_blocks
(
self
)
->
int
:
return
self
.
_hashless_allocator
.
get_num_total_blocks
()
@
property
def
all_block_ids
(
self
)
->
f
rozen
s
et
[
int
]:
def
all_block_ids
(
self
)
->
F
rozen
S
et
[
int
]:
return
self
.
_hashless_allocator
.
all_block_ids
def
promote_to_immutable_block
(
self
,
block
:
"PrefixCachingBlock"
)
->
BlockId
:
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
"""Once a mutable block is full, it can be promoted to an immutable
block. This means that its content can be referenced by future blocks
having the same prefix.
...
...
@@ -249,7 +302,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block.
Args:
block
(PrefixCachingBlock)
: The mutable block to be promoted.
block: The mutable block to be promoted.
Returns:
BlockId: Either the original block index, or the block index of
...
...
@@ -266,7 +319,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
else
:
self
.
_free_block_id_for_block
(
block
.
block_id
,
block
)
self
.
_incr_refcount_cached_block
(
block
.
content_hash
,
self
.
_cached_blocks
[
block
.
content_hash
])
block
,
self
.
_cached_blocks
[
block
.
content_hash
])
return
self
.
_cached_blocks
[
block
.
content_hash
]
...
...
@@ -293,29 +346,63 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"""
return
self
.
_cow_tracker
.
clear_cows
()
def
mark_blocks_as_computed
(
self
)
->
None
:
def
mark_blocks_as_accessed
(
self
,
block_ids
:
List
[
int
],
now
:
float
)
->
None
:
"""Mark blocks as accessed, used in prefix caching.
If the block is added into evictor, we need to update corresponding
info in evictor's metadata.
"""
for
block_id
in
block_ids
:
if
block_id
in
self
.
_blocks
:
self
.
_blocks
[
block_id
].
last_accessed
=
now
elif
block_id
in
self
.
evictor
:
self
.
evictor
.
update
(
block_id
,
now
)
else
:
raise
ValueError
(
"Mark block as accessed which is not belonged to GPU"
)
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
"""Mark blocks as computed, used in prefix caching."""
# TODO Track computed blocks.
pass
for
block_id
in
block_ids
:
if
block_id
in
self
.
_blocks
:
# only those full block is valid for prefix caching
if
self
.
_blocks
[
block_id
].
is_full
:
self
.
_blocks
[
block_id
].
computed
=
True
elif
block_id
not
in
self
.
evictor
:
raise
ValueError
(
f
"Mark
{
block_id
=
}
as computed which "
"is not belonged to GPU"
)
def
block_is_computed
(
self
,
block_id
:
int
)
->
bool
:
if
block_id
in
self
.
_blocks
:
return
self
.
_blocks
[
block_id
].
computed
else
:
return
block_id
in
self
.
evictor
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
Only those blocks that are immutable and already be marked
compyted would be taken consideration.
"""
# TODO: Track computed blocks.
computed
=
lambda
block_id
:
False
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
ids_list
=
[
takewhile
(
lambda
block_id
:
computed
(
block_id
),
seq
[:
-
1
])
for
seq
in
seq_block_ids
list
(
takewhile
(
lambda
block_id
:
self
.
block_is_computed
(
block_id
),
seq
[:
-
1
]))
for
seq
in
seq_block_ids
]
return
commonprefix
([
ids
for
ids
in
ids_list
if
ids
!=
[]])
# It returns a list of int although type annotation says list of string.
return
commonprefix
([
ids
for
ids
in
ids_list
# type: ignore
if
ids
!=
[]
])
class
PrefixCachingBlock
(
Block
):
...
...
@@ -332,7 +419,7 @@ class PrefixCachingBlock(Block):
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
prefix_caching_allocator (
PrefixCaching
BlockAllocator): The prefix
prefix_caching_allocator (BlockAllocator): The prefix
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
...
...
@@ -340,17 +427,25 @@ class PrefixCachingBlock(Block):
def
__init__
(
self
,
prev_block
:
Optional
[
"PrefixCaching
Block
"
],
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
block_size
:
int
,
prefix_caching_allocator
:
PrefixCaching
BlockAllocator
,
prefix_caching_allocator
:
BlockAllocator
,
block_id
:
Optional
[
int
]
=
None
,
computed
:
bool
=
False
,
):
assert
isinstance
(
prefix_caching_allocator
,
PrefixCachingBlockAllocator
),
(
"Currently this class is only tested with "
"PrefixCachingBlockAllocator."
)
assert_prefix_caching_block_or_none
(
prev_block
)
self
.
_prev_block
=
prev_block
self
.
_cached_content_hash
:
Optional
[
int
]
=
None
self
.
_cached_num_tokens_total
:
Optional
[
int
]
=
None
self
.
_prefix_caching_allocator
=
prefix_caching_allocator
self
.
_last_accessed
:
float
=
_DEFAULT_LAST_ACCESSED_TIME
self
.
_computed
=
computed
self
.
_block
=
NaiveBlock
(
prev_block
=
prev_block
,
...
...
@@ -361,6 +456,22 @@ class PrefixCachingBlock(Block):
_cow_target
=
self
,
)
@
property
def
computed
(
self
)
->
bool
:
return
self
.
_computed
@
computed
.
setter
def
computed
(
self
,
value
)
->
None
:
self
.
_computed
=
value
@
property
def
last_accessed
(
self
)
->
float
:
return
self
.
_last_accessed
@
last_accessed
.
setter
def
last_accessed
(
self
,
last_accessed_ts
:
float
):
self
.
_last_accessed
=
last_accessed_ts
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
"""Appends the given token IDs to the block and registers the block as
immutable if the block becomes full.
...
...
@@ -398,6 +509,27 @@ class PrefixCachingBlock(Block):
def
num_empty_slots
(
self
)
->
int
:
return
self
.
_block
.
num_empty_slots
@
property
def
num_tokens_total
(
self
)
->
int
:
"""return the total tokens so far.
Here we iterate the block chain till to the first block, while
cache the result in local to prevent repeated computations.
"""
if
self
.
_cached_num_tokens_total
is
not
None
:
return
self
.
_cached_num_tokens_total
_block
:
Optional
[
Block
]
=
self
self
.
_cached_num_tokens_total
=
0
# TODO: current implement here take O(N^2), we expect future
# we have O(1) here
while
_block
is
not
None
:
self
.
_cached_num_tokens_total
+=
len
(
_block
.
token_ids
)
_block
=
_block
.
prev_block
return
self
.
_cached_num_tokens_total
@
property
def
block_size
(
self
)
->
int
:
return
self
.
_block
.
block_size
...
...
@@ -428,8 +560,10 @@ class PrefixCachingBlock(Block):
return
None
is_first_block
=
self
.
_prev_block
is
None
prev_block_hash
=
(
None
if
is_first_block
else
self
.
_prev_block
.
content_hash
)
prev_block_hash
=
(
None
if
is_first_block
else
self
.
_prev_block
.
content_hash
# type: ignore
)
# Previous block exists but does not yet have a hash.
# Return no hash in this case.
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment