Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1591c68f
Commit
1591c68f
authored
May 25, 2024
by
zhuwenwen
Browse files
merge v0.4.2
parents
09bcf00b
c7f2cf2b
Changes
265
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1080 additions
and
359 deletions
+1080
-359
vllm/__init__.py
vllm/__init__.py
+2
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+62
-12
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+9
-4
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+23
-22
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+220
-0
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+55
-58
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+18
-18
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+33
-33
vllm/attention/layer.py
vllm/attention/layer.py
+7
-0
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+19
-18
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+56
-19
vllm/attention/ops/triton_flash_attention.py
vllm/attention/ops/triton_flash_attention.py
+11
-13
vllm/attention/selector.py
vllm/attention/selector.py
+13
-9
vllm/config.py
vllm/config.py
+149
-58
vllm/core/block/block_table.py
vllm/core/block/block_table.py
+10
-6
vllm/core/block/common.py
vllm/core/block/common.py
+16
-4
vllm/core/block/cpu_gpu_block_allocator.py
vllm/core/block/cpu_gpu_block_allocator.py
+35
-13
vllm/core/block/interfaces.py
vllm/core/block/interfaces.py
+100
-5
vllm/core/block/naive_block.py
vllm/core/block/naive_block.py
+53
-10
vllm/core/block/prefix_caching_block.py
vllm/core/block/prefix_caching_block.py
+189
-55
No files found.
vllm/__init__.py
View file @
1591c68f
...
@@ -3,14 +3,14 @@
...
@@ -3,14 +3,14 @@
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.ray_utils
import
initialize_ray_cluster
from
vllm.entrypoints.llm
import
LLM
from
vllm.entrypoints.llm
import
LLM
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.version
import
__dcu_version__
from
vllm.version
import
__dcu_version__
__version__
=
"0.4.
1
"
__version__
=
"0.4.
2
"
__all__
=
[
__all__
=
[
"LLM"
,
"LLM"
,
...
...
vllm/_custom_ops.py
View file @
1591c68f
...
@@ -39,17 +39,17 @@ def paged_attention_v1(
...
@@ -39,17 +39,17 @@ def paged_attention_v1(
num_kv_heads
:
int
,
num_kv_heads
:
int
,
scale
:
float
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context
_lens
:
torch
.
Tensor
,
seq
_lens
:
torch
.
Tensor
,
block_size
:
int
,
block_size
:
int
,
max_
context
_len
:
int
,
max_
seq
_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
kv_scale
:
float
,
)
->
None
:
)
->
None
:
vllm_ops
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
vllm_ops
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
context_lens
,
block_size
,
max_
context_len
,
block_size
,
max_
seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
)
kv_cache_dtype
,
kv_scale
)
def
paged_attention_v2
(
def
paged_attention_v2
(
...
@@ -63,17 +63,17 @@ def paged_attention_v2(
...
@@ -63,17 +63,17 @@ def paged_attention_v2(
num_kv_heads
:
int
,
num_kv_heads
:
int
,
scale
:
float
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context
_lens
:
torch
.
Tensor
,
seq
_lens
:
torch
.
Tensor
,
block_size
:
int
,
block_size
:
int
,
max_
context
_len
:
int
,
max_
seq
_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
kv_scale
:
float
,
)
->
None
:
)
->
None
:
vllm_ops
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
vllm_ops
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
context
_lens
,
block_size
,
block_tables
,
seq
_lens
,
block_size
,
max_
context
_len
,
alibi_slopes
,
kv_cache_dtype
,
max_
seq
_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
)
kv_scale
)
...
@@ -153,11 +153,49 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -153,11 +153,49 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n
,
size_k
)
size_n
,
size_k
)
# aqlm
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
vllm_ops
.
aqlm_gemm
(
input
,
codes
,
codebooks
,
scales
,
codebook_partition_sizes
,
bias
)
def
aqlm_dequant
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
vllm_ops
.
aqlm_dequant
(
codes
,
codebooks
,
codebook_partition_sizes
)
# gptq_marlin
def
gptq_marlin_repack
(
b_q_weight
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_repack
(
b_q_weight
,
perm
,
size_k
,
size_n
,
num_bits
)
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
g_idx
,
perm
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
,
is_k_full
)
# fp8
# fp8
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
scaled_fp8_quant
(
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
float8_e4m3fn
)
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
float8_e4m3fn
)
vllm_ops
.
scaled_fp8_quant
(
output
,
input
,
scale
)
if
scale
is
None
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
vllm_ops
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
vllm_ops
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
return
output
,
scale
...
@@ -184,6 +222,18 @@ def reshape_and_cache(
...
@@ -184,6 +222,18 @@ def reshape_and_cache(
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
def
reshape_and_cache_flash
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
)
->
None
:
vllm_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
)
def
copy_blocks
(
key_caches
:
torch
.
Tensor
,
value_caches
:
torch
.
Tensor
,
def
copy_blocks
(
key_caches
:
torch
.
Tensor
,
value_caches
:
torch
.
Tensor
,
block_mapping
:
torch
.
Tensor
)
->
None
:
block_mapping
:
torch
.
Tensor
)
->
None
:
vllm_cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
vllm_cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
...
...
vllm/attention/backends/abstract.py
View file @
1591c68f
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
typing
import
Any
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
from
typing
import
(
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
)
import
torch
import
torch
...
@@ -15,7 +16,7 @@ class AttentionBackend(ABC):
...
@@ -15,7 +16,7 @@ class AttentionBackend(ABC):
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata
PerStage
"
:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
@
staticmethod
...
@@ -50,13 +51,17 @@ class AttentionBackend(ABC):
...
@@ -50,13 +51,17 @@ class AttentionBackend(ABC):
class
AttentionMetadataPerStage
:
class
AttentionMetadataPerStage
:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
"""Attention metadata for a specific stage. I.e., prefill or decode."""
def
asdict_zerocopy
(
self
)
->
Dict
[
str
,
Any
]:
def
asdict_zerocopy
(
self
,
skip_fields
:
Optional
[
Set
[
str
]]
=
None
)
->
Dict
[
str
,
Any
]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if
skip_fields
is
None
:
skip_fields
=
set
()
# Note that if we add dataclasses as fields, they will need
# Note that if we add dataclasses as fields, they will need
# similar handling.
# similar handling.
return
{
return
{
field
.
name
:
getattr
(
self
,
field
.
name
)
field
.
name
:
getattr
(
self
,
field
.
name
)
for
field
in
fields
(
self
)
for
field
in
fields
(
self
)
if
field
.
name
not
in
skip_fields
}
}
...
...
vllm/attention/backends/flash_attn.py
View file @
1591c68f
...
@@ -66,27 +66,24 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
...
@@ -66,27 +66,24 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# Currently, input sequences can only contain all prompts
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
is_prompt
:
bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
# (batch_size,). The sequence length per sequence. Sequence length means
prompt_lens
:
Optional
[
List
[
int
]]
# the computed tokens + new tokens None if it is a decoding.
# prompt_lens stored as a tensor.
seq_lens
:
Optional
[
List
[
int
]]
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len,
sub
query_len, and seqlen.
# NOTE(sang): Definition of context_len, query_len, and seq
_
len.
# |---------- N-1 iteration --------|
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |-------------------- seq
_
len ----------------------|
# |-
sub
query_len -|
# |-
-
query_len
--
-|
# WARNING(sang): context_len has different definition depending on if it is
# Maximum query length in the batch.
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
max_query_len
:
Optional
[
int
]
# When it is for decoding, it includes a new token.
# Maximum sequence length in the batch.
max_seq_len
:
Optional
[
int
]
# Maximum subquery length in the batch.
max_subquery_len
:
Optional
[
int
]
# Maximum prompt length in the batch.
max_prompt_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
# is [4, 6], it is [0, 4, 10].
...
@@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
...
@@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# the batch, used to index into sequence. E.g., if the sequence length is
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# Cuda-graph is currently enabled for decoding only.
...
@@ -223,8 +223,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -223,8 +223,8 @@ class FlashAttentionImpl(AttentionImpl):
v
=
value
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_
prompt
_len
,
max_seqlen_q
=
prefill_meta
.
max_
seq
_len
,
max_seqlen_k
=
prefill_meta
.
max_
prompt
_len
,
max_seqlen_k
=
prefill_meta
.
max_
seq
_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
...
@@ -245,10 +245,11 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -245,10 +245,11 @@ class FlashAttentionImpl(AttentionImpl):
value_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
block_tables
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
prompt
_lens_tensor
,
prefill_meta
.
seq
_lens_tensor
,
prefill_meta
.
context_lens
,
prefill_meta
.
context_lens
_tensor
,
prefill_meta
.
max_
sub
query_len
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
)
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Decoding run.
...
@@ -257,8 +258,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -257,8 +258,8 @@ class FlashAttentionImpl(AttentionImpl):
key_cache
,
key_cache
,
value_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
block_tables
,
decode_meta
.
context_lens
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_
context
_len
,
decode_meta
.
max_
seq
_len
,
attn_metadata
.
kv_cache_dtype
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
...
...
vllm/attention/backends/flashinfer.py
0 → 100644
View file @
1591c68f
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
try
:
import
flashinfer
from
flash_attn
import
flash_attn_varlen_func
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
except
ImportError
:
flashinfer
=
None
flash_attn_varlen_func
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadataPerStage
)
class
FlashInferBackend
(
AttentionBackend
):
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashInferImpl"
]:
return
FlashInferImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"FlashInferMetadata"
:
return
FlashInferMetadata
(
*
args
,
**
kwargs
)
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
raise
NotImplementedError
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
raise
NotImplementedError
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
128
,
256
]
@
dataclass
class
FlashInferMetadata
(
AttentionMetadataPerStage
):
is_prompt
:
bool
use_cuda_graph
:
bool
=
False
decode_wrapper
:
Optional
[
BatchDecodeWithPagedKVCacheWrapper
]
=
None
# Metadata for the prefill stage since we still
# use flash attention for prefill.
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
max_seq_len
:
Optional
[
int
]
=
None
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer
:
Optional
[
torch
.
Tensor
]
=
None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
# The page indices of the paged kv cache
paged_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len
:
Optional
[
torch
.
Tensor
]
=
None
# The number of query/output heads
num_qo_heads
:
Optional
[
int
]
=
None
# The number of key/value heads
num_kv_heads
:
Optional
[
int
]
=
None
# The dimension of the attention heads
head_dim
:
Optional
[
int
]
=
None
# Block size of vllm
page_size
:
Optional
[
int
]
=
None
# The data type of the paged kv cache
data_type
:
torch
.
dtype
=
None
def
__post_init__
(
self
):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
supported_head_sizes
=
FlashInferBackend
.
get_supported_head_sizes
()
if
self
.
head_dim
is
not
None
and
self
.
head_dim
\
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Only
{
supported_head_sizes
}
are supported for head_dim,"
,
f
"received
{
self
.
head_dim
}
."
)
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if
not
self
.
is_prompt
:
self
.
decode_wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
self
.
decode_wrapper
.
begin_forward
(
self
.
paged_kv_indptr
,
self
.
paged_kv_indices
,
self
.
paged_kv_last_page_len
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
data_type
)
def
asdict_zerocopy
(
self
,
skip_fields
:
Optional
[
Set
[
str
]]
=
None
)
->
Dict
[
str
,
Any
]:
if
skip_fields
is
None
:
skip_fields
=
set
()
# We need to skip the decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields
.
add
(
'decode_wrapper'
)
return
super
().
asdict_zerocopy
(
skip_fields
)
class
FlashInferImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
if
sliding_window
is
not
None
:
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
alibi_slopes
=
alibi_slopes
self
.
scale
=
scale
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
[
FlashInferMetadata
],
kv_scale
:
float
):
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
attn_metadata
.
num_prefill_tokens
>
0
:
assert
attn_metadata
.
num_decode_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
attn_metadata
.
num_decode_tokens
>
0
:
assert
attn_metadata
.
num_prefill_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
kv_cache
is
not
None
:
# Use the same reshape and cache kernel as flash attention.
ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[:,
0
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
attn_metadata
.
kv_cache_dtype
,
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
assert
prefill_meta
.
block_tables
is
not
None
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
)
else
:
raise
NotImplementedError
(
"Prefix caching is not supported with flashinfer yet."
)
else
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
query
=
query
.
contiguous
(
)
# Flashinfer requires query to be contiguous
output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
query
,
kv_cache
,
sm_scale
=
self
.
scale
,
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/rocm_flash_attn.py
View file @
1591c68f
"""Attention layer ROCm GPUs."""
"""Attention layer ROCm GPUs."""
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadata
,
AttentionMetadataPerStage
)
AttentionMetadataPerStage
)
...
@@ -64,27 +64,24 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
...
@@ -64,27 +64,24 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# Currently, input sequences can only contain all prompts
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
is_prompt
:
bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
# (batch_size,). The sequence length per sequence. Sequence length means
prompt_lens
:
Optional
[
List
[
int
]]
# the computed tokens + new tokens None if it is a decoding.
# prompt_lens stored as a tensor.
seq_lens
:
Optional
[
List
[
int
]]
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len,
sub
query_len, and seqlen.
# NOTE(sang): Definition of context_len, query_len, and seq
_
len.
# |---------- N-1 iteration --------|
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |-------------------- seq
_
len ----------------------|
# |-
sub
query_len -|
# |-
-
query_len
--
-|
# WARNING(sang): context_len has different definition depending on if it is
# Maximum query length in the batch.
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
max_query_len
:
Optional
[
int
]
# When it is for decoding, it includes a new token.
# Maximum sequence length in the batch.
max_seq_len
:
Optional
[
int
]
# Maximum subquery length in the batch.
max_subquery_len
:
Optional
[
int
]
# Maximum prompt length in the batch.
max_prompt_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
# is [4, 6], it is [0, 4, 10].
...
@@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
...
@@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# Cuda-graph is currently enabled for decoding only.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
use_cuda_graph
:
bool
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
class
ROCmFlashAttentionImpl
(
AttentionImpl
):
class
ROCmFlashAttentionImpl
(
AttentionImpl
):
...
@@ -156,8 +156,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -156,8 +156,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
use_naive_attn
=
False
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
(
os
.
environ
.
get
(
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
"VLLM_USE_TRITON_FLASH_ATTN"
,
"True"
).
lower
()
in
(
"true"
,
"1"
))
if
self
.
use_triton_flash_attn
:
if
self
.
use_triton_flash_attn
:
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
triton_attention
)
triton_attention
)
...
@@ -248,41 +247,36 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -248,41 +247,36 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
# Prompt run.
assert
prefill_meta
.
prompt
_lens
is
not
None
assert
prefill_meta
.
seq
_lens
is
not
None
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
# triton attention
# triton attention
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
if
self
.
use_triton_flash_attn
or
self
.
use_naive_attn
:
if
self
.
use_triton_flash_attn
:
out
,
_
=
self
.
attn_func
(
query
,
key
,
value
,
None
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_seq_len
,
prefill_meta
.
max_seq_len
,
True
,
self
.
scale
,
)
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Interleave for MQA workaround.
# Interleave for MQA workaround.
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
if
self
.
use_naive_attn
:
out
=
self
.
attn_func
(
out
=
self
.
attn_func
(
query
,
query
,
key
,
key
,
value
,
value
,
prefill_meta
.
seq_lens
,
prefill_meta
.
prompt_lens
,
self
.
scale
,
self
.
scale
,
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
out
,
_
=
self
.
attn_func
(
query
,
key
,
value
,
None
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_prompt_len
,
prefill_meta
.
max_prompt_len
,
True
,
self
.
scale
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
else
:
out
=
self
.
attn_func
(
out
=
self
.
attn_func
(
q
=
query
,
q
=
query
,
...
@@ -290,13 +284,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -290,13 +284,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
v
=
value
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_
prompt
_len
,
max_seqlen_q
=
prefill_meta
.
max_
seq
_len
,
max_seqlen_k
=
prefill_meta
.
max_
prompt
_len
,
max_seqlen_k
=
prefill_meta
.
max_
seq
_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
# common code for prefill
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
output
[:
num_prefill_tokens
]
=
PagedAttention
.
forward_prefix
(
output
[:
num_prefill_tokens
]
=
PagedAttention
.
forward_prefix
(
...
@@ -307,10 +303,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -307,10 +303,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
block_tables
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
prompt
_lens_tensor
,
prefill_meta
.
seq
_lens_tensor
,
prefill_meta
.
context_lens
,
prefill_meta
.
context_lens
_tensor
,
prefill_meta
.
max_
sub
query_len
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
)
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
@@ -320,8 +317,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -320,8 +317,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache
,
key_cache
,
value_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
block_tables
,
decode_meta
.
context_lens
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_
context
_len
,
decode_meta
.
max_
seq
_len
,
attn_metadata
.
kv_cache_dtype
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
...
@@ -337,13 +334,13 @@ def _naive_attention(
...
@@ -337,13 +334,13 @@ def _naive_attention(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
scale
:
float
,
scale
:
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
start
=
0
start
=
0
for
_
,
prompt
_len
in
enumerate
(
prompt
_lens
):
for
_
,
seq
_len
in
enumerate
(
seq
_lens
):
end
=
start
+
prompt
_len
end
=
start
+
seq
_len
out
=
_naive_masked_attention
(
out
=
_naive_masked_attention
(
query
[
start
:
end
],
query
[
start
:
end
],
key
[
start
:
end
],
key
[
start
:
end
],
...
@@ -352,7 +349,7 @@ def _naive_attention(
...
@@ -352,7 +349,7 @@ def _naive_attention(
)
)
# TODO(woosuk): Unnecessary copy. Optimize.
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
)
output
[
start
:
end
].
copy_
(
out
)
start
+=
prompt
_len
start
+=
seq
_len
return
output
return
output
...
...
vllm/attention/backends/torch_sdpa.py
View file @
1591c68f
...
@@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
...
@@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
# or all decoding. True if all sequences are prompts.
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
is_prompt
:
bool
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
prompt
_lens
:
Optional
[
List
[
int
]]
seq
_lens
:
Optional
[
List
[
int
]]
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# Set during the execution of the first attention op.
...
@@ -136,7 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
...
@@ -136,7 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
kv_scale
)
kv_scale
)
if
attn_metadata
.
is_prompt
:
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
prompt
_lens
is
not
None
assert
attn_metadata
.
seq
_lens
is
not
None
if
(
kv_cache
is
None
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
(
kv_cache
is
None
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
...
@@ -147,13 +147,13 @@ class TorchSDPABackendImpl(AttentionImpl):
...
@@ -147,13 +147,13 @@ class TorchSDPABackendImpl(AttentionImpl):
if
self
.
alibi_slopes
is
not
None
:
if
self
.
alibi_slopes
is
not
None
:
att_masks
=
_make_alibi_bias
(
att_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
prompt
_lens
)
# type: ignore
attn_metadata
.
seq
_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
elif
self
.
sliding_window
is
not
None
:
att_masks
=
_make_sliding_window_bias
(
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
prompt
_lens
,
self
.
sliding_window
,
attn_metadata
.
seq
_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
query
.
dtype
)
# type: ignore
else
:
else
:
att_masks
=
[
None
]
*
len
(
attn_metadata
.
prompt
_lens
)
att_masks
=
[
None
]
*
len
(
attn_metadata
.
seq
_lens
)
attn_metadata
.
attn_bias
=
att_masks
attn_metadata
.
attn_bias
=
att_masks
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
...
@@ -164,9 +164,9 @@ class TorchSDPABackendImpl(AttentionImpl):
...
@@ -164,9 +164,9 @@ class TorchSDPABackendImpl(AttentionImpl):
output
=
torch
.
empty
(
output
=
torch
.
empty
(
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
)
dtype
=
query
.
dtype
)
for
prompt
_len
,
mask
in
zip
(
attn_metadata
.
prompt
_lens
,
for
seq
_len
,
mask
in
zip
(
attn_metadata
.
seq
_lens
,
attn_metadata
.
attn_bias
):
attn_metadata
.
attn_bias
):
end
=
start
+
prompt
_len
end
=
start
+
seq
_len
sub_out
=
scaled_dot_product_attention
(
sub_out
=
scaled_dot_product_attention
(
query
[:,
start
:
end
,
:],
query
[:,
start
:
end
,
:],
key
[:,
start
:
end
,
:],
key
[:,
start
:
end
,
:],
...
@@ -189,8 +189,8 @@ class TorchSDPABackendImpl(AttentionImpl):
...
@@ -189,8 +189,8 @@ class TorchSDPABackendImpl(AttentionImpl):
key_cache
,
key_cache
,
value_cache
,
value_cache
,
attn_metadata
.
block_tables
,
attn_metadata
.
block_tables
,
attn_metadata
.
context_lens
,
attn_metadata
.
seq_lens_tensor
,
attn_metadata
.
max_
context
_len
,
attn_metadata
.
max_
seq
_len
,
attn_metadata
.
kv_cache_dtype
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
...
@@ -205,13 +205,13 @@ class TorchSDPABackendImpl(AttentionImpl):
...
@@ -205,13 +205,13 @@ class TorchSDPABackendImpl(AttentionImpl):
def
_make_alibi_bias
(
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
)
->
List
[
torch
.
Tensor
]:
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
attn_biases
=
[]
for
prompt
_len
in
prompt
_lens
:
for
seq
_len
in
seq
_lens
:
bias
=
torch
.
arange
(
prompt
_len
,
dtype
=
dtype
)
bias
=
torch
.
arange
(
seq
_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(
prompt
_len, 1)`
# `bias = bias[None, :].repeat(
seq
_len, 1)`
# here. We find that both biases give the same results, but
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# the bias below more accurately follows the original ALiBi
# paper.
# paper.
...
@@ -221,7 +221,7 @@ def _make_alibi_bias(
...
@@ -221,7 +221,7 @@ def _make_alibi_bias(
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
inf_mask
=
torch
.
empty
(
inf_mask
=
torch
.
empty
(
(
1
,
prompt_len
,
prompt
_len
),
(
1
,
seq_len
,
seq
_len
),
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
...
@@ -229,14 +229,14 @@ def _make_alibi_bias(
...
@@ -229,14 +229,14 @@ def _make_alibi_bias(
def
_make_sliding_window_bias
(
def
_make_sliding_window_bias
(
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
window_size
:
Optional
[
int
],
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
)
->
List
[
torch
.
Tensor
]:
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
attn_biases
=
[]
for
prompt
_len
in
prompt
_lens
:
for
seq
_len
in
seq
_lens
:
tensor
=
torch
.
full
(
tensor
=
torch
.
full
(
(
1
,
prompt_len
,
prompt
_len
),
(
1
,
seq_len
,
seq
_len
),
dtype
=
dtype
,
dtype
=
dtype
,
fill_value
=
1
,
fill_value
=
1
,
)
)
...
...
vllm/attention/backends/xformers.py
View file @
1591c68f
...
@@ -66,28 +66,24 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
...
@@ -66,28 +66,24 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# Currently, input sequences can only contain all prompts
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
is_prompt
:
bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
# (batch_size,). The sequence length per sequence. Sequence length means
prompt_lens
:
Optional
[
List
[
int
]]
# the computed tokens + new tokens None if it is a decoding.
# prompt_lens stored as a tensor.
seq_lens
:
Optional
[
List
[
int
]]
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |-------------------- seq
_
len ----------------------|
# |-
sub
query_len -|
# |-
-
query_len
--
-|
# WARNING(sang): context_len has different definition depending on if it is
# Maximum query length in the batch.
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
max_query_len
:
Optional
[
int
]
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len
:
Optional
[
int
]
# FIXME: It is for flash attn.
# FIXME: It is for flash attn.
# Maximum
prompt
length in the batch.
# Maximum
sequence
length in the batch.
max_
prompt
_len
:
Optional
[
int
]
max_
seq
_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
# is [4, 6], it is [0, 4, 10].
...
@@ -97,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
...
@@ -97,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# the batch, used to index into sequence. E.g., if the sequence length is
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# Cuda-graph is currently enabled for decoding only.
...
@@ -242,10 +241,11 @@ class XFormersImpl(AttentionImpl):
...
@@ -242,10 +241,11 @@ class XFormersImpl(AttentionImpl):
value_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
block_tables
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
prompt
_lens_tensor
,
prefill_meta
.
seq
_lens_tensor
,
prefill_meta
.
context_lens
,
prefill_meta
.
context_lens
_tensor
,
prefill_meta
.
max_
sub
query_len
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
sliding_window
,
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
output
[:
num_prefill_tokens
]
=
out
...
@@ -256,8 +256,8 @@ class XFormersImpl(AttentionImpl):
...
@@ -256,8 +256,8 @@ class XFormersImpl(AttentionImpl):
key_cache
,
key_cache
,
value_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
block_tables
,
decode_meta
.
context_lens
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_
context
_len
,
decode_meta
.
max_
seq
_len
,
attn_metadata
.
kv_cache_dtype
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
...
@@ -288,7 +288,7 @@ class XFormersImpl(AttentionImpl):
...
@@ -288,7 +288,7 @@ class XFormersImpl(AttentionImpl):
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
attn_metadata: Metadata for attention.
"""
"""
assert
attn_metadata
.
prompt
_lens
is
not
None
assert
attn_metadata
.
seq
_lens
is
not
None
original_query
=
query
original_query
=
query
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# GQA/MQA requires the shape [B, M, G, H, K].
# GQA/MQA requires the shape [B, M, G, H, K].
...
@@ -309,7 +309,7 @@ class XFormersImpl(AttentionImpl):
...
@@ -309,7 +309,7 @@ class XFormersImpl(AttentionImpl):
if
attn_metadata
.
attn_bias
is
None
:
if
attn_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
None
:
if
self
.
alibi_slopes
is
None
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
attn_metadata
.
prompt
_lens
)
attn_metadata
.
seq
_lens
)
if
self
.
sliding_window
is
not
None
:
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
self
.
sliding_window
)
...
@@ -317,7 +317,7 @@ class XFormersImpl(AttentionImpl):
...
@@ -317,7 +317,7 @@ class XFormersImpl(AttentionImpl):
else
:
else
:
attn_metadata
.
attn_bias
=
_make_alibi_bias
(
attn_metadata
.
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
query
.
dtype
,
self
.
alibi_slopes
,
self
.
num_kv_heads
,
query
.
dtype
,
attn_metadata
.
prompt
_lens
)
attn_metadata
.
seq
_lens
)
# No alibi slopes.
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
# TODO(woosuk): Too many view operations. Let's try to reduce
...
@@ -342,8 +342,8 @@ class XFormersImpl(AttentionImpl):
...
@@ -342,8 +342,8 @@ class XFormersImpl(AttentionImpl):
# one. This is inefficient, especially when we have many short prompts.
# one. This is inefficient, especially when we have many short prompts.
output
=
torch
.
empty_like
(
original_query
)
output
=
torch
.
empty_like
(
original_query
)
start
=
0
start
=
0
for
i
,
prompt
_len
in
enumerate
(
attn_metadata
.
prompt
_lens
):
for
i
,
seq
_len
in
enumerate
(
attn_metadata
.
seq
_lens
):
end
=
start
+
prompt
_len
end
=
start
+
seq
_len
out
=
xops
.
memory_efficient_attention_forward
(
out
=
xops
.
memory_efficient_attention_forward
(
query
[
None
,
start
:
end
],
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
...
@@ -353,7 +353,7 @@ class XFormersImpl(AttentionImpl):
...
@@ -353,7 +353,7 @@ class XFormersImpl(AttentionImpl):
scale
=
self
.
scale
)
scale
=
self
.
scale
)
# TODO(woosuk): Unnecessary copy. Optimize.
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
.
view_as
(
original_query
[
start
:
end
]))
output
[
start
:
end
].
copy_
(
out
.
view_as
(
original_query
[
start
:
end
]))
start
+=
prompt
_len
start
+=
seq
_len
return
output
return
output
...
@@ -361,13 +361,13 @@ def _make_alibi_bias(
...
@@ -361,13 +361,13 @@ def _make_alibi_bias(
alibi_slopes
:
torch
.
Tensor
,
alibi_slopes
:
torch
.
Tensor
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
prompt
_lens
:
List
[
int
],
seq
_lens
:
List
[
int
],
)
->
LowerTriangularMaskWithTensorBias
:
)
->
LowerTriangularMaskWithTensorBias
:
attn_biases
=
[]
attn_biases
=
[]
for
prompt
_len
in
prompt
_lens
:
for
seq
_len
in
seq
_lens
:
bias
=
torch
.
arange
(
prompt
_len
,
dtype
=
dtype
)
bias
=
torch
.
arange
(
seq
_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(
prompt
_len, 1)`
# `bias = bias[None, :].repeat(
seq
_len, 1)`
# here. We find that both biases give the same results, but
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# the bias below more accurately follows the original ALiBi
# paper.
# paper.
...
@@ -375,16 +375,16 @@ def _make_alibi_bias(
...
@@ -375,16 +375,16 @@ def _make_alibi_bias(
# element.
# element.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
padded_len
=
(
prompt
_len
+
7
)
//
8
*
8
padded_len
=
(
seq
_len
+
7
)
//
8
*
8
num_heads
=
alibi_slopes
.
shape
[
0
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
torch
.
empty
(
bias
=
torch
.
empty
(
1
,
# batch size
1
,
# batch size
num_heads
,
num_heads
,
prompt
_len
,
seq
_len
,
padded_len
,
padded_len
,
device
=
alibi_slopes
.
device
,
device
=
alibi_slopes
.
device
,
dtype
=
dtype
,
dtype
=
dtype
,
)[:,
:,
:,
:
prompt
_len
].
copy_
(
bias
)
)[:,
:,
:,
:
seq
_len
].
copy_
(
bias
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
if
num_heads
!=
num_kv_heads
:
if
num_heads
!=
num_kv_heads
:
bias
=
bias
.
unflatten
(
1
,
(
num_kv_heads
,
num_heads
//
num_kv_heads
))
bias
=
bias
.
unflatten
(
1
,
(
num_kv_heads
,
num_heads
//
num_kv_heads
))
...
...
vllm/attention/layer.py
View file @
1591c68f
...
@@ -47,3 +47,10 @@ class Attention(nn.Module):
...
@@ -47,3 +47,10 @@ class Attention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
kv_scale
)
kv_scale
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
s
+=
f
", num_heads=
{
self
.
impl
.
num_heads
}
"
# type: ignore
s
+=
f
", num_kv_heads=
{
self
.
impl
.
num_kv_heads
}
"
# type: ignore
s
+=
f
", scale=
{
self
.
impl
.
scale
}
"
# type: ignore
return
s
vllm/attention/ops/paged_attn.py
View file @
1591c68f
...
@@ -13,12 +13,11 @@ _PARTITION_SIZE = 512
...
@@ -13,12 +13,11 @@ _PARTITION_SIZE = 512
@
dataclass
@
dataclass
class
PagedAttentionMetadata
:
class
PagedAttentionMetadata
:
"""Metadata for PagedAttention."""
"""Metadata for PagedAttention."""
# (batch_size,). The length of context (tokens stored in KV cache) per
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# sequence.
# tokens. When it is for decoding, it includes a new token.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
context_lens
:
Optional
[
torch
.
Tensor
]
# Maximum sequence length in the batch.
# Maximum context length in the batch.
max_seq_len
:
Optional
[
int
]
max_context_len
:
Optional
[
int
]
# (batch_size, max_blocks_per_seq).
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
...
@@ -85,8 +84,8 @@ class PagedAttention:
...
@@ -85,8 +84,8 @@ class PagedAttention:
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context
_lens
:
torch
.
Tensor
,
seq
_lens
:
torch
.
Tensor
,
max_
context
_len
:
int
,
max_
seq
_len
:
int
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
scale
:
float
,
scale
:
float
,
...
@@ -97,7 +96,7 @@ class PagedAttention:
...
@@ -97,7 +96,7 @@ class PagedAttention:
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
((
max_
context
_len
+
_PARTITION_SIZE
-
1
)
//
max_num_partitions
=
((
max_
seq
_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
...
@@ -106,7 +105,7 @@ class PagedAttention:
...
@@ -106,7 +105,7 @@ class PagedAttention:
# to parallelize.
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
(
max_
context
_len
<=
8192
use_v1
=
(
max_
seq
_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
if
use_v1
:
if
use_v1
:
# Run PagedAttention V1.
# Run PagedAttention V1.
...
@@ -118,9 +117,9 @@ class PagedAttention:
...
@@ -118,9 +117,9 @@ class PagedAttention:
num_kv_heads
,
num_kv_heads
,
scale
,
scale
,
block_tables
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
kv_scale
,
...
@@ -150,9 +149,9 @@ class PagedAttention:
...
@@ -150,9 +149,9 @@ class PagedAttention:
num_kv_heads
,
num_kv_heads
,
scale
,
scale
,
block_tables
,
block_tables
,
context
_lens
,
seq
_lens
,
block_size
,
block_size
,
max_
context
_len
,
max_
seq
_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
kv_scale
,
...
@@ -168,10 +167,11 @@ class PagedAttention:
...
@@ -168,10 +167,11 @@ class PagedAttention:
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
subquery_start_loc
:
torch
.
Tensor
,
subquery_start_loc
:
torch
.
Tensor
,
prompt
_lens_tensor
:
torch
.
Tensor
,
seq
_lens_tensor
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_
sub
query_len
:
int
,
max_query_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
sliding_window
:
Optional
[
int
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
context_attention_fwd
(
context_attention_fwd
(
...
@@ -184,10 +184,11 @@ class PagedAttention:
...
@@ -184,10 +184,11 @@ class PagedAttention:
block_tables
,
block_tables
,
# subquery_start_loc is (batch_size + 1,)
# subquery_start_loc is (batch_size + 1,)
subquery_start_loc
[:
-
1
],
subquery_start_loc
[:
-
1
],
prompt
_lens_tensor
,
seq
_lens_tensor
,
context_lens
,
context_lens
,
max_
sub
query_len
,
max_query_len
,
alibi_slopes
,
alibi_slopes
,
sliding_window
,
)
)
return
output
return
output
...
...
vllm/attention/ops/prefix_prefill.py
View file @
1591c68f
...
@@ -50,6 +50,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -50,6 +50,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL
:
tl
.
constexpr
,
# head size
BLOCK_DMODEL
:
tl
.
constexpr
,
# head size
BLOCK_DMODEL_PADDED
:
tl
.
constexpr
,
# head size padded to a power of 2
BLOCK_DMODEL_PADDED
:
tl
.
constexpr
,
# head size padded to a power of 2
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
SLIDING_WINDOW
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -62,42 +63,53 @@ if triton.__version__ >= "2.1.0":
...
@@ -62,42 +63,53 @@ if triton.__version__ >= "2.1.0":
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_query_len
=
cur_batch_seq_len
-
cur_batch_ctx_len
cur_batch_query_len
=
cur_batch_seq_len
-
cur_batch_ctx_len
# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc
=
BLOCK_M
*
start_m
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
# initialize offsets
# [N]; starts at 0
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
# [D]; starts at 0
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
# [M]; starts at current position in query
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
# [M,D]
off_q
=
(
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
dim_mask
=
tl
.
where
(
dim_mask
=
tl
.
where
(
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
<
BLOCK_DMODEL
,
1
,
0
).
to
(
tl
.
int1
)
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
<
BLOCK_DMODEL
,
1
,
0
).
to
(
tl
.
int1
)
# [D]
q
=
tl
.
load
(
Q
+
off_q
,
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
dim_mask
[
None
,
:]
&
mask
=
dim_mask
[
None
,
:]
&
(
offs_m
[:,
None
]
<
cur_batch_query_len
),
(
offs_m
[:,
None
]
<
cur_batch_query_len
),
other
=
0.0
)
other
=
0.0
)
# [M,D]
# # initialize pointer to m and l
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
# [M]
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
# [M]
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL_PADDED
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL_PADDED
],
dtype
=
tl
.
float32
)
# [M,D]
# compute query against context (no causal mask here)
for
start_n
in
range
(
0
,
cur_batch_ctx_len
,
BLOCK_N
):
for
start_n
in
range
(
0
,
cur_batch_ctx_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
# -- compute qk ----
bn
=
tl
.
load
(
B_Loc
+
cur_batch
*
stride_b_loc_b
+
bn
=
tl
.
load
(
B_Loc
+
cur_batch
*
stride_b_loc_b
+
((
start_n
+
offs_n
)
//
block_size
)
*
stride_b_loc_s
,
((
start_n
+
offs_n
)
//
block_size
)
*
stride_b_loc_s
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
other
=
0
)
# [N]
# [D,N]
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_kv_head
*
stride_k_cache_h
+
cur_kv_head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
# [N,D]
off_v
=
(
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
cur_kv_head
*
stride_v_cache_h
+
...
@@ -106,23 +118,39 @@ if triton.__version__ >= "2.1.0":
...
@@ -106,23 +118,39 @@ if triton.__version__ >= "2.1.0":
k
=
tl
.
load
(
K_cache
+
off_k
,
k
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
dim_mask
[:,
None
]
&
mask
=
dim_mask
[:,
None
]
&
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
),
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
# [D,N]
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
# [M,N]
qk
+=
tl
.
dot
(
q
,
k
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
=
tl
.
where
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
qk
,
qk
=
tl
.
where
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
qk
,
float
(
"-inf"
))
float
(
"-inf"
))
qk
*=
sm_scale
qk
*=
sm_scale
if
SLIDING_WINDOW
>
0
:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk
=
tl
.
where
((
cur_batch_ctx_len
+
offs_m
[:,
None
])
-
(
start_n
+
offs_n
[
None
,
:])
<
SLIDING_WINDOW
,
qk
,
-
10000
)
# -- compute m_ij, p, l_ij
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
m_ij
=
tl
.
max
(
qk
,
1
)
# [M]
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
# [M,N]
l_ij
=
tl
.
sum
(
p
,
1
)
l_ij
=
tl
.
sum
(
p
,
1
)
# [M]
# -- update m_i and l_i
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
# [M]
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
# [M]
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
# [M]
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# [M]
# -- update output accumulator --
# -- update output accumulator --
# scale p
# scale p
p_scale
=
beta
/
l_i_new
p_scale
=
beta
/
l_i_new
...
@@ -134,7 +162,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -134,7 +162,7 @@ if triton.__version__ >= "2.1.0":
v
=
tl
.
load
(
V_cache
+
off_v
,
v
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
dim_mask
[
None
,
:]
&
mask
=
dim_mask
[
None
,
:]
&
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
),
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
# [N,D]
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
acc
+=
tl
.
dot
(
p
,
v
)
...
@@ -149,8 +177,10 @@ if triton.__version__ >= "2.1.0":
...
@@ -149,8 +177,10 @@ if triton.__version__ >= "2.1.0":
k_ptrs
=
K
+
off_k
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
v_ptrs
=
V
+
off_v
# block_mask is 0 when we're already past the current query length
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_query_len
,
1
,
0
)
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_query_len
,
1
,
0
)
# compute query against itself (with causal mask)
for
start_n
in
range
(
0
,
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
):
for
start_n
in
range
(
0
,
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
# -- compute qk ----
...
@@ -163,8 +193,13 @@ if triton.__version__ >= "2.1.0":
...
@@ -163,8 +193,13 @@ if triton.__version__ >= "2.1.0":
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
qk
*=
sm_scale
# apply causal mask
qk
=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:]),
qk
,
qk
=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
float
(
"-inf"
))
if
SLIDING_WINDOW
>
0
:
qk
=
tl
.
where
(
offs_m
[:,
None
]
-
(
start_n
+
offs_n
[
None
,
:])
<
SLIDING_WINDOW
,
qk
,
-
10000
)
# -- compute m_ij, p, l_ij
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
m_ij
=
tl
.
max
(
qk
,
1
)
...
@@ -636,7 +671,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -636,7 +671,8 @@ if triton.__version__ >= "2.1.0":
b_seq_len
,
b_seq_len
,
b_ctx_len
,
b_ctx_len
,
max_input_len
,
max_input_len
,
alibi_slopes
=
None
):
alibi_slopes
=
None
,
sliding_window
=
None
):
cap
=
torch
.
cuda
.
get_device_capability
()
cap
=
torch
.
cuda
.
get_device_capability
()
BLOCK
=
128
if
cap
[
0
]
>=
8
else
64
BLOCK
=
128
if
cap
[
0
]
>=
8
else
64
...
@@ -644,7 +680,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -644,7 +680,7 @@ if triton.__version__ >= "2.1.0":
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lq
==
Lk
and
Lk
==
Lv
# round up Lk to a power of 2 - this is required for Triton block size
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded
=
2
**
((
Lk
-
1
).
bit_length
()
)
Lk_padded
=
triton
.
next_power_of_2
(
Lk
)
sm_scale
=
1.0
/
(
Lq
**
0.5
)
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
...
@@ -749,6 +785,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -749,6 +785,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
SLIDING_WINDOW
=
sliding_window
if
sliding_window
is
not
None
else
0
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
)
)
...
...
vllm/attention/ops/triton_flash_attention.py
View file @
1591c68f
...
@@ -293,7 +293,7 @@ def _attn_fwd_inner(
...
@@ -293,7 +293,7 @@ def _attn_fwd_inner(
num_warps
=
4
,
num_warps
=
4
,
),
),
],
],
key
=
[
"hq"
,
"hk"
,
"
IS_CAUSAL
"
,
"
dropout_p
"
,
"
BLOCK_DMODEL
"
],
key
=
[
'
IS_CAUSAL
'
,
'
dropout_p
'
,
'
BLOCK_DMODEL
'
],
)
)
@
triton
.
jit
@
triton
.
jit
def
attn_fwd
(
def
attn_fwd
(
...
@@ -330,8 +330,8 @@ def attn_fwd(
...
@@ -330,8 +330,8 @@ def attn_fwd(
philox_seed
,
philox_seed
,
philox_offset_base
,
philox_offset_base
,
encoded_softmax
,
encoded_softmax
,
hq
,
HQ
:
tl
.
constexpr
,
hk
,
HK
:
tl
.
constexpr
,
ACTUAL_BLOCK_DMODEL
:
tl
.
constexpr
,
ACTUAL_BLOCK_DMODEL
:
tl
.
constexpr
,
MAX_SEQLENS_Q
:
tl
.
constexpr
,
MAX_SEQLENS_Q
:
tl
.
constexpr
,
MAX_SEQLENS_K
:
tl
.
constexpr
,
MAX_SEQLENS_K
:
tl
.
constexpr
,
...
@@ -403,7 +403,7 @@ def attn_fwd(
...
@@ -403,7 +403,7 @@ def attn_fwd(
# We still need to write 0s to the result
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z *
hq
* MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# l_ptrs = L + off_z *
HQ
* MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
# we subtract this
...
@@ -414,11 +414,9 @@ def attn_fwd(
...
@@ -414,11 +414,9 @@ def attn_fwd(
# TODO: Should dropout and return encoded softmax be handled here?
# TODO: Should dropout and return encoded softmax be handled here?
return
return
is_mqa
=
hq
!=
hk
# If MQA / GQA, set the K and V head offsets appropriately.
if
is_mqa
:
# noqa: SIM108
GROUP_SIZE
:
tl
.
constexpr
=
HQ
//
HK
off_h_k
=
off_h_q
%
hk
off_h_k
=
off_h_q
//
GROUP_SIZE
if
GROUP_SIZE
!=
1
else
off_h_q
else
:
off_h_k
=
off_h_q
n_extra_tokens
=
0
n_extra_tokens
=
0
if
seqlen_k
<
BLOCK_N
:
if
seqlen_k
<
BLOCK_N
:
...
@@ -471,7 +469,7 @@ def attn_fwd(
...
@@ -471,7 +469,7 @@ def attn_fwd(
bias_ptr
=
None
bias_ptr
=
None
if
ENABLE_DROPOUT
:
if
ENABLE_DROPOUT
:
batch_philox_offset
=
philox_offset_base
\
batch_philox_offset
=
philox_offset_base
\
+
(
off_z
*
hq
+
off_h_q
)
\
+
(
off_z
*
HQ
+
off_h_q
)
\
*
seqlen_q
*
seqlen_k
*
seqlen_q
*
seqlen_k
else
:
else
:
batch_philox_offset
=
0
batch_philox_offset
=
0
...
@@ -624,7 +622,7 @@ def attn_fwd(
...
@@ -624,7 +622,7 @@ def attn_fwd(
z
=
0.0
z
=
0.0
acc
=
tl
.
where
(
out_ptrs_mask
,
acc
,
z
.
to
(
acc
.
type
.
element_ty
))
acc
=
tl
.
where
(
out_ptrs_mask
,
acc
,
z
.
to
(
acc
.
type
.
element_ty
))
# write back LSE
# write back LSE
# l_ptrs = L + off_z *
hq
* MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# l_ptrs = L + off_z *
HQ
* MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
# overflow_size will be -ve
...
@@ -784,8 +782,8 @@ class _attention(torch.autograd.Function):
...
@@ -784,8 +782,8 @@ class _attention(torch.autograd.Function):
philox_seed
=
philox_seed
,
philox_seed
=
philox_seed
,
philox_offset_base
=
philox_offset
,
philox_offset_base
=
philox_offset
,
encoded_softmax
=
encoded_softmax
,
encoded_softmax
=
encoded_softmax
,
hq
=
nheads_q
,
HQ
=
nheads_q
,
hk
=
nheads_k
,
HK
=
nheads_k
,
ACTUAL_BLOCK_DMODEL
=
head_size
,
ACTUAL_BLOCK_DMODEL
=
head_size
,
MAX_SEQLENS_Q
=
max_seqlens_q
,
MAX_SEQLENS_Q
=
max_seqlens_q
,
MAX_SEQLENS_K
=
max_seqlens_k
,
MAX_SEQLENS_K
=
max_seqlens_k
,
...
...
vllm/attention/selector.py
View file @
1591c68f
import
enum
import
enum
import
os
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Type
from
typing
import
Type
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_cpu
,
is_hip
from
vllm.utils
import
is_cpu
,
is_hip
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
VLLM_ATTENTION_BACKEND
=
"VLLM_ATTENTION_BACKEND"
class
_Backend
(
enum
.
Enum
):
class
_Backend
(
enum
.
Enum
):
FLASH_ATTN
=
enum
.
auto
()
FLASH_ATTN
=
enum
.
auto
()
XFORMERS
=
enum
.
auto
()
XFORMERS
=
enum
.
auto
()
ROCM_FLASH
=
enum
.
auto
()
ROCM_FLASH
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
get_attn_backend
(
dtype
:
torch
.
dtype
)
->
Type
[
AttentionBackend
]:
def
get_attn_backend
(
dtype
:
torch
.
dtype
)
->
Type
[
AttentionBackend
]:
backend
=
_which_attn_to_use
(
dtype
)
backend
=
_which_attn_to_use
(
dtype
)
if
backend
==
_Backend
.
FLASH_ATTN
:
if
backend
==
_Backend
.
FLASH_ATTN
:
logger
.
info
(
"Using FlashAttention backend."
)
logger
.
info
(
"Using FlashAttention
-2
backend."
)
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
FlashAttentionBackend
)
FlashAttentionBackend
)
return
FlashAttentionBackend
return
FlashAttentionBackend
...
@@ -43,6 +42,11 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
...
@@ -43,6 +42,11 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
logger
.
info
(
"Using Torch SDPA backend."
)
logger
.
info
(
"Using Torch SDPA backend."
)
from
vllm.attention.backends.torch_sdpa
import
TorchSDPABackend
from
vllm.attention.backends.torch_sdpa
import
TorchSDPABackend
return
TorchSDPABackend
return
TorchSDPABackend
elif
backend
==
_Backend
.
FLASHINFER
:
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
warning
(
"Eager mode is enforced for the Flashinfer backend. "
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
return
FlashInferBackend
else
:
else
:
raise
ValueError
(
"Invalid attention backend."
)
raise
ValueError
(
"Invalid attention backend."
)
...
@@ -62,12 +66,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
...
@@ -62,12 +66,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
# NVIDIA GPUs.
# NVIDIA GPUs.
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
# Volta and Turing NVIDIA GPUs.
# Volta and Turing NVIDIA GPUs.
logger
.
info
(
"Cannot use FlashAttention backend for Volta and Turing "
logger
.
info
(
"Cannot use FlashAttention
-2
backend for Volta and Turing "
"GPUs."
)
"GPUs."
)
return
_Backend
.
XFORMERS
return
_Backend
.
XFORMERS
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
logger
.
info
(
"Cannot use FlashAttention backend for dtype other than "
logger
.
info
(
"Cannot use FlashAttention
-2
backend for dtype other than "
"torch.float16 or torch.bfloat16."
)
"torch.float16 or torch.bfloat16."
)
return
_Backend
.
XFORMERS
return
_Backend
.
XFORMERS
...
@@ -75,11 +79,11 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
...
@@ -75,11 +79,11 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
import
flash_attn
# noqa: F401
import
flash_attn
# noqa: F401
except
ImportError
:
except
ImportError
:
logger
.
info
(
logger
.
info
(
"Cannot use FlashAttention backend because the flash_attn
package
"
"Cannot use FlashAttention
-2
backend because the flash_attn "
"is not found. Please install it for better performance."
)
"
package
is not found. Please install it for better performance."
)
return
_Backend
.
XFORMERS
return
_Backend
.
XFORMERS
backend_by_env_var
=
os
.
get
env
(
VLLM_ATTENTION_BACKEND
)
backend_by_env_var
=
env
s
.
VLLM_ATTENTION_BACKEND
if
backend_by_env_var
is
not
None
:
if
backend_by_env_var
is
not
None
:
return
_Backend
[
backend_by_env_var
]
return
_Backend
[
backend_by_env_var
]
...
...
vllm/config.py
View file @
1591c68f
import
enum
import
enum
import
json
import
json
import
os
from
dataclasses
import
dataclass
,
field
,
fields
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
TYPE_CHECKING
,
ClassVar
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
ClassVar
,
List
,
Optional
,
Union
...
@@ -9,11 +8,14 @@ from packaging.version import Version
...
@@ -9,11 +8,14 @@ from packaging.version import Version
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
(
QUANTIZATION_METHODS
,
get_quantization_config
)
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.utils
import
(
get_cpu_memory
,
get_nvcc_cuda_version
,
is_cpu
,
is_hip
,
from
vllm.utils
import
(
get_cpu_memory
,
get_nvcc_cuda_version
,
is_cpu
,
is_hip
,
is_neuron
)
is_neuron
)
GPTQMarlinConfig
=
get_quantization_config
(
"gptq_marlin"
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
ray.util.placement_group
import
PlacementGroup
...
@@ -21,10 +23,6 @@ if TYPE_CHECKING:
...
@@ -21,10 +23,6 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# If true, will load models from ModelScope instead of Hugging Face Hub.
VLLM_USE_MODELSCOPE
=
os
.
environ
.
get
(
"VLLM_USE_MODELSCOPE"
,
"False"
).
lower
()
==
"true"
_GB
=
1
<<
30
_GB
=
1
<<
30
...
@@ -33,6 +31,8 @@ class ModelConfig:
...
@@ -33,6 +31,8 @@ class ModelConfig:
Args:
Args:
model: Name or path of the huggingface model to use.
model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
available, and "slow" will always use the slow tokenizer.
...
@@ -65,9 +65,16 @@ class ModelConfig:
...
@@ -65,9 +65,16 @@ class ModelConfig:
If False, we will use CUDA graph and eager execution in hybrid.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
When a sequence has context length larger than this, we fall back
to eager mode.
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode
skip_tokenizer_init: If true, skip initialization of tokenizer and
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified,
the model name will be the same as `model`.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -86,8 +93,10 @@ class ModelConfig:
...
@@ -86,8 +93,10 @@ class ModelConfig:
quantization_param_path
:
Optional
[
str
]
=
None
,
quantization_param_path
:
Optional
[
str
]
=
None
,
enforce_eager
:
bool
=
False
,
enforce_eager
:
bool
=
False
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
5
,
max_logprobs
:
int
=
5
,
skip_tokenizer_init
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
)
->
None
:
)
->
None
:
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
...
@@ -101,6 +110,11 @@ class ModelConfig:
...
@@ -101,6 +110,11 @@ class ModelConfig:
self
.
quantization_param_path
=
quantization_param_path
self
.
quantization_param_path
=
quantization_param_path
self
.
enforce_eager
=
enforce_eager
self
.
enforce_eager
=
enforce_eager
self
.
max_context_len_to_capture
=
max_context_len_to_capture
self
.
max_context_len_to_capture
=
max_context_len_to_capture
if
self
.
max_context_len_to_capture
is
not
None
:
raise
ValueError
(
"`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead."
)
self
.
max_seq_len_to_capture
=
(
max_seq_len_to_capture
or
max_context_len_to_capture
)
self
.
max_logprobs
=
max_logprobs
self
.
max_logprobs
=
max_logprobs
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
skip_tokenizer_init
=
skip_tokenizer_init
...
@@ -110,6 +124,8 @@ class ModelConfig:
...
@@ -110,6 +124,8 @@ class ModelConfig:
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
hf_text_config
,
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
hf_text_config
,
max_model_len
)
max_model_len
)
self
.
served_model_name
=
get_served_model_name
(
model
,
served_model_name
)
if
not
self
.
skip_tokenizer_init
:
if
not
self
.
skip_tokenizer_init
:
self
.
_verify_tokenizer_mode
()
self
.
_verify_tokenizer_mode
()
self
.
_verify_quantization
()
self
.
_verify_quantization
()
...
@@ -138,14 +154,34 @@ class ModelConfig:
...
@@ -138,14 +154,34 @@ class ModelConfig:
is_format_marlin
=
(
quant_cfg
.
get
(
"checkpoint_format"
)
==
"marlin"
is_format_marlin
=
(
quant_cfg
.
get
(
"checkpoint_format"
)
==
"marlin"
or
quant_cfg
.
get
(
"is_marlin_format"
,
False
))
or
quant_cfg
.
get
(
"is_marlin_format"
,
False
))
# Use marlin if the GPTQ model is serialized in marlin format.
# Check which LinearMethod the GPTQ model should use.
if
quant_method
==
"gptq"
and
is_format_marlin
:
if
quant_method
==
"gptq"
:
logger
.
info
(
"The model is serialized in Marlin format. "
# If serialized in Marlin format, use MarlinLinearMethod.
# TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod.
if
is_format_marlin
:
logger
.
info
(
"The model is serialized in Marlin format. "
"Using Marlin kernel."
)
quant_method
=
"marlin"
if
self
.
quantization
==
"gptq"
:
self
.
quantization
=
quant_method
# If convertible to Marlin format, use GPTQMarlinLinearMethod
# unless the user explicitly specified GPTQLinearMethod.
elif
GPTQMarlinConfig
.
is_marlin_compatible
(
quant_cfg
):
if
self
.
quantization
==
"gptq"
:
logger
.
warning
(
"The model is convertible to Marlin format, but "
"you specified quantization=gptq. Use "
"quantization=marlin for faster inference."
)
else
:
logger
.
info
(
"The model is convertible to Marlin format. "
"Using Marlin kernel."
)
"Using Marlin kernel."
)
quant_method
=
"marlin"
quant_method
=
"
gptq_
marlin"
if
self
.
quantization
==
"
gptq
"
:
if
self
.
quantization
==
"
marlin
"
:
self
.
quantization
=
quant_method
self
.
quantization
=
quant_method
# Verify quantization configurations.
if
self
.
quantization
is
None
:
if
self
.
quantization
is
None
:
self
.
quantization
=
quant_method
self
.
quantization
=
quant_method
elif
self
.
quantization
!=
quant_method
:
elif
self
.
quantization
!=
quant_method
:
...
@@ -165,17 +201,17 @@ class ModelConfig:
...
@@ -165,17 +201,17 @@ class ModelConfig:
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
f
"supported in ROCm."
)
if
self
.
quantization
!=
"
marlin"
:
if
(
self
.
quantization
not
in
[
"marlin"
,
"gptq_
marlin"
])
:
logger
.
warning
(
logger
.
warning
(
f
"
{
self
.
quantization
}
quantization is not fully "
"%s
quantization is not fully "
"optimized yet. The speed can be slower than "
"optimized yet. The speed can be slower than "
"non-quantized models."
)
"non-quantized models."
,
self
.
quantization
)
def
_verify_cuda_graph
(
self
)
->
None
:
def
_verify_cuda_graph
(
self
)
->
None
:
if
self
.
max_
context
_len_to_capture
is
None
:
if
self
.
max_
seq
_len_to_capture
is
None
:
self
.
max_
context
_len_to_capture
=
self
.
max_model_len
self
.
max_
seq
_len_to_capture
=
self
.
max_model_len
self
.
max_
context
_len_to_capture
=
min
(
self
.
max_
context
_len_to_capture
,
self
.
max_
seq
_len_to_capture
=
min
(
self
.
max_
seq
_len_to_capture
,
self
.
max_model_len
)
self
.
max_model_len
)
def
verify_with_parallel_config
(
def
verify_with_parallel_config
(
self
,
self
,
...
@@ -271,6 +307,11 @@ class ModelConfig:
...
@@ -271,6 +307,11 @@ class ModelConfig:
return
max
(
1
,
return
max
(
1
,
total_num_kv_heads
//
parallel_config
.
tensor_parallel_size
)
total_num_kv_heads
//
parallel_config
.
tensor_parallel_size
)
def
get_num_attention_heads
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
return
self
.
hf_text_config
.
num_attention_heads
//
\
parallel_config
.
tensor_parallel_size
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
total_num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
total_num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
return
total_num_hidden_layers
//
parallel_config
.
pipeline_parallel_size
return
total_num_hidden_layers
//
parallel_config
.
pipeline_parallel_size
...
@@ -330,7 +371,8 @@ class CacheConfig:
...
@@ -330,7 +371,8 @@ class CacheConfig:
elif
self
.
cache_dtype
==
"fp8"
:
elif
self
.
cache_dtype
==
"fp8"
:
if
not
is_hip
():
if
not
is_hip
():
nvcc_cuda_version
=
get_nvcc_cuda_version
()
nvcc_cuda_version
=
get_nvcc_cuda_version
()
if
nvcc_cuda_version
<
Version
(
"11.8"
):
if
nvcc_cuda_version
is
not
None
\
and
nvcc_cuda_version
<
Version
(
"11.8"
):
raise
ValueError
(
raise
ValueError
(
"FP8 is not supported when cuda version is"
"FP8 is not supported when cuda version is"
"lower than 11.8."
)
"lower than 11.8."
)
...
@@ -360,7 +402,7 @@ class CacheConfig:
...
@@ -360,7 +402,7 @@ class CacheConfig:
if
cpu_memory_usage
>
0.7
*
total_cpu_memory
:
if
cpu_memory_usage
>
0.7
*
total_cpu_memory
:
raise
ValueError
(
"Too large swap space. "
+
msg
)
raise
ValueError
(
"Too large swap space. "
+
msg
)
elif
cpu_memory_usage
>
0.4
*
total_cpu_memory
:
elif
cpu_memory_usage
>
0.4
*
total_cpu_memory
:
logger
.
warning
(
"Possibly too large swap space.
"
+
msg
)
logger
.
warning
(
"Possibly too large swap space.
%s"
,
msg
)
@
dataclass
@
dataclass
...
@@ -574,8 +616,9 @@ class SchedulerConfig:
...
@@ -574,8 +616,9 @@ class SchedulerConfig:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
else
:
else
:
if
enable_chunked_prefill
:
if
enable_chunked_prefill
:
# For chunked prefill, choose the well-tuned batch size.
# It is the values that have the best balance between ITL
self
.
max_num_batched_tokens
=
768
# and TTFT on A100. Note it is not optimized for throughput.
self
.
max_num_batched_tokens
=
512
else
:
else
:
# If max_model_len is too short, use 2048 as the default value
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
# for higher throughput.
...
@@ -658,6 +701,8 @@ class SpeculativeConfig:
...
@@ -658,6 +701,8 @@ class SpeculativeConfig:
speculative_max_model_len
:
Optional
[
int
],
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
enable_chunked_prefill
:
bool
,
use_v2_block_manager
:
bool
,
use_v2_block_manager
:
bool
,
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
)
->
Optional
[
"SpeculativeConfig"
]:
)
->
Optional
[
"SpeculativeConfig"
]:
"""Create a SpeculativeConfig if possible, else return None.
"""Create a SpeculativeConfig if possible, else return None.
...
@@ -684,6 +729,10 @@ class SpeculativeConfig:
...
@@ -684,6 +729,10 @@ class SpeculativeConfig:
use_v2_block_manager (bool): Whether vLLM is configured to use the
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
block manager is required with spec decode.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided.
Returns:
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
...
@@ -718,39 +767,57 @@ class SpeculativeConfig:
...
@@ -718,39 +767,57 @@ class SpeculativeConfig:
draft_code_revision
=
None
draft_code_revision
=
None
draft_quantization
=
None
draft_quantization
=
None
draft_model_config
=
ModelConfig
(
if
speculative_model
==
"[ngram]"
:
model
=
speculative_model
,
assert
(
ngram_prompt_lookup_max
is
not
None
tokenizer
=
target_model_config
.
tokenizer
,
and
ngram_prompt_lookup_max
>
0
)
tokenizer_mode
=
target_model_config
.
tokenizer_mode
,
if
ngram_prompt_lookup_min
is
None
:
trust_remote_code
=
target_model_config
.
trust_remote_code
,
ngram_prompt_lookup_min
=
0
dtype
=
target_model_config
.
dtype
,
else
:
seed
=
target_model_config
.
seed
,
assert
ngram_prompt_lookup_max
>
ngram_prompt_lookup_min
revision
=
draft_revision
,
code_revision
=
draft_code_revision
,
tokenizer_revision
=
target_model_config
.
tokenizer_revision
,
max_model_len
=
None
,
quantization
=
draft_quantization
,
enforce_eager
=
target_model_config
.
enforce_eager
,
max_context_len_to_capture
=
target_model_config
.
max_context_len_to_capture
,
max_logprobs
=
target_model_config
.
max_logprobs
,
)
draft_model_config
.
max_model_len
=
(
SpeculativeConfig
.
_maybe_override_draft_max_model_len
(
speculative_max_model_len
,
draft_model_config
.
max_model_len
,
target_model_config
.
max_model_len
,
))
draft_parallel_config
=
(
# TODO: current we still need extract vocab_size from target model
SpeculativeConfig
.
create_draft_parallel_config
(
# config, in future, we may try refactor it out, and set
target_parallel_config
))
# draft related config as None here.
draft_model_config
=
target_model_config
draft_parallel_config
=
target_parallel_config
else
:
ngram_prompt_lookup_max
=
0
ngram_prompt_lookup_min
=
0
draft_model_config
=
ModelConfig
(
model
=
speculative_model
,
tokenizer
=
target_model_config
.
tokenizer
,
tokenizer_mode
=
target_model_config
.
tokenizer_mode
,
trust_remote_code
=
target_model_config
.
trust_remote_code
,
dtype
=
target_model_config
.
dtype
,
seed
=
target_model_config
.
seed
,
revision
=
draft_revision
,
code_revision
=
draft_code_revision
,
tokenizer_revision
=
target_model_config
.
tokenizer_revision
,
max_model_len
=
None
,
quantization
=
draft_quantization
,
enforce_eager
=
target_model_config
.
enforce_eager
,
max_seq_len_to_capture
=
target_model_config
.
max_seq_len_to_capture
,
max_logprobs
=
target_model_config
.
max_logprobs
,
)
draft_model_config
.
max_model_len
=
(
SpeculativeConfig
.
_maybe_override_draft_max_model_len
(
speculative_max_model_len
,
draft_model_config
.
max_model_len
,
target_model_config
.
max_model_len
,
))
draft_parallel_config
=
(
SpeculativeConfig
.
create_draft_parallel_config
(
target_parallel_config
))
return
SpeculativeConfig
(
return
SpeculativeConfig
(
draft_model_config
,
draft_model_config
,
draft_parallel_config
,
draft_parallel_config
,
num_speculative_tokens
,
num_speculative_tokens
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
,
)
)
@
staticmethod
@
staticmethod
...
@@ -818,6 +885,8 @@ class SpeculativeConfig:
...
@@ -818,6 +885,8 @@ class SpeculativeConfig:
draft_model_config
:
ModelConfig
,
draft_model_config
:
ModelConfig
,
draft_parallel_config
:
ParallelConfig
,
draft_parallel_config
:
ParallelConfig
,
num_speculative_tokens
:
int
,
num_speculative_tokens
:
int
,
ngram_prompt_lookup_max
:
int
,
ngram_prompt_lookup_min
:
int
,
):
):
"""Create a SpeculativeConfig object.
"""Create a SpeculativeConfig object.
...
@@ -830,6 +899,8 @@ class SpeculativeConfig:
...
@@ -830,6 +899,8 @@ class SpeculativeConfig:
self
.
draft_model_config
=
draft_model_config
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
self
.
_verify_args
()
self
.
_verify_args
()
...
@@ -853,7 +924,10 @@ class SpeculativeConfig:
...
@@ -853,7 +924,10 @@ class SpeculativeConfig:
return
self
.
num_speculative_tokens
return
self
.
num_speculative_tokens
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
draft_model
=
self
.
draft_model_config
.
model
if
self
.
ngram_prompt_lookup_max
>
0
:
draft_model
=
"[ngram]"
else
:
draft_model
=
self
.
draft_model_config
.
model
num_spec_tokens
=
self
.
num_speculative_tokens
num_spec_tokens
=
self
.
num_speculative_tokens
return
f
"SpeculativeConfig(
{
draft_model
=
}
,
{
num_spec_tokens
=
}
)"
return
f
"SpeculativeConfig(
{
draft_model
=
}
,
{
num_spec_tokens
=
}
)"
...
@@ -862,6 +936,7 @@ class SpeculativeConfig:
...
@@ -862,6 +936,7 @@ class SpeculativeConfig:
class
LoRAConfig
:
class
LoRAConfig
:
max_lora_rank
:
int
max_lora_rank
:
int
max_loras
:
int
max_loras
:
int
fully_sharded_loras
:
bool
=
False
max_cpu_loras
:
Optional
[
int
]
=
None
max_cpu_loras
:
Optional
[
int
]
=
None
lora_dtype
:
Optional
[
torch
.
dtype
]
=
None
lora_dtype
:
Optional
[
torch
.
dtype
]
=
None
lora_extra_vocab_size
:
int
=
256
lora_extra_vocab_size
:
int
=
256
...
@@ -898,8 +973,8 @@ class LoRAConfig:
...
@@ -898,8 +973,8 @@ class LoRAConfig:
"awq"
,
"gptq"
"awq"
,
"gptq"
]:
]:
# TODO support marlin and squeezellm
# TODO support marlin and squeezellm
logger
.
warning
(
f
"
{
model_config
.
quantization
}
quantization is not "
logger
.
warning
(
"%s quantization is not tested with LoRA yet."
,
"tested with LoRA yet."
)
model_config
.
quantization
)
def
verify_with_scheduler_config
(
self
,
scheduler_config
:
SchedulerConfig
):
def
verify_with_scheduler_config
(
self
,
scheduler_config
:
SchedulerConfig
):
if
scheduler_config
.
max_num_batched_tokens
>
65528
:
if
scheduler_config
.
max_num_batched_tokens
>
65528
:
...
@@ -1008,7 +1083,7 @@ def _get_and_verify_dtype(
...
@@ -1008,7 +1083,7 @@ def _get_and_verify_dtype(
pass
pass
else
:
else
:
# Casting between float16 and bfloat16 is allowed with a warning.
# Casting between float16 and bfloat16 is allowed with a warning.
logger
.
warning
(
f
"Casting
{
config_dtype
}
to
{
torch_dtype
}
."
)
logger
.
warning
(
"Casting
%s to %s."
,
config_dtype
,
torch_dtype
)
return
torch_dtype
return
torch_dtype
...
@@ -1051,12 +1126,12 @@ def _get_and_verify_max_len(
...
@@ -1051,12 +1126,12 @@ def _get_and_verify_max_len(
logger
.
warning
(
logger
.
warning
(
"The model's config.json does not contain any of the following "
"The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: "
"keys to determine the original maximum length of the model: "
f
"
{
possible_keys
}
. Assuming the model's maximum length is
"
"%d
. Assuming the model's maximum length is
%d."
,
possible_keys
,
f
"
{
default_max_len
}
."
)
default_max_len
)
derived_max_model_len
=
default_max_len
derived_max_model_len
=
default_max_len
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
hf_config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
:
if
rope_scaling
is
not
None
and
rope_scaling
[
"type"
]
!=
"su"
:
assert
"factor"
in
rope_scaling
assert
"factor"
in
rope_scaling
scaling_factor
=
rope_scaling
[
"factor"
]
scaling_factor
=
rope_scaling
[
"factor"
]
if
rope_scaling
[
"type"
]
==
"yarn"
:
if
rope_scaling
[
"type"
]
==
"yarn"
:
...
@@ -1084,6 +1159,22 @@ def _get_and_verify_max_len(
...
@@ -1084,6 +1159,22 @@ def _get_and_verify_max_len(
return
int
(
max_model_len
)
return
int
(
max_model_len
)
def
get_served_model_name
(
model
:
str
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]):
"""
If the input is a non-empty list, the first model_name in
`served_model_name` is taken.
If the input is a non-empty string, it is used directly.
For cases where the input is either an empty string or an
empty list, the fallback is to use `self.model`.
"""
if
not
served_model_name
:
return
model
if
isinstance
(
served_model_name
,
list
):
return
served_model_name
[
0
]
return
served_model_name
@
dataclass
@
dataclass
class
DecodingConfig
:
class
DecodingConfig
:
"""Dataclass which contains the decoding strategy of the engine"""
"""Dataclass which contains the decoding strategy of the engine"""
...
...
vllm/core/block/block_table.py
View file @
1591c68f
...
@@ -40,7 +40,9 @@ class BlockTable:
...
@@ -40,7 +40,9 @@ class BlockTable:
):
):
self
.
_block_size
=
block_size
self
.
_block_size
=
block_size
self
.
_allocator
=
block_allocator
self
.
_allocator
=
block_allocator
self
.
_blocks
:
Optional
[
List
[
Block
]]
=
_blocks
if
_blocks
is
None
:
_blocks
=
[]
self
.
_blocks
:
List
[
Block
]
=
_blocks
# Use helper method instead of directly calculating, as blocks
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
# may not be allocated.
...
@@ -104,7 +106,7 @@ class BlockTable:
...
@@ -104,7 +106,7 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended.
token_ids (List[int]): The sequence of token IDs to be appended.
"""
"""
assert
self
.
_is_allocated
assert
self
.
_is_allocated
assert
self
.
_blocks
is
not
None
assert
len
(
self
.
_blocks
)
>
0
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
)
+
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
)
+
num_lookahead_slots
)
num_lookahead_slots
)
...
@@ -141,6 +143,7 @@ class BlockTable:
...
@@ -141,6 +143,7 @@ class BlockTable:
blocks_to_allocate
=
cdiv
(
slots_to_allocate
,
self
.
_block_size
)
blocks_to_allocate
=
cdiv
(
slots_to_allocate
,
self
.
_block_size
)
for
_
in
range
(
blocks_to_allocate
):
for
_
in
range
(
blocks_to_allocate
):
assert
len
(
self
.
_blocks
)
>
0
self
.
_blocks
.
append
(
self
.
_blocks
.
append
(
self
.
_allocator
.
allocate_mutable
(
prev_block
=
self
.
_blocks
[
-
1
],
self
.
_allocator
.
allocate_mutable
(
prev_block
=
self
.
_blocks
[
-
1
],
device
=
device
))
device
=
device
))
...
@@ -159,6 +162,7 @@ class BlockTable:
...
@@ -159,6 +162,7 @@ class BlockTable:
the current instance.
the current instance.
"""
"""
assert
self
.
_is_allocated
assert
self
.
_is_allocated
assert
len
(
self
.
_blocks
)
>
0
forked_blocks
=
self
.
_allocator
.
fork
(
self
.
_blocks
[
-
1
])
forked_blocks
=
self
.
_allocator
.
fork
(
self
.
_blocks
[
-
1
])
return
BlockTable
(
return
BlockTable
(
block_size
=
self
.
_block_size
,
block_size
=
self
.
_block_size
,
...
@@ -177,10 +181,10 @@ class BlockTable:
...
@@ -177,10 +181,10 @@ class BlockTable:
assert
self
.
_is_allocated
assert
self
.
_is_allocated
for
block
in
self
.
_blocks
:
for
block
in
self
.
_blocks
:
self
.
_allocator
.
free
(
block
)
self
.
_allocator
.
free
(
block
)
self
.
_blocks
=
None
self
.
_blocks
=
[]
@
property
@
property
def
physical_block_ids
(
self
)
->
List
[
int
]:
def
physical_block_ids
(
self
)
->
List
[
Optional
[
int
]
]
:
"""Returns a list of physical block indices for the blocks in the
"""Returns a list of physical block indices for the blocks in the
BlockTable.
BlockTable.
...
@@ -235,7 +239,7 @@ class BlockTable:
...
@@ -235,7 +239,7 @@ class BlockTable:
def
_get_all_token_ids
(
self
)
->
List
[
int
]:
def
_get_all_token_ids
(
self
)
->
List
[
int
]:
# NOTE: This function is O(seq_len); use sparingly.
# NOTE: This function is O(seq_len); use sparingly.
token_ids
=
[]
token_ids
:
List
[
int
]
=
[]
if
not
self
.
_is_allocated
:
if
not
self
.
_is_allocated
:
return
token_ids
return
token_ids
...
@@ -247,7 +251,7 @@ class BlockTable:
...
@@ -247,7 +251,7 @@ class BlockTable:
@
property
@
property
def
_is_allocated
(
self
)
->
bool
:
def
_is_allocated
(
self
)
->
bool
:
return
self
.
_blocks
is
not
None
return
len
(
self
.
_blocks
)
>
0
@
property
@
property
def
_num_empty_slots
(
self
)
->
int
:
def
_num_empty_slots
(
self
)
->
int
:
...
...
vllm/core/block/common.py
View file @
1591c68f
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Dict
,
Iterable
,
List
,
Optional
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Protocol
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
...
@@ -7,7 +7,19 @@ BlockId = int
...
@@ -7,7 +7,19 @@ BlockId = int
RefCount
=
int
RefCount
=
int
class
RefCounter
:
class
RefCounterProtocol
(
Protocol
):
def
incr
(
self
,
block_id
:
BlockId
)
->
RefCount
:
raise
NotImplementedError
def
decr
(
self
,
block_id
:
BlockId
)
->
RefCount
:
raise
NotImplementedError
def
get
(
self
,
block_id
:
BlockId
)
->
RefCount
:
raise
NotImplementedError
class
RefCounter
(
RefCounterProtocol
):
"""A class for managing reference counts for a set of block indices.
"""A class for managing reference counts for a set of block indices.
The RefCounter class maintains a dictionary that maps block indices to their
The RefCounter class maintains a dictionary that maps block indices to their
...
@@ -54,7 +66,7 @@ class RefCounter:
...
@@ -54,7 +66,7 @@ class RefCounter:
return
ReadOnlyRefCounter
(
self
)
return
ReadOnlyRefCounter
(
self
)
class
ReadOnlyRefCounter
:
class
ReadOnlyRefCounter
(
RefCounterProtocol
)
:
"""A read-only view of the RefCounter class.
"""A read-only view of the RefCounter class.
The ReadOnlyRefCounter class provides a read-only interface to access the
The ReadOnlyRefCounter class provides a read-only interface to access the
...
@@ -96,7 +108,7 @@ class CopyOnWriteTracker:
...
@@ -96,7 +108,7 @@ class CopyOnWriteTracker:
def
__init__
(
def
__init__
(
self
,
self
,
refcounter
:
RefCounter
,
refcounter
:
RefCounter
Protocol
,
allocator
:
BlockAllocator
,
allocator
:
BlockAllocator
,
):
):
self
.
_copy_on_writes
:
Dict
[
BlockId
,
List
[
BlockId
]]
=
defaultdict
(
list
)
self
.
_copy_on_writes
:
Dict
[
BlockId
,
List
[
BlockId
]]
=
defaultdict
(
list
)
...
...
vllm/core/block/cpu_gpu_block_allocator.py
View file @
1591c68f
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
FrozenSet
,
List
,
Optional
from
vllm.core.block.interfaces
import
(
Block
,
BlockAllocator
,
from
vllm.core.block.interfaces
import
(
Block
,
BlockAllocator
,
BlockId
,
DeviceAwareBlockAllocator
)
DeviceAwareBlockAllocator
)
from
vllm.core.block.naive_block
import
NaiveBlock
,
NaiveBlockAllocator
from
vllm.core.block.naive_block
import
NaiveBlock
,
NaiveBlockAllocator
from
vllm.core.block.prefix_caching_block
import
PrefixCachingBlockAllocator
from
vllm.core.block.prefix_caching_block
import
PrefixCachingBlockAllocator
...
@@ -57,15 +57,15 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -57,15 +57,15 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
cpu_block_ids
=
block_ids
[
num_gpu_blocks
:]
cpu_block_ids
=
block_ids
[
num_gpu_blocks
:]
if
allocator_type
==
"naive"
:
if
allocator_type
==
"naive"
:
gpu_allocator
=
NaiveBlockAllocator
(
gpu_allocator
:
BlockAllocator
=
NaiveBlockAllocator
(
create_block
=
NaiveBlock
,
create_block
=
NaiveBlock
,
# type: ignore
num_blocks
=
num_gpu_blocks
,
num_blocks
=
num_gpu_blocks
,
block_size
=
block_size
,
block_size
=
block_size
,
block_ids
=
gpu_block_ids
,
block_ids
=
gpu_block_ids
,
)
)
cpu_allocator
=
NaiveBlockAllocator
(
cpu_allocator
:
BlockAllocator
=
NaiveBlockAllocator
(
create_block
=
NaiveBlock
,
create_block
=
NaiveBlock
,
# type: ignore
num_blocks
=
num_cpu_blocks
,
num_blocks
=
num_cpu_blocks
,
block_size
=
block_size
,
block_size
=
block_size
,
block_ids
=
cpu_block_ids
,
block_ids
=
cpu_block_ids
,
...
@@ -105,7 +105,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -105,7 +105,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Device
.
GPU
:
gpu_block_allocator
,
Device
.
GPU
:
gpu_block_allocator
,
}
}
self
.
_block_ids_to_allocator
=
{}
self
.
_block_ids_to_allocator
:
Dict
[
int
,
BlockAllocator
]
=
{}
for
_
,
allocator
in
self
.
_allocators
.
items
():
for
_
,
allocator
in
self
.
_allocators
.
items
():
for
block_id
in
allocator
.
all_block_ids
:
for
block_id
in
allocator
.
all_block_ids
:
self
.
_block_ids_to_allocator
[
block_id
]
=
allocator
self
.
_block_ids_to_allocator
[
block_id
]
=
allocator
...
@@ -149,7 +149,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -149,7 +149,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
Args:
block (Block): The block to be freed.
block (Block): The block to be freed.
"""
"""
allocator
=
self
.
_block_ids_to_allocator
[
block
.
block_id
]
block_id
=
block
.
block_id
assert
block_id
is
not
None
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
return
allocator
.
free
(
block
)
return
allocator
.
free
(
block
)
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
...
@@ -163,7 +165,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -163,7 +165,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
List[Block]: A new list of blocks that shares the same memory as the
List[Block]: A new list of blocks that shares the same memory as the
original sequence.
original sequence.
"""
"""
allocator
=
self
.
_block_ids_to_allocator
[
last_block
.
block_id
]
block_id
=
last_block
.
block_id
assert
block_id
is
not
None
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
return
allocator
.
fork
(
last_block
)
return
allocator
.
fork
(
last_block
)
def
get_num_free_blocks
(
self
,
device
:
Device
)
->
int
:
def
get_num_free_blocks
(
self
,
device
:
Device
)
->
int
:
...
@@ -171,13 +175,16 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -171,13 +175,16 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
Args:
device (Device): The device for which to query the number of free
device (Device): The device for which to query the number of free
blocks.
blocks.
AssertionError is raised if None is passed.
Returns:
Returns:
int: The number of free blocks available on the specified device.
int: The number of free blocks available on the specified device.
"""
"""
return
self
.
_allocators
[
device
].
get_num_free_blocks
()
return
self
.
_allocators
[
device
].
get_num_free_blocks
()
def
get_num_total_blocks
(
self
,
device
:
Device
)
->
int
:
return
self
.
_allocators
[
device
].
get_num_total_blocks
()
def
clear_copy_on_writes
(
self
)
->
Dict
[
int
,
List
[
int
]]:
def
clear_copy_on_writes
(
self
)
->
Dict
[
int
,
List
[
int
]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
"""Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs.
source to destination block IDs.
...
@@ -190,10 +197,18 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -190,10 +197,18 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device
=
Device
.
GPU
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
clear_copy_on_writes
()
return
self
.
_allocators
[
device
].
clear_copy_on_writes
()
def
mark_blocks_as_computed
(
self
)
->
None
:
def
mark_blocks_as_accessed
(
self
,
block_ids
:
List
[
int
],
now
:
float
)
->
None
:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
mark_blocks_as_accessed
(
block_ids
,
now
)
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
# Prefix caching only supported on GPU.
device
=
Device
.
GPU
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
mark_blocks_as_computed
()
return
self
.
_allocators
[
device
].
mark_blocks_as_computed
(
block_ids
)
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
...
@@ -202,5 +217,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -202,5 +217,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
return
self
.
_allocators
[
device
].
get_common_computed_block_ids
(
return
self
.
_allocators
[
device
].
get_common_computed_block_ids
(
seq_block_ids
)
seq_block_ids
)
def
all_block_ids
(
self
)
->
frozenset
[
int
]:
@
property
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]:
return
frozenset
(
self
.
_block_ids_to_allocator
.
keys
())
return
frozenset
(
self
.
_block_ids_to_allocator
.
keys
())
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
raise
NotImplementedError
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
raise
NotImplementedError
vllm/core/block/interfaces.py
View file @
1591c68f
...
@@ -3,6 +3,8 @@ from typing import Dict, FrozenSet, List, Optional, Protocol
...
@@ -3,6 +3,8 @@ from typing import Dict, FrozenSet, List, Optional, Protocol
from
vllm.utils
import
Device
from
vllm.utils
import
Device
BlockId
=
int
class
Block
(
ABC
):
class
Block
(
ABC
):
...
@@ -15,6 +17,12 @@ class Block(ABC):
...
@@ -15,6 +17,12 @@ class Block(ABC):
def
block_id
(
self
)
->
Optional
[
int
]:
def
block_id
(
self
)
->
Optional
[
int
]:
pass
pass
@
block_id
.
setter
@
abstractmethod
def
block_id
(
self
,
value
:
Optional
[
int
])
->
None
:
"""NOTE: Do not use this API outside Block."""
self
.
_block_id
=
value
@
property
@
property
@
abstractmethod
@
abstractmethod
def
token_ids
(
self
)
->
List
[
int
]:
def
token_ids
(
self
)
->
List
[
int
]:
...
@@ -35,6 +43,27 @@ class Block(ABC):
...
@@ -35,6 +43,27 @@ class Block(ABC):
def
prev_block
(
self
)
->
Optional
[
"Block"
]:
def
prev_block
(
self
)
->
Optional
[
"Block"
]:
pass
pass
@
property
@
abstractmethod
def
computed
(
self
)
->
bool
:
raise
NotImplementedError
@
computed
.
setter
@
abstractmethod
def
computed
(
self
,
value
)
->
bool
:
"""Should be only used by PrefixCacingAllocator"""
raise
NotImplementedError
@
property
@
abstractmethod
def
last_accessed
(
self
)
->
float
:
raise
NotImplementedError
@
last_accessed
.
setter
@
abstractmethod
def
last_accessed
(
self
,
last_accessed_ts
:
float
):
raise
NotImplementedError
class
Factory
(
Protocol
):
class
Factory
(
Protocol
):
@
abstractmethod
@
abstractmethod
...
@@ -48,6 +77,17 @@ class Block(ABC):
...
@@ -48,6 +77,17 @@ class Block(ABC):
)
->
"Block"
:
)
->
"Block"
:
pass
pass
@
property
@
abstractmethod
def
content_hash
(
self
)
->
Optional
[
int
]:
"""Return the content-based hash of the current block, or None if it is
not yet defined or not supported.
For the content-based hash to be defined, the current block must be
full.
"""
return
None
class
BlockAllocator
(
ABC
):
class
BlockAllocator
(
ABC
):
...
@@ -57,7 +97,7 @@ class BlockAllocator(ABC):
...
@@ -57,7 +97,7 @@ class BlockAllocator(ABC):
@
abstractmethod
@
abstractmethod
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
]
,
device
:
Device
)
->
Block
:
token_ids
:
List
[
int
])
->
Block
:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -69,7 +109,11 @@ class BlockAllocator(ABC):
...
@@ -69,7 +109,11 @@ class BlockAllocator(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
get_num_free_blocks
(
self
,
device
:
Device
)
->
int
:
def
get_num_total_blocks
(
self
)
->
int
:
pass
@
abstractmethod
def
get_num_free_blocks
(
self
)
->
int
:
pass
pass
@
property
@
property
...
@@ -82,7 +126,12 @@ class BlockAllocator(ABC):
...
@@ -82,7 +126,12 @@ class BlockAllocator(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
mark_blocks_as_computed
(
self
)
->
None
:
def
mark_blocks_as_accessed
(
self
,
block_ids
:
List
[
int
],
now
:
float
)
->
None
:
pass
@
abstractmethod
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -90,14 +139,25 @@ class BlockAllocator(ABC):
...
@@ -90,14 +139,25 @@ class BlockAllocator(ABC):
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
pass
pass
@
abstractmethod
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
"BlockId"
]:
"""NOTE: This should not be used besides Block"""
pass
@
abstractmethod
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
"""NOTE: This should not be used besides Block"""
pass
class
NoFreeBlocksError
(
ValueError
):
class
NoFreeBlocksError
(
ValueError
):
pass
pass
class
DeviceAwareBlockAllocator
(
BlockAllocator
):
class
DeviceAwareBlockAllocator
(
ABC
):
@
abstractmethod
@
abstractmethod
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
])
->
Block
:
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -108,3 +168,38 @@ class DeviceAwareBlockAllocator(BlockAllocator):
...
@@ -108,3 +168,38 @@ class DeviceAwareBlockAllocator(BlockAllocator):
@
abstractmethod
@
abstractmethod
def
get_num_free_blocks
(
self
,
device
:
Device
)
->
int
:
def
get_num_free_blocks
(
self
,
device
:
Device
)
->
int
:
pass
pass
@
abstractmethod
def
get_num_total_blocks
(
self
,
device
:
Device
)
->
int
:
pass
@
abstractmethod
def
free
(
self
,
block
:
Block
)
->
None
:
pass
@
abstractmethod
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
pass
@
property
@
abstractmethod
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]:
pass
@
abstractmethod
def
clear_copy_on_writes
(
self
)
->
Dict
[
int
,
List
[
int
]]:
pass
@
abstractmethod
def
mark_blocks_as_accessed
(
self
,
block_ids
:
List
[
int
],
now
:
float
)
->
None
:
pass
@
abstractmethod
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
pass
@
abstractmethod
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
pass
vllm/core/block/naive_block.py
View file @
1591c68f
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Set
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Set
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
RefCounter
,
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
RefCounter
,
get_all_blocks_recursively
)
get_all_blocks_recursively
)
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
BlockId
=
int
Refcount
=
int
Refcount
=
int
...
@@ -49,8 +48,10 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -49,8 +48,10 @@ class NaiveBlockAllocator(BlockAllocator):
allocator
=
self
,
allocator
=
self
,
)
)
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
def
allocate_immutable
(
self
,
token_ids
:
List
[
int
])
->
Block
:
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates a new immutable block with the given token IDs, linked to
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
the previous block.
...
@@ -63,11 +64,14 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -63,11 +64,14 @@ class NaiveBlockAllocator(BlockAllocator):
Returns:
Returns:
Block: The newly allocated immutable block.
Block: The newly allocated immutable block.
"""
"""
assert
device
is
None
block
=
self
.
allocate_mutable
(
prev_block
=
prev_block
)
block
=
self
.
allocate_mutable
(
prev_block
=
prev_block
)
block
.
append_token_ids
(
token_ids
)
block
.
append_token_ids
(
token_ids
)
return
block
return
block
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
])
->
Block
:
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates a new mutable block, linked to the previous block.
"""Allocates a new mutable block, linked to the previous block.
Args:
Args:
...
@@ -78,6 +82,7 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -78,6 +82,7 @@ class NaiveBlockAllocator(BlockAllocator):
Returns:
Returns:
Block: The newly allocated mutable block.
Block: The newly allocated mutable block.
"""
"""
assert
device
is
None
block_id
=
self
.
_allocate_new_block_id
()
block_id
=
self
.
_allocate_new_block_id
()
return
self
.
_create_block
(
return
self
.
_create_block
(
prev_block
=
prev_block
,
prev_block
=
prev_block
,
...
@@ -88,6 +93,7 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -88,6 +93,7 @@ class NaiveBlockAllocator(BlockAllocator):
)
)
def
free
(
self
,
block
:
Block
)
->
None
:
def
free
(
self
,
block
:
Block
)
->
None
:
assert
block
.
block_id
is
not
None
self
.
_free_block_id
(
block
.
block_id
)
self
.
_free_block_id
(
block
.
block_id
)
# Mark the block as having no allocation.
# Mark the block as having no allocation.
...
@@ -111,6 +117,7 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -111,6 +117,7 @@ class NaiveBlockAllocator(BlockAllocator):
for
block
in
source_blocks
:
for
block
in
source_blocks
:
# Increment refcount for each block.
# Increment refcount for each block.
assert
block
.
block_id
is
not
None
refcount
=
self
.
_refcounter
.
incr
(
block
.
block_id
)
refcount
=
self
.
_refcounter
.
incr
(
block
.
block_id
)
assert
refcount
!=
1
,
"can't fork free'd block"
assert
refcount
!=
1
,
"can't fork free'd block"
...
@@ -129,6 +136,9 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -129,6 +136,9 @@ class NaiveBlockAllocator(BlockAllocator):
def
get_num_free_blocks
(
self
)
->
int
:
def
get_num_free_blocks
(
self
)
->
int
:
return
len
(
self
.
_free_block_indices
)
return
len
(
self
.
_free_block_indices
)
def
get_num_total_blocks
(
self
)
->
int
:
return
len
(
self
.
_all_block_indices
)
def
_allocate_new_block_id
(
self
)
->
BlockId
:
def
_allocate_new_block_id
(
self
)
->
BlockId
:
if
not
self
.
_free_block_indices
:
if
not
self
.
_free_block_indices
:
raise
BlockAllocator
.
NoFreeBlocksError
()
raise
BlockAllocator
.
NoFreeBlocksError
()
...
@@ -148,7 +158,7 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -148,7 +158,7 @@ class NaiveBlockAllocator(BlockAllocator):
return
self
.
_refcounter
return
self
.
_refcounter
@
property
@
property
def
all_block_ids
(
self
):
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]
:
return
self
.
_all_block_indices
return
self
.
_all_block_indices
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
...
@@ -174,7 +184,16 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -174,7 +184,16 @@ class NaiveBlockAllocator(BlockAllocator):
"""
"""
return
self
.
_cow_tracker
.
clear_cows
()
return
self
.
_cow_tracker
.
clear_cows
()
def
mark_blocks_as_computed
(
self
)
->
None
:
def
mark_blocks_as_accessed
(
self
,
block_ids
:
List
[
int
],
now
:
float
)
->
None
:
"""Mark blocks as accessed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
nothing.
"""
pass
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
"""Mark blocks as computed, used in prefix caching.
"""Mark blocks as computed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
Since the naive allocator does not implement prefix caching, we do
...
@@ -191,6 +210,9 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -191,6 +210,9 @@ class NaiveBlockAllocator(BlockAllocator):
"""
"""
return
[]
return
[]
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
raise
NotImplementedError
class
NaiveBlock
(
Block
):
class
NaiveBlock
(
Block
):
"""An implementation of the Block class that does not support prefix
"""An implementation of the Block class that does not support prefix
...
@@ -215,13 +237,13 @@ class NaiveBlock(Block):
...
@@ -215,13 +237,13 @@ class NaiveBlock(Block):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
prev_block
:
Block
,
prev_block
:
Optional
[
Block
]
,
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
block_size
:
int
,
block_size
:
int
,
allocator
:
BlockAllocator
,
allocator
:
BlockAllocator
,
block_id
:
Optional
[
int
]
=
None
,
block_id
:
Optional
[
int
]
=
None
,
_cow_target
:
Optional
[
Block
]
=
None
):
_cow_target
:
Optional
[
Block
]
=
None
):
self
.
_token_ids
=
[]
self
.
_token_ids
:
List
[
int
]
=
[]
self
.
_block_size
=
block_size
self
.
_block_size
=
block_size
self
.
_prev_block
=
prev_block
self
.
_prev_block
=
prev_block
self
.
_block_id
=
block_id
self
.
_block_id
=
block_id
...
@@ -247,6 +269,22 @@ class NaiveBlock(Block):
...
@@ -247,6 +269,22 @@ class NaiveBlock(Block):
assert
self
.
num_empty_slots
>=
len
(
token_ids
)
assert
self
.
num_empty_slots
>=
len
(
token_ids
)
self
.
_token_ids
.
extend
(
token_ids
)
self
.
_token_ids
.
extend
(
token_ids
)
@
property
def
computed
(
self
)
->
bool
:
raise
NotImplementedError
@
computed
.
setter
def
computed
(
self
,
value
)
->
None
:
raise
NotImplementedError
@
property
def
last_accessed
(
self
)
->
float
:
raise
NotImplementedError
@
last_accessed
.
setter
def
last_accessed
(
self
,
last_accessed_ts
:
float
):
raise
NotImplementedError
@
property
@
property
def
block_id
(
self
)
->
Optional
[
int
]:
def
block_id
(
self
)
->
Optional
[
int
]:
return
self
.
_block_id
return
self
.
_block_id
...
@@ -267,9 +305,14 @@ class NaiveBlock(Block):
...
@@ -267,9 +305,14 @@ class NaiveBlock(Block):
def
token_ids
(
self
)
->
List
[
int
]:
def
token_ids
(
self
)
->
List
[
int
]:
return
self
.
_token_ids
return
self
.
_token_ids
@
property
def
block_size
(
self
)
->
int
:
def
block_size
(
self
)
->
int
:
return
self
.
_block_size
return
self
.
_block_size
@
property
@
property
def
prev_block
(
self
)
->
Optional
[
"Block"
]:
def
prev_block
(
self
)
->
Optional
[
"Block"
]:
return
self
.
_prev_block
return
self
.
_prev_block
@
property
def
content_hash
(
self
)
->
Optional
[
int
]:
return
None
vllm/core/block/prefix_caching_block.py
View file @
1591c68f
"""Token blocks."""
"""Token blocks."""
from
itertools
import
takewhile
from
itertools
import
takewhile
from
os.path
import
commonprefix
from
os.path
import
commonprefix
from
typing
import
Dict
,
Iterable
,
List
,
Optional
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
get_all_blocks_recursively
)
get_all_blocks_recursively
)
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
from
vllm.core.block.naive_block
import
NaiveBlock
,
NaiveBlockAllocator
from
vllm.core.block.naive_block
import
NaiveBlock
,
NaiveBlockAllocator
from
vllm.core.evictor_v2
import
EvictionPolicy
,
Evictor
,
make_evictor
PrefixHash
=
int
PrefixHash
=
int
BlockId
=
int
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
# then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME
=
-
1
class
PrefixCachingBlockAllocator
(
BlockAllocator
):
class
PrefixCachingBlockAllocator
(
BlockAllocator
):
...
@@ -27,26 +32,23 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -27,26 +32,23 @@ class PrefixCachingBlockAllocator(BlockAllocator):
from 0 to num_blocks - 1.
from 0 to num_blocks - 1.
"""
"""
# TODO last access time / evictor integration
def
__init__
(
def
__init__
(
self
,
self
,
num_blocks
:
int
,
num_blocks
:
int
,
block_size
:
int
,
block_size
:
int
,
block_ids
:
Optional
[
Iterable
[
int
]]
=
None
,
block_ids
:
Optional
[
Iterable
[
int
]]
=
None
,
eviction_policy
:
EvictionPolicy
=
EvictionPolicy
.
LRU
,
):
):
# A mapping of prefix hash to block index. All blocks which have a
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0.
# prefix hash will be in this dict, even if they have refcount 0.
self
.
_cached_blocks
:
Dict
[
PrefixHash
,
BlockId
]
=
{}
self
.
_cached_blocks
:
Dict
[
PrefixHash
,
BlockId
]
=
{}
# A mapping of prefix hash to block index. All blocks which have a
# A mapping of blockId to Block to track those cached blocks
# prefix hash AND refcount 0 will be in this dict. Thus, it is a subset
self
.
_blocks
:
Dict
[
BlockId
,
Block
]
=
{}
# of self._cached_blocks.
self
.
_unused_cached_blocks
:
Dict
[
PrefixHash
,
BlockId
]
=
{}
# An allocator for blocks that do not have prefix hashes.
# An allocator for blocks that do not have prefix hashes.
self
.
_hashless_allocator
=
NaiveBlockAllocator
(
self
.
_hashless_allocator
=
NaiveBlockAllocator
(
create_block
=
self
.
_create_block
,
create_block
=
self
.
_create_block
,
# type: ignore
num_blocks
=
num_blocks
,
num_blocks
=
num_blocks
,
block_size
=
block_size
,
block_size
=
block_size
,
block_ids
=
block_ids
,
block_ids
=
block_ids
,
...
@@ -54,6 +56,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -54,6 +56,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self
.
_block_size
=
block_size
self
.
_block_size
=
block_size
# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
self
.
evictor
:
Evictor
=
make_evictor
(
eviction_policy
)
# We share the refcounter between allocators. This allows us to promote
# We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable
# blocks originally allocated in the hashless allocator to immutable
# blocks.
# blocks.
...
@@ -72,6 +78,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -72,6 +78,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_size
:
int
,
block_size
:
int
,
allocator
:
BlockAllocator
,
allocator
:
BlockAllocator
,
block_id
:
Optional
[
int
]
=
None
,
block_id
:
Optional
[
int
]
=
None
,
computed
:
bool
=
False
,
)
->
Block
:
)
->
Block
:
# Bind block to self.
# Bind block to self.
allocator
=
self
allocator
=
self
...
@@ -82,10 +89,13 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -82,10 +89,13 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_size
=
block_size
,
block_size
=
block_size
,
block_id
=
block_id
,
block_id
=
block_id
,
prefix_caching_allocator
=
allocator
,
prefix_caching_allocator
=
allocator
,
computed
=
computed
,
)
)
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
def
allocate_immutable
(
self
,
token_ids
:
List
[
int
])
->
Block
:
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates an immutable block with the given token IDs, reusing cached
"""Allocates an immutable block with the given token IDs, reusing cached
blocks if possible.
blocks if possible.
...
@@ -96,6 +106,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -96,6 +106,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Returns:
Returns:
Block: The allocated immutable block.
Block: The allocated immutable block.
"""
"""
assert
device
is
None
assert_prefix_caching_block_or_none
(
prev_block
)
assert_prefix_caching_block_or_none
(
prev_block
)
block
=
self
.
_create_block
(
block
=
self
.
_create_block
(
...
@@ -109,65 +120,95 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -109,65 +120,95 @@ class PrefixCachingBlockAllocator(BlockAllocator):
cached_block_id
=
self
.
_cached_blocks
.
get
(
block
.
content_hash
,
None
)
cached_block_id
=
self
.
_cached_blocks
.
get
(
block
.
content_hash
,
None
)
if
cached_block_id
is
not
None
:
if
cached_block_id
is
not
None
:
block
.
block_id
=
cached_block_id
block
.
block_id
=
cached_block_id
self
.
_incr_refcount_cached_block
(
block
.
content_hash
,
self
.
_incr_refcount_cached_block
(
block
,
block
.
block_id
)
block
.
block_id
)
return
block
return
block
block
=
self
.
allocate_mutable
(
prev_block
)
block
=
self
.
allocate_mutable
(
prev_block
)
block
.
append_token_ids
(
token_ids
)
block
.
append_token_ids
(
token_ids
)
assert
block
.
content_hash
is
not
None
assert
block
.
content_hash
is
not
None
# TODO computed bit
return
block
return
block
def
allocate_mutable
(
self
,
prev_block
:
Block
)
->
Block
:
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates a mutable block. If there are no free blocks, this will
"""Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks.
evict unused cached blocks.
Args:
Args:
prev_block (Block): The previous block in the sequence.
prev_block (Block): The previous block in the sequence.
None is not allowed unlike it is super class.
Returns:
Returns:
Block: The allocated mutable block.
Block: The allocated mutable block.
"""
"""
assert
device
is
None
assert_prefix_caching_block_or_none
(
prev_block
)
assert_prefix_caching_block_or_none
(
prev_block
)
try
:
try
:
return
self
.
_hashless_allocator
.
allocate_mutable
(
block
=
self
.
_hashless_allocator
.
allocate_mutable
(
prev_block
=
prev_block
)
prev_block
=
prev_block
)
assert
block
.
block_id
not
in
self
.
_blocks
assert
block
.
block_id
is
not
None
self
.
_blocks
[
block
.
block_id
]
=
block
return
block
except
BlockAllocator
.
NoFreeBlocksError
:
except
BlockAllocator
.
NoFreeBlocksError
:
# We must check the unused cached blocks before raising OOM.
# We must check the unused cached blocks before raising OOM.
pass
pass
if
self
.
_unused_cached_blocks
:
# If the evictor has blocks available for eviction, evict a block
# TODO policy for selecting block to remove
# and return it.
content_hash_to_evict
=
next
(
iter
(
self
.
_unused_cached_blocks
))
if
self
.
evictor
.
num_blocks
>
0
:
block_id
,
content_hash_to_evict
=
self
.
evictor
.
evict
()
# Here we may have scenario that several blocks have
# the same content hash, but due to the latter coming block
# is coming from mutable to immutable path, their physical
# block is added into evictor.
# However in this case, we shall not pop the _cached_blocks,
# as the same content is still used by others, which means
# we need to check ref before decide to pop the list.
# Clear content hash mapping; the block will be overwritten.
_block_id
=
self
.
_cached_blocks
[
content_hash_to_evict
]
del
self
.
_cached_blocks
[
content_hash_to_evict
]
refcount
=
self
.
_refcounter
.
get
(
_block_id
)
if
refcount
==
1
:
self
.
_cached_blocks
.
pop
(
content_hash_to_evict
)
assert
_block_id
==
block_id
block_id
=
self
.
_unused_cached_blocks
.
pop
(
content_hash_to_evict
)
self
.
_refcounter
.
incr
(
block_id
)
refcount
=
self
.
_refcounter
.
incr
(
block_id
)
assert
ref
co
u
nt
==
1
# the block comes from evictor already
cont
ain computed result
block
=
self
.
_create_block
(
block
=
self
.
_create_block
(
prev_block
=
prev_block
,
prev_block
=
prev_block
,
token_ids
=
[],
token_ids
=
[],
block_size
=
self
.
_block_size
,
block_size
=
self
.
_block_size
,
allocator
=
self
,
allocator
=
self
,
block_id
=
block_id
,
block_id
=
block_id
,
computed
=
True
,
)
)
assert
block
.
content_hash
is
None
assert
block
.
content_hash
is
None
assert
block
.
block_id
not
in
self
.
_blocks
assert
block
.
block_id
is
not
None
self
.
_blocks
[
block
.
block_id
]
=
block
return
block
return
block
# No block available in hashless allocator, nor in unused cache blocks.
# No block available in hashless allocator, nor in unused cache blocks.
raise
BlockAllocator
.
NoFreeBlocksError
()
raise
BlockAllocator
.
NoFreeBlocksError
()
def
_incr_refcount_cached_block
(
self
,
content_hash
:
int
,
def
_incr_refcount_cached_block
(
self
,
block
:
Block
,
block_id
:
BlockId
)
->
None
:
block_id
:
BlockId
)
->
None
:
# since block is already computed, mark it
block
.
computed
=
True
refcount
=
self
.
_refcounter
.
incr
(
block_id
)
refcount
=
self
.
_refcounter
.
incr
(
block_id
)
if
refcount
==
1
:
if
refcount
==
1
:
assert
content_hash
in
self
.
_unused_cached_blocks
# if block get referred, then it shall not be in evictor
del
self
.
_unused_cached_blocks
[
content_hash
]
# and put it into _blocks for tracking
if
block_id
in
self
.
evictor
:
self
.
evictor
.
remove
(
block_id
)
self
.
_blocks
[
block_id
]
=
block
def
free
(
self
,
block
:
Block
)
->
None
:
def
free
(
self
,
block
:
Block
)
->
None
:
"""Decrement the refcount of the block. If the decremented refcount is
"""Decrement the refcount of the block. If the decremented refcount is
...
@@ -180,6 +221,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -180,6 +221,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
is
not
None
),
"freeing unallocated block is undefined"
is
not
None
),
"freeing unallocated block is undefined"
self
.
_free_block_id_for_block
(
block
.
block_id
,
block
)
self
.
_free_block_id_for_block
(
block
.
block_id
,
block
)
block
.
block_id
=
None
block
.
block_id
=
None
def
_free_block_id_for_block
(
self
,
block_id
:
BlockId
,
def
_free_block_id_for_block
(
self
,
block_id
:
BlockId
,
...
@@ -187,15 +229,23 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -187,15 +229,23 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert
isinstance
(
block
,
PrefixCachingBlock
)
assert
isinstance
(
block
,
PrefixCachingBlock
)
if
block
.
content_hash
is
None
:
if
block
.
content_hash
is
None
:
refcount
=
self
.
_refcounter
.
get
(
block_id
)
# We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1
if
refcount
<=
1
:
assert
block
.
block_id
is
not
None
del
self
.
_blocks
[
block
.
block_id
]
return
self
.
_hashless_allocator
.
free
(
block
)
return
self
.
_hashless_allocator
.
free
(
block
)
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
# If no longer used, add the block to the
unused cached blocks
.
# If no longer used, add the block to the
evictor
.
if
refcount
==
0
:
if
refcount
==
0
:
assert
block
.
content_hash
not
in
self
.
_unused_cached_blocks
assert
block
.
content_hash
in
self
.
_cached_blocks
assert
block
.
content_hash
in
self
.
_cached_blocks
self
.
_unused_cached_blocks
[
block
.
content_hash
]
=
block_id
assert
block
.
block_id
is
not
None
del
self
.
_blocks
[
block
.
block_id
]
self
.
evictor
.
add
(
block
.
block_id
,
block
.
content_hash
,
block
.
num_tokens_total
,
block
.
last_accessed
)
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
"""Creates a new sequence of blocks that shares the same underlying
"""Creates a new sequence of blocks that shares the same underlying
...
@@ -228,18 +278,21 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -228,18 +278,21 @@ class PrefixCachingBlockAllocator(BlockAllocator):
return
forked_blocks
return
forked_blocks
def
get_num_free_blocks
(
self
)
->
int
:
def
get_num_free_blocks
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
int
:
assert
device
is
None
# The number of free blocks is the number of hashless free blocks
# The number of free blocks is the number of hashless free blocks
# plus the number of hashful blocks that are unused.
# plus the number of blocks evictor could free from its list.
return
self
.
_hashless_allocator
.
get_num_free_blocks
()
+
len
(
return
self
.
_hashless_allocator
.
get_num_free_blocks
(
self
.
_unused_cached_blocks
)
)
+
self
.
evictor
.
num_blocks
def
get_num_total_blocks
(
self
)
->
int
:
return
self
.
_hashless_allocator
.
get_num_total_blocks
()
@
property
@
property
def
all_block_ids
(
self
)
->
f
rozen
s
et
[
int
]:
def
all_block_ids
(
self
)
->
F
rozen
S
et
[
int
]:
return
self
.
_hashless_allocator
.
all_block_ids
return
self
.
_hashless_allocator
.
all_block_ids
def
promote_to_immutable_block
(
self
,
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
block
:
"PrefixCachingBlock"
)
->
BlockId
:
"""Once a mutable block is full, it can be promoted to an immutable
"""Once a mutable block is full, it can be promoted to an immutable
block. This means that its content can be referenced by future blocks
block. This means that its content can be referenced by future blocks
having the same prefix.
having the same prefix.
...
@@ -249,7 +302,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -249,7 +302,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block.
block.
Args:
Args:
block
(PrefixCachingBlock)
: The mutable block to be promoted.
block: The mutable block to be promoted.
Returns:
Returns:
BlockId: Either the original block index, or the block index of
BlockId: Either the original block index, or the block index of
...
@@ -266,7 +319,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -266,7 +319,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
else
:
else
:
self
.
_free_block_id_for_block
(
block
.
block_id
,
block
)
self
.
_free_block_id_for_block
(
block
.
block_id
,
block
)
self
.
_incr_refcount_cached_block
(
self
.
_incr_refcount_cached_block
(
block
.
content_hash
,
self
.
_cached_blocks
[
block
.
content_hash
])
block
,
self
.
_cached_blocks
[
block
.
content_hash
])
return
self
.
_cached_blocks
[
block
.
content_hash
]
return
self
.
_cached_blocks
[
block
.
content_hash
]
...
@@ -293,29 +346,63 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -293,29 +346,63 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"""
"""
return
self
.
_cow_tracker
.
clear_cows
()
return
self
.
_cow_tracker
.
clear_cows
()
def
mark_blocks_as_computed
(
self
)
->
None
:
def
mark_blocks_as_accessed
(
self
,
block_ids
:
List
[
int
],
now
:
float
)
->
None
:
"""Mark blocks as accessed, used in prefix caching.
If the block is added into evictor, we need to update corresponding
info in evictor's metadata.
"""
for
block_id
in
block_ids
:
if
block_id
in
self
.
_blocks
:
self
.
_blocks
[
block_id
].
last_accessed
=
now
elif
block_id
in
self
.
evictor
:
self
.
evictor
.
update
(
block_id
,
now
)
else
:
raise
ValueError
(
"Mark block as accessed which is not belonged to GPU"
)
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
"""Mark blocks as computed, used in prefix caching."""
"""Mark blocks as computed, used in prefix caching."""
# TODO Track computed blocks.
pass
for
block_id
in
block_ids
:
if
block_id
in
self
.
_blocks
:
# only those full block is valid for prefix caching
if
self
.
_blocks
[
block_id
].
is_full
:
self
.
_blocks
[
block_id
].
computed
=
True
elif
block_id
not
in
self
.
evictor
:
raise
ValueError
(
f
"Mark
{
block_id
=
}
as computed which "
"is not belonged to GPU"
)
def
block_is_computed
(
self
,
block_id
:
int
)
->
bool
:
if
block_id
in
self
.
_blocks
:
return
self
.
_blocks
[
block_id
].
computed
else
:
return
block_id
in
self
.
evictor
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
"""Return the block ids that are common for a given sequence group.
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
Only those blocks that are immutable and already be marked
compyted would be taken consideration.
"""
"""
# TODO: Track computed blocks.
computed
=
lambda
block_id
:
False
# NOTE We exclude the last block to avoid the case where the entire
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# prompt is cached. This would cause erroneous behavior in model
# runner.
# runner.
ids_list
=
[
ids_list
=
[
takewhile
(
lambda
block_id
:
computed
(
block_id
),
seq
[:
-
1
])
list
(
for
seq
in
seq_block_ids
takewhile
(
lambda
block_id
:
self
.
block_is_computed
(
block_id
),
seq
[:
-
1
]))
for
seq
in
seq_block_ids
]
]
return
commonprefix
([
ids
for
ids
in
ids_list
if
ids
!=
[]])
# It returns a list of int although type annotation says list of string.
return
commonprefix
([
ids
for
ids
in
ids_list
# type: ignore
if
ids
!=
[]
])
class
PrefixCachingBlock
(
Block
):
class
PrefixCachingBlock
(
Block
):
...
@@ -332,7 +419,7 @@ class PrefixCachingBlock(Block):
...
@@ -332,7 +419,7 @@ class PrefixCachingBlock(Block):
token_ids (List[int]): The initial token IDs to be stored in the block.
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
block_size (int): The maximum number of token IDs that can be stored in
the block.
the block.
prefix_caching_allocator (
PrefixCaching
BlockAllocator): The prefix
prefix_caching_allocator (BlockAllocator): The prefix
caching block allocator associated with this block.
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
of this block. Defaults to None.
...
@@ -340,17 +427,25 @@ class PrefixCachingBlock(Block):
...
@@ -340,17 +427,25 @@ class PrefixCachingBlock(Block):
def
__init__
(
def
__init__
(
self
,
self
,
prev_block
:
Optional
[
"PrefixCaching
Block
"
],
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
block_size
:
int
,
block_size
:
int
,
prefix_caching_allocator
:
PrefixCaching
BlockAllocator
,
prefix_caching_allocator
:
BlockAllocator
,
block_id
:
Optional
[
int
]
=
None
,
block_id
:
Optional
[
int
]
=
None
,
computed
:
bool
=
False
,
):
):
assert
isinstance
(
prefix_caching_allocator
,
PrefixCachingBlockAllocator
),
(
"Currently this class is only tested with "
"PrefixCachingBlockAllocator."
)
assert_prefix_caching_block_or_none
(
prev_block
)
assert_prefix_caching_block_or_none
(
prev_block
)
self
.
_prev_block
=
prev_block
self
.
_prev_block
=
prev_block
self
.
_cached_content_hash
:
Optional
[
int
]
=
None
self
.
_cached_content_hash
:
Optional
[
int
]
=
None
self
.
_cached_num_tokens_total
:
Optional
[
int
]
=
None
self
.
_prefix_caching_allocator
=
prefix_caching_allocator
self
.
_prefix_caching_allocator
=
prefix_caching_allocator
self
.
_last_accessed
:
float
=
_DEFAULT_LAST_ACCESSED_TIME
self
.
_computed
=
computed
self
.
_block
=
NaiveBlock
(
self
.
_block
=
NaiveBlock
(
prev_block
=
prev_block
,
prev_block
=
prev_block
,
...
@@ -361,6 +456,22 @@ class PrefixCachingBlock(Block):
...
@@ -361,6 +456,22 @@ class PrefixCachingBlock(Block):
_cow_target
=
self
,
_cow_target
=
self
,
)
)
@
property
def
computed
(
self
)
->
bool
:
return
self
.
_computed
@
computed
.
setter
def
computed
(
self
,
value
)
->
None
:
self
.
_computed
=
value
@
property
def
last_accessed
(
self
)
->
float
:
return
self
.
_last_accessed
@
last_accessed
.
setter
def
last_accessed
(
self
,
last_accessed_ts
:
float
):
self
.
_last_accessed
=
last_accessed_ts
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
"""Appends the given token IDs to the block and registers the block as
"""Appends the given token IDs to the block and registers the block as
immutable if the block becomes full.
immutable if the block becomes full.
...
@@ -398,6 +509,27 @@ class PrefixCachingBlock(Block):
...
@@ -398,6 +509,27 @@ class PrefixCachingBlock(Block):
def
num_empty_slots
(
self
)
->
int
:
def
num_empty_slots
(
self
)
->
int
:
return
self
.
_block
.
num_empty_slots
return
self
.
_block
.
num_empty_slots
@
property
def
num_tokens_total
(
self
)
->
int
:
"""return the total tokens so far.
Here we iterate the block chain till to the first block, while
cache the result in local to prevent repeated computations.
"""
if
self
.
_cached_num_tokens_total
is
not
None
:
return
self
.
_cached_num_tokens_total
_block
:
Optional
[
Block
]
=
self
self
.
_cached_num_tokens_total
=
0
# TODO: current implement here take O(N^2), we expect future
# we have O(1) here
while
_block
is
not
None
:
self
.
_cached_num_tokens_total
+=
len
(
_block
.
token_ids
)
_block
=
_block
.
prev_block
return
self
.
_cached_num_tokens_total
@
property
@
property
def
block_size
(
self
)
->
int
:
def
block_size
(
self
)
->
int
:
return
self
.
_block
.
block_size
return
self
.
_block
.
block_size
...
@@ -428,8 +560,10 @@ class PrefixCachingBlock(Block):
...
@@ -428,8 +560,10 @@ class PrefixCachingBlock(Block):
return
None
return
None
is_first_block
=
self
.
_prev_block
is
None
is_first_block
=
self
.
_prev_block
is
None
prev_block_hash
=
(
None
if
is_first_block
else
prev_block_hash
=
(
self
.
_prev_block
.
content_hash
)
None
if
is_first_block
else
self
.
_prev_block
.
content_hash
# type: ignore
)
# Previous block exists but does not yet have a hash.
# Previous block exists but does not yet have a hash.
# Return no hash in this case.
# Return no hash in this case.
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment