Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99b471c2
Commit
99b471c2
authored
May 21, 2024
by
zhuwenwen
Browse files
merge v0.4.1
parents
1925d2e9
468d761b
Changes
336
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3108 additions
and
649 deletions
+3108
-649
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+39
-2
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+62
-31
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+377
-0
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+250
-0
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+89
-136
vllm/attention/layer.py
vllm/attention/layer.py
+6
-3
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+10
-11
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+27
-17
vllm/attention/ops/triton_flash_attention.py
vllm/attention/ops/triton_flash_attention.py
+812
-0
vllm/attention/selector.py
vllm/attention/selector.py
+50
-12
vllm/config.py
vllm/config.py
+424
-111
vllm/core/block/__init__.py
vllm/core/block/__init__.py
+0
-0
vllm/core/block/block_table.py
vllm/core/block/block_table.py
+52
-6
vllm/core/block/common.py
vllm/core/block/common.py
+4
-2
vllm/core/block/interfaces.py
vllm/core/block/interfaces.py
+18
-13
vllm/core/block_manager_v1.py
vllm/core/block_manager_v1.py
+33
-15
vllm/core/block_manager_v2.py
vllm/core/block_manager_v2.py
+62
-24
vllm/core/interfaces.py
vllm/core/interfaces.py
+13
-7
vllm/core/policy.py
vllm/core/policy.py
+1
-3
vllm/core/scheduler.py
vllm/core/scheduler.py
+779
-256
No files found.
vllm/attention/backends/abstract.py
View file @
99b471c2
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
fields
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
import
torch
...
...
@@ -47,7 +47,8 @@ class AttentionBackend(ABC):
@
dataclass
class
AttentionMetadata
:
class
AttentionMetadataPerStage
:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
def
asdict_zerocopy
(
self
)
->
Dict
[
str
,
Any
]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
...
...
@@ -59,6 +60,41 @@ class AttentionMetadata:
}
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadataPerStage
)
@
dataclass
class
AttentionMetadata
(
Generic
[
T
]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills
:
int
# Number of prefill tokens.
num_prefill_tokens
:
int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens
:
int
# The attention metadata for prefill requests in a batch.
# None if there's no prefill requests in a batch.
prefill_metadata
:
Optional
[
T
]
# The attention metadata for decode requests in a batch.
# None if there's no decode requests in a batch.
decode_metadata
:
Optional
[
T
]
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
# The kv cache's data type.
kv_cache_dtype
:
str
def
__post_init__
(
self
):
if
self
.
num_prefill_tokens
>
0
:
assert
self
.
num_prefills
>
0
assert
self
.
prefill_metadata
is
not
None
if
self
.
num_decode_tokens
>
0
:
assert
self
.
decode_metadata
is
not
None
class
AttentionImpl
(
ABC
):
@
abstractmethod
...
...
@@ -81,5 +117,6 @@ class AttentionImpl(ABC):
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
kv_scale
:
float
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
vllm/attention/backends/flash_attn.py
View file @
99b471c2
...
...
@@ -11,7 +11,8 @@ import torch
from
flash_attn
import
flash_attn_varlen_func
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionMetadataPerStage
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
...
...
@@ -53,7 +54,8 @@ class FlashAttentionBackend(AttentionBackend):
@
dataclass
class
FlashAttentionMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
class
FlashAttentionMetadata
(
AttentionMetadataPerStage
,
PagedAttentionMetadata
):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
...
...
@@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens
:
int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens
:
int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
...
...
@@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
class
FlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_pr
ompt
_tokens -------------->|
|<--pr
ompt
_0-->|<--pr
ompt
_1-->|...|<--pr
ompt
_N-1-->|
|<--------------- num_pr
efill
_tokens
---
-------------->|
|<--pr
efill
_0-->|<--pr
efill
_1-->|...|<--pr
efill
_N-1--
-
>|
Otherwise, the layout is as follows:
|<-----------------
-
num_
generation
_tokens
(M)
----------------->|
|<--
generation
_0-->|..........|<--
generation
_M-1-->|<--padding-->|
|<----------------- num_
decode
_tokens
-
----------------->|
|<--
decode
_0-->|..........|<--
decode
_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def
__init__
(
...
...
@@ -155,7 +162,8 @@ class FlashAttentionImpl(AttentionImpl):
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
attn_metadata
:
AttentionMetadata
[
FlashAttentionMetadata
],
kv_scale
:
float
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
...
...
@@ -184,55 +192,78 @@ class FlashAttentionImpl(AttentionImpl):
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
kv_cache_dtype
)
attn_metadata
.
kv_cache_dtype
,
kv_scale
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
if
attn_metadata
.
is_prompt
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
kv_cache
is
None
or
attn_metada
ta
.
block_tables
.
numel
()
==
0
:
if
kv_cache
is
None
or
prefill_me
ta
.
block_tables
.
numel
()
==
0
:
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
out
put
=
flash_attn_varlen_func
(
out
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
attn_metada
ta
.
seq_start_loc
,
cu_seqlens_k
=
attn_metada
ta
.
seq_start_loc
,
max_seqlen_q
=
attn_metada
ta
.
max_prompt_len
,
max_seqlen_k
=
attn_metada
ta
.
max_prompt_len
,
cu_seqlens_q
=
prefill_me
ta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_me
ta
.
seq_start_loc
,
max_seqlen_q
=
prefill_me
ta
.
max_prompt_len
,
max_seqlen_k
=
prefill_me
ta
.
max_prompt_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
# prefix-enabled attention
output
=
PagedAttention
.
forward_prefix
(
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
output
[:
num_prefill_tokens
]
=
PagedAttention
.
forward_prefix
(
query
,
key
,
value
,
key_cache
,
value_cache
,
attn_metada
ta
.
block_tables
,
attn_metada
ta
.
subquery_start_loc
,
attn_metada
ta
.
prompt_lens_tensor
,
attn_metada
ta
.
context_lens
,
attn_metada
ta
.
max_subquery_len
,
prefill_me
ta
.
block_tables
,
prefill_me
ta
.
subquery_start_loc
,
prefill_me
ta
.
prompt_lens_tensor
,
prefill_me
ta
.
context_lens
,
prefill_me
ta
.
max_subquery_len
,
self
.
alibi_slopes
,
)
else
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
=
PagedAttention
.
forward_decode
(
query
,
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
decode_
query
,
key_cache
,
value_cache
,
attn_metada
ta
.
block_tables
,
attn_metada
ta
.
context_lens
,
attn_metada
ta
.
max_context_len
,
decode_me
ta
.
block_tables
,
decode_me
ta
.
context_lens
,
decode_me
ta
.
max_context_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
kv_scale
,
)
# Reshape the output tensor.
...
...
vllm/attention/backends/rocm_flash_attn.py
0 → 100644
View file @
99b471c2
"""Attention layer ROCm GPUs."""
import
os
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadataPerStage
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_impl_cls
()
->
Type
[
"ROCmFlashAttentionImpl"
]:
return
ROCmFlashAttentionImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"ROCmFlashAttentionMetadata"
:
return
ROCmFlashAttentionMetadata
(
*
args
,
**
kwargs
)
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
class
ROCmFlashAttentionMetadata
(
AttentionMetadataPerStage
,
PagedAttentionMetadata
):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |- subquery_len -|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len
:
Optional
[
int
]
# Maximum prompt length in the batch.
max_prompt_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
class
ROCmFlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
suppored_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
suppored_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"True"
).
lower
()
in
(
"true"
,
"1"
))
if
self
.
use_triton_flash_attn
:
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
triton_attention
)
self
.
attn_func
=
triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
else
:
# if not using triton, navi3x not use flash-attn either
if
torch
.
cuda
.
get_device_capability
()[
0
]
==
11
:
self
.
use_naive_attn
=
True
else
:
try
:
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
attn_func
=
flash_attn_varlen_func
logger
.
debug
(
"Using CK FA in ROCmBackend"
)
except
ModuleNotFoundError
:
self
.
use_naive_attn
=
True
if
self
.
use_naive_attn
:
self
.
attn_func
=
_naive_attention
logger
.
debug
(
"Using naive attention in ROCmBackend"
)
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens
,
n_kv_heads
,
head_dim
=
x
.
shape
return
(
x
[:,
:,
None
,
:].
expand
(
tokens
,
n_kv_heads
,
n_rep
,
head_dim
).
reshape
(
tokens
,
n_kv_heads
*
n_rep
,
head_dim
))
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
[
ROCmFlashAttentionMetadata
],
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
kv_cache_dtype
,
kv_scale
,
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
assert
prefill_meta
.
prompt_lens
is
not
None
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if
self
.
use_triton_flash_attn
or
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Interleave for MQA workaround.
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
if
self
.
use_naive_attn
:
out
=
self
.
attn_func
(
query
,
key
,
value
,
prefill_meta
.
prompt_lens
,
self
.
scale
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
out
,
_
=
self
.
attn_func
(
query
,
key
,
value
,
None
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_prompt_len
,
prefill_meta
.
max_prompt_len
,
True
,
self
.
scale
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
out
=
self
.
attn_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prompt_len
,
max_seqlen_k
=
prefill_meta
.
max_prompt_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
# prefix-enabled attention
output
[:
num_prefill_tokens
]
=
PagedAttention
.
forward_prefix
(
query
,
key
,
value
,
key_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
subquery_start_loc
,
prefill_meta
.
prompt_lens_tensor
,
prefill_meta
.
context_lens
,
prefill_meta
.
max_subquery_len
,
self
.
alibi_slopes
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
decode_query
,
key_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
context_lens
,
decode_meta
.
max_context_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
kv_scale
,
)
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
def
_naive_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
prompt_lens
:
List
[
int
],
scale
:
float
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
start
=
0
for
_
,
prompt_len
in
enumerate
(
prompt_lens
):
end
=
start
+
prompt_len
out
=
_naive_masked_attention
(
query
[
start
:
end
],
key
[
start
:
end
],
value
[
start
:
end
],
scale
,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
)
start
+=
prompt_len
return
output
def
_naive_masked_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
)
->
torch
.
Tensor
:
seq_len
,
head_size
,
head_dim
=
query
.
shape
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
query
.
dtype
,
device
=
query
.
device
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
query
.
dtype
).
min
attn_weights
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
attn_weights
=
attn_weights
+
attn_mask
.
float
()
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
value
)
return
out
vllm/attention/backends/torch_sdpa.py
0 → 100644
View file @
99b471c2
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadataPerStage
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
class
TorchSDPABackend
(
AttentionBackend
):
@
staticmethod
def
get_impl_cls
()
->
Type
[
"TorchSDPABackendImpl"
]:
return
TorchSDPABackendImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"TorchSDPAMetadata"
:
return
TorchSDPAMetadata
(
*
args
,
**
kwargs
)
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
class
TorchSDPAMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
,
AttentionMetadataPerStage
):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
slot_mapping
:
torch
.
Tensor
prompt_lens
:
Optional
[
List
[
int
]]
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
class
TorchSDPABackendImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
sliding_window
if
alibi_slopes
is
not
None
:
assert
len
(
alibi_slopes
)
==
num_heads
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
or
self
.
sliding_window
is
not
None
)
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
suppored_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
suppored_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
kv_scale
:
float
,
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
kv_cache_dtype
,
kv_scale
)
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
prompt_lens
is
not
None
if
(
kv_cache
is
None
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
if
attn_metadata
.
attn_bias
is
None
:
if
self
.
alibi_slopes
is
not
None
:
att_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
prompt_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
prompt_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
att_masks
=
[
None
]
*
len
(
attn_metadata
.
prompt_lens
)
attn_metadata
.
attn_bias
=
att_masks
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
start
=
0
output
=
torch
.
empty
(
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
)
for
prompt_len
,
mask
in
zip
(
attn_metadata
.
prompt_lens
,
attn_metadata
.
attn_bias
):
end
=
start
+
prompt_len
sub_out
=
scaled_dot_product_attention
(
query
[:,
start
:
end
,
:],
key
[:,
start
:
end
,
:],
value
[:,
start
:
end
,
:],
attn_mask
=
mask
,
dropout_p
=
0.0
,
is_causal
=
not
self
.
need_mask
,
scale
=
self
.
scale
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start
:
end
,
:,
:]
=
sub_out
start
=
end
else
:
# prefix-enabled attention
raise
RuntimeError
(
"Torch SDPA backend doesn't support prefix decoding."
)
else
:
# Decoding run.
output
=
PagedAttention
.
forward_decode
(
query
,
key_cache
,
value_cache
,
attn_metadata
.
block_tables
,
attn_metadata
.
context_lens
,
attn_metadata
.
max_context_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
kv_scale
,
)
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
prompt_lens
:
List
[
int
],
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
prompt_len
in
prompt_lens
:
bias
=
torch
.
arange
(
prompt_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
inf_mask
=
torch
.
empty
(
(
1
,
prompt_len
,
prompt_len
),
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
return
attn_biases
def
_make_sliding_window_bias
(
prompt_lens
:
List
[
int
],
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
prompt_len
in
prompt_lens
:
tensor
=
torch
.
full
(
(
1
,
prompt_len
,
prompt_len
),
dtype
=
dtype
,
fill_value
=
1
,
)
shift
=
0
mask
=
torch
.
tril
(
tensor
,
diagonal
=
shift
).
to
(
dtype
)
# type: ignore
if
window_size
is
not
None
:
mask
=
torch
.
triu
(
mask
,
diagonal
=
shift
-
window_size
+
1
)
mask
=
torch
.
log
(
mask
)
attn_biases
.
append
(
mask
.
to
(
dtype
))
return
attn_biases
vllm/attention/backends/xformers.py
View file @
99b471c2
"""Attention layer with xFormers and PagedAttention."""
import
importlib
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
...
...
@@ -10,11 +9,11 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias
)
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
,
AttentionMetadataPerStage
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
logger
=
init_logger
(
__name__
)
...
...
@@ -56,7 +55,7 @@ class XFormersBackend(AttentionBackend):
@
dataclass
class
XFormersMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
class
XFormersMetadata
(
AttentionMetadata
PerStage
,
PagedAttentionMetadata
):
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
...
...
@@ -67,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens
:
Optional
[
List
[
int
]]
# prompt_lens stored as a tensor.
prompt_lens_tensor
:
Optional
[
torch
.
Tensor
]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens
:
int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens
:
int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
...
...
@@ -125,18 +115,27 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
class
XFormersImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_pr
ompt
_tokens --------------->|
|<--pr
ompt
_0-->|<--pr
ompt
_1-->|...|<--pr
ompt
_N-1--->|
|<--------------- num_pr
efill
_tokens
--
--------------->|
|<--pr
efill
_0-->|<--pr
efill
_1-->|...|<--pr
efill
_N-1--->|
Otherwise, the layout is as follows:
|<-----------------
-
num_
generation
_tokens
(M)
----------------->|
|<--
generation
_0-->|..........|<--
generation
_M-1-->|<--padding-->|
|<----------------- num_
decode
_tokens
-
----------------->|
|<--
decode
_0-->|..........|<--
decode
_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def
__init__
(
...
...
@@ -166,18 +165,14 @@ class XFormersImpl(AttentionImpl):
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
# AMD Radeon 7900 series (gfx1100) currently does not support xFormers
# nor FlashAttention. As a temporary workaround, we use naive PyTorch
# implementation of attention.
self
.
use_naive_attention
=
_check_use_naive_attention
()
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
XFormersMetadata
,
attn_metadata
:
AttentionMetadata
[
XFormersMetadata
],
kv_scale
:
float
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
...
...
@@ -205,86 +200,69 @@ class XFormersImpl(AttentionImpl):
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
kv_cache_dtype
)
attn_metadata
.
kv_cache_dtype
,
kv_scale
)
if
attn_metadata
.
is_prompt
:
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
kv_cache
is
None
or
attn_metada
ta
.
block_tables
.
numel
()
==
0
:
if
kv_cache
is
None
or
prefill_me
ta
.
block_tables
.
numel
()
==
0
:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
if
self
.
use_naive_attention
:
output
=
torch
.
empty_like
(
query
)
start
=
0
for
_
,
prompt_len
in
enumerate
(
attn_metadata
.
prompt_lens
):
end
=
start
+
prompt_len
out
=
_naive_masked_attention
(
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
self
.
num_heads
,
self
.
num_kv_heads
,
self
.
head_size
,
self
.
scale
,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
)
start
+=
prompt_len
# Using view got RuntimeError: view size is not compatible
# with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# Use reshape instead.
return
output
.
reshape
(
num_tokens
,
hidden_size
)
output
=
self
.
_run_memory_efficient_xformers_forward
(
query
,
key
,
value
,
attn_metadata
)
out
=
self
.
_run_memory_efficient_xformers_forward
(
query
,
key
,
value
,
prefill_meta
)
assert
out
.
shape
==
output
[:
num_prefill_tokens
].
shape
output
[:
num_prefill_tokens
]
=
out
else
:
# prefix-enabled attention
output
=
PagedAttention
.
forward_prefix
(
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
out
=
PagedAttention
.
forward_prefix
(
query
,
key
,
value
,
key_cache
,
value_cache
,
attn_metada
ta
.
block_tables
,
attn_metada
ta
.
subquery_start_loc
,
attn_metada
ta
.
prompt_lens_tensor
,
attn_metada
ta
.
context_lens
,
attn_metada
ta
.
max_subquery_len
,
prefill_me
ta
.
block_tables
,
prefill_me
ta
.
subquery_start_loc
,
prefill_me
ta
.
prompt_lens_tensor
,
prefill_me
ta
.
context_lens
,
prefill_me
ta
.
max_subquery_len
,
self
.
alibi_slopes
,
)
else
:
# Decoding run.
output
=
PagedAttention
.
forward_decode
(
query
,
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
decode_query
,
key_cache
,
value_cache
,
attn_metada
ta
.
block_tables
,
attn_metada
ta
.
context_lens
,
attn_metada
ta
.
max_context_len
,
decode_me
ta
.
block_tables
,
decode_me
ta
.
context_lens
,
decode_me
ta
.
max_context_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
kv_scale
,
)
# Reshape the output tensor.
...
...
@@ -300,13 +278,31 @@ class XFormersImpl(AttentionImpl):
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
See https://facebookresearch.github.io/xformers/components/ops.html
for API spec.
Args:
output: shape = [num_pr
ompt
_tokens, num_heads, head_size]
query: shape = [num_pr
ompt
_tokens, num_heads, head_size]
key: shape = [num_pr
ompt
_tokens, num_kv_heads, head_size]
value: shape = [num_pr
ompt
_tokens, num_kv_heads, head_size]
output: shape = [num_pr
efill
_tokens, num_heads, head_size]
query: shape = [num_pr
efill
_tokens, num_heads, head_size]
key: shape = [num_pr
efill
_tokens, num_kv_heads, head_size]
value: shape = [num_pr
efill
_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
assert
attn_metadata
.
prompt_lens
is
not
None
original_query
=
query
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# GQA/MQA requires the shape [B, M, G, H, K].
# Note that the output also has the same shape (which is different
# from a spec from the doc).
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
...
...
@@ -323,12 +319,11 @@ class XFormersImpl(AttentionImpl):
self
.
alibi_slopes
,
self
.
num_kv_heads
,
query
.
dtype
,
attn_metadata
.
prompt_lens
)
op
=
xops
.
fmha
.
MemoryEfficientAttentionFlashAttentionOp
[
0
]
if
(
is_hip
())
else
None
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if
self
.
alibi_slopes
is
None
:
# Add the batch dimension.
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
...
...
@@ -338,16 +333,14 @@ class XFormersImpl(AttentionImpl):
value
,
attn_bias
=
attn_metadata
.
attn_bias
[
0
],
p
=
0.0
,
scale
=
self
.
scale
,
op
=
op
)
return
out
.
view_as
(
query
)
scale
=
self
.
scale
)
return
out
.
view_as
(
original_query
)
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
original_
query
)
start
=
0
for
i
,
prompt_len
in
enumerate
(
attn_metadata
.
prompt_lens
):
end
=
start
+
prompt_len
...
...
@@ -357,10 +350,9 @@ class XFormersImpl(AttentionImpl):
value
[
None
,
start
:
end
],
attn_bias
=
attn_metadata
.
attn_bias
[
i
],
p
=
0.0
,
scale
=
self
.
scale
,
op
=
op
)
scale
=
self
.
scale
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
.
squeeze
(
0
))
output
[
start
:
end
].
copy_
(
out
.
view_as
(
original_query
[
start
:
end
]
))
start
+=
prompt_len
return
output
...
...
@@ -399,42 +391,3 @@ def _make_alibi_bias(
attn_biases
.
append
(
LowerTriangularMaskWithTensorBias
(
bias
))
return
attn_biases
def
_check_use_naive_attention
()
->
bool
:
if
not
is_hip
():
return
False
# For ROCm, check whether flash attention is installed or not.
use_naive_attention
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
None
if
use_naive_attention
:
logger
.
warning
(
"flash_attn is not installed. Using naive attention. "
"This will take significantly more GPU memory."
)
return
True
return
False
def
_naive_masked_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
scale
:
float
,
)
->
torch
.
Tensor
:
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
seq_len
,
_
,
_
=
query
.
shape
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
query
.
dtype
,
device
=
query
.
device
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
query
.
dtype
).
min
attn_weights
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
attn_weights
=
attn_weights
+
attn_mask
.
float
()
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
value
)
return
out
vllm/attention/layer.py
View file @
99b471c2
...
...
@@ -4,7 +4,8 @@ from typing import List, Optional
import
torch
import
torch.nn
as
nn
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
(
AttentionMetadata
,
AttentionMetadataPerStage
)
from
vllm.attention.selector
import
get_attn_backend
...
...
@@ -41,6 +42,8 @@ class Attention(nn.Module):
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
[
AttentionMetadataPerStage
],
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
)
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
kv_scale
)
vllm/attention/ops/paged_attn.py
View file @
99b471c2
...
...
@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
import
torch
from
vllm
._C
import
cache
_ops
,
ops
from
vllm
import
_custom
_ops
as
ops
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
...
...
@@ -13,11 +13,6 @@ _PARTITION_SIZE = 512
@
dataclass
class
PagedAttentionMetadata
:
"""Metadata for PagedAttention."""
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
...
...
@@ -31,7 +26,6 @@ class PagedAttentionMetadata:
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
kv_cache_dtype
:
str
class
PagedAttention
:
...
...
@@ -73,14 +67,16 @@ class PagedAttention:
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
)
->
None
:
cache_
ops
.
reshape_and_cache
(
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
(),
kv_cache_dtype
,
kv_scale
,
)
@
staticmethod
...
...
@@ -95,6 +91,7 @@ class PagedAttention:
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_scale
:
float
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
...
...
@@ -126,6 +123,7 @@ class PagedAttention:
max_context_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
)
else
:
# Run PagedAttention V2.
...
...
@@ -157,6 +155,7 @@ class PagedAttention:
max_context_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
)
return
output
...
...
@@ -200,11 +199,11 @@ class PagedAttention:
)
->
None
:
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
cache_
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_value_cache
=
src_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
cache_
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
...
...
@@ -213,4 +212,4 @@ class PagedAttention:
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
cache_
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
vllm/attention/ops/prefix_prefill.py
View file @
99b471c2
...
...
@@ -47,7 +47,8 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
# head size
BLOCK_DMODEL_PADDED
:
tl
.
constexpr
,
# head size padded to a power of 2
BLOCK_N
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
...
...
@@ -59,26 +60,30 @@ if triton.__version__ >= "2.1.0":
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_query_len
=
cur_batch_seq_len
-
cur_batch_ctx_len
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
_PADDED
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
)
dim_mask
=
tl
.
where
(
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
<
BLOCK_DMODEL
,
1
,
0
).
to
(
tl
.
int1
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
dim_mask
[
None
,
:]
&
(
offs_m
[:,
None
]
<
cur_batch_query_len
),
other
=
0.0
)
# # initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
_PADDED
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
0
,
cur_batch_ctx_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
...
...
@@ -99,7 +104,8 @@ if triton.__version__ >= "2.1.0":
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
mask
=
dim_mask
[:,
None
]
&
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
),
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
...
...
@@ -126,7 +132,8 @@ if triton.__version__ >= "2.1.0":
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
,
mask
=
dim_mask
[
None
,
:]
&
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
),
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
...
...
@@ -142,16 +149,15 @@ if triton.__version__ >= "2.1.0":
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
1
,
0
)
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_query_len
,
1
,
0
)
for
start_n
in
range
(
0
,
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
-
cur_batch_
ctx
_len
,
mask
=
dim_mask
[:,
None
]
&
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_
query
_len
)
,
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
...
...
@@ -179,8 +185,8 @@ if triton.__version__ >= "2.1.0":
# update acc
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
-
cur_batch_
ctx
_len
,
mask
=
dim_mask
[
None
,
:]
&
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_
query
_len
)
,
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
...
...
@@ -195,7 +201,8 @@ if triton.__version__ >= "2.1.0":
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
)
mask
=
dim_mask
[
None
,
:]
&
(
offs_m
[:,
None
]
<
cur_batch_query_len
))
return
@
triton
.
jit
...
...
@@ -636,7 +643,8 @@ if triton.__version__ >= "2.1.0":
# shape constraints
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lk
in
{
16
,
32
,
64
,
128
}
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded
=
2
**
((
Lk
-
1
).
bit_length
())
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
...
...
@@ -646,6 +654,7 @@ if triton.__version__ >= "2.1.0":
num_warps
=
8
if
Lk
<=
64
else
8
if
alibi_slopes
is
not
None
:
assert
Lk
==
Lk_padded
_fwd_kernel_alibi
[
grid
](
q
,
k
,
...
...
@@ -738,6 +747,7 @@ if triton.__version__ >= "2.1.0":
num_queries_per_kv
=
num_queries_per_kv
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
...
...
vllm/attention/ops/triton_flash_attention.py
0 → 100644
View file @
99b471c2
#!/usr/bin/env python
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
(https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims
"""
import
torch
import
triton
import
triton.language
as
tl
torch_dtype
:
tl
.
constexpr
=
torch
.
float16
@
triton
.
jit
def
cdiv_fn
(
x
,
y
):
return
(
x
+
y
-
1
)
//
y
@
triton
.
jit
def
max_fn
(
x
,
y
):
return
tl
.
math
.
max
(
x
,
y
)
@
triton
.
jit
def
dropout_offsets
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
):
ms
=
tl
.
arange
(
0
,
m
)
ns
=
tl
.
arange
(
0
,
n
)
return
philox_offset
+
ms
[:,
None
]
*
stride
+
ns
[
None
,
:]
@
triton
.
jit
def
dropout_rng
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
):
rng_offsets
=
dropout_offsets
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
).
to
(
tl
.
uint32
)
# TODO: use tl.randint for better performance
return
tl
.
rand
(
philox_seed
,
rng_offsets
)
@
triton
.
jit
def
dropout_mask
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
):
rng_output
=
dropout_rng
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
)
rng_keep
=
rng_output
>
dropout_p
return
rng_keep
@
triton
.
jit
def
load_fn
(
block_ptr
,
first
,
second
,
pad
):
if
first
and
second
:
tensor
=
tl
.
load
(
block_ptr
,
boundary_check
=
(
0
,
1
),
padding_option
=
pad
)
elif
first
:
tensor
=
tl
.
load
(
block_ptr
,
boundary_check
=
(
0
,
),
padding_option
=
pad
)
elif
second
:
tensor
=
tl
.
load
(
block_ptr
,
boundary_check
=
(
1
,
),
padding_option
=
pad
)
else
:
tensor
=
tl
.
load
(
block_ptr
)
return
tensor
@
triton
.
jit
def
_attn_fwd_inner
(
acc
,
l_i
,
m_i
,
q
,
K_block_ptr
,
V_block_ptr
,
start_m
,
actual_seqlen_k
,
dropout_p
,
philox_seed
,
batch_philox_offset
,
encoded_softmax_block_ptr
,
block_min
,
block_max
,
offs_n_causal
,
masked_blocks
,
n_extra_tokens
,
bias_ptr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
OFFS_M
:
tl
.
constexpr
,
OFFS_N
:
tl
.
constexpr
,
PRE_LOAD_V
:
tl
.
constexpr
,
MASK_STEPS
:
tl
.
constexpr
,
ENABLE_DROPOUT
:
tl
.
constexpr
,
RETURN_ENCODED_SOFTMAX
:
tl
.
constexpr
,
PADDED_HEAD
:
tl
.
constexpr
,
):
# loop over k, v, and update accumulator
for
start_n
in
range
(
block_min
,
block_max
,
BLOCK_N
):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k
=
load_fn
(
K_block_ptr
,
PADDED_HEAD
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
"zero"
,
)
if
PRE_LOAD_V
:
v
=
load_fn
(
V_block_ptr
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
PADDED_HEAD
,
"zero"
,
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if
MASK_STEPS
:
# noqa: SIM102
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
# if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case.
if
(
start_n
+
BLOCK_N
==
block_max
)
and
(
n_extra_tokens
!=
0
):
boundary_m
=
tl
.
full
([
BLOCK_M
],
actual_seqlen_k
,
dtype
=
tl
.
int32
)
size_n
=
start_n
+
OFFS_N
[
None
,
:]
mask
=
size_n
<
boundary_m
[:,
None
]
qk
=
tl
.
where
(
mask
,
qk
,
float
(
"-inf"
))
if
IS_CAUSAL
:
causal_boundary
=
start_n
+
offs_n_causal
causal_mask
=
OFFS_M
[:,
None
]
>=
causal_boundary
[
None
,
:]
qk
=
tl
.
where
(
causal_mask
,
qk
,
float
(
"-inf"
))
# -- compute qk ----
qk
+=
tl
.
dot
(
q
,
k
)
if
bias_ptr
is
not
None
:
bias
=
load_fn
(
bias_ptr
,
False
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
"zero"
)
# While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk
+=
bias
*
1.44269504089
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
=
qk
-
m_ij
[:,
None
]
p
=
tl
.
math
.
exp2
(
qk
)
# CAVEAT: Must update l_ij before applying dropout
l_ij
=
tl
.
sum
(
p
,
1
)
if
ENABLE_DROPOUT
:
philox_offset
=
(
batch_philox_offset
+
start_m
*
BLOCK_M
*
actual_seqlen_k
+
start_n
-
BLOCK_N
)
keep
=
dropout_mask
(
philox_seed
,
philox_offset
,
dropout_p
,
BLOCK_M
,
BLOCK_N
,
actual_seqlen_k
,
)
if
RETURN_ENCODED_SOFTMAX
:
tl
.
store
(
encoded_softmax_block_ptr
,
tl
.
where
(
keep
,
p
,
-
p
).
to
(
encoded_softmax_block_ptr
.
type
.
element_ty
),
)
p
=
tl
.
where
(
keep
,
p
,
0.0
)
elif
RETURN_ENCODED_SOFTMAX
:
tl
.
store
(
encoded_softmax_block_ptr
,
p
.
to
(
encoded_softmax_block_ptr
.
type
.
element_ty
),
)
# -- update output accumulator --
alpha
=
tl
.
math
.
exp2
(
m_i
-
m_ij
)
acc
=
acc
*
alpha
[:,
None
]
if
not
PRE_LOAD_V
:
v
=
load_fn
(
V_block_ptr
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
PADDED_HEAD
,
"zero"
,
)
# -- update m_i and l_i
l_i
=
l_i
*
alpha
+
l_ij
# update m_i and l_i
m_i
=
m_ij
acc
+=
tl
.
dot
(
p
.
to
(
V_block_ptr
.
type
.
element_ty
),
v
)
V_block_ptr
=
tl
.
advance
(
V_block_ptr
,
(
BLOCK_N
,
0
))
K_block_ptr
=
tl
.
advance
(
K_block_ptr
,
(
0
,
BLOCK_N
))
if
bias_ptr
is
not
None
:
bias_ptr
=
tl
.
advance
(
bias_ptr
,
(
0
,
BLOCK_N
))
if
RETURN_ENCODED_SOFTMAX
:
encoded_softmax_block_ptr
=
tl
.
advance
(
encoded_softmax_block_ptr
,
(
0
,
BLOCK_N
))
return
acc
,
l_i
,
m_i
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"waves_per_eu"
:
2
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"waves_per_eu"
:
2
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"waves_per_eu"
:
2
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"waves_per_eu"
:
3
,
"PRE_LOAD_V"
:
True
,
},
num_stages
=
1
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"waves_per_eu"
:
3
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
64
,
"waves_per_eu"
:
4
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
32
,
"BLOCK_N"
:
32
,
"waves_per_eu"
:
4
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
8
,
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton
.
Config
(
{
"BLOCK_M"
:
16
,
"BLOCK_N"
:
16
,
"waves_per_eu"
:
1
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
4
,
),
],
key
=
[
"hq"
,
"hk"
,
"IS_CAUSAL"
,
"dropout_p"
,
"BLOCK_DMODEL"
],
)
@
triton
.
jit
def
attn_fwd
(
Q
,
K
,
V
,
bias
,
sm_scale
,
L
,
Out
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
stride_oz
,
stride_oh
,
stride_om
,
stride_on
,
stride_bz
,
stride_bh
,
stride_bm
,
stride_bn
,
cu_seqlens_q
,
cu_seqlens_k
,
dropout_p
,
philox_seed
,
philox_offset_base
,
encoded_softmax
,
hq
,
hk
,
ACTUAL_BLOCK_DMODEL
:
tl
.
constexpr
,
MAX_SEQLENS_Q
:
tl
.
constexpr
,
MAX_SEQLENS_K
:
tl
.
constexpr
,
VARLEN
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
PRE_LOAD_V
:
tl
.
constexpr
,
BIAS_TYPE
:
tl
.
constexpr
,
ENABLE_DROPOUT
:
tl
.
constexpr
,
RETURN_ENCODED_SOFTMAX
:
tl
.
constexpr
,
):
start_m
=
tl
.
program_id
(
0
)
off_h_q
=
tl
.
program_id
(
1
)
off_z
=
tl
.
program_id
(
2
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
if
VARLEN
:
cu_seqlens_q_start
=
tl
.
load
(
cu_seqlens_q
+
off_z
)
cu_seqlens_q_end
=
tl
.
load
(
cu_seqlens_q
+
off_z
+
1
)
seqlen_q
=
cu_seqlens_q_end
-
cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if
start_m
*
BLOCK_M
>
seqlen_q
:
return
cu_seqlens_k_start
=
tl
.
load
(
cu_seqlens_k
+
off_z
)
cu_seqlens_k_end
=
tl
.
load
(
cu_seqlens_k
+
off_z
+
1
)
seqlen_k
=
cu_seqlens_k_end
-
cu_seqlens_k_start
else
:
cu_seqlens_q_start
=
0
cu_seqlens_k_start
=
0
seqlen_q
=
MAX_SEQLENS_Q
seqlen_k
=
MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks
=
cdiv_fn
(
seqlen_k
,
BLOCK_N
)
if
IS_CAUSAL
:
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn
# matrix
n_blocks_seqlen
=
cdiv_fn
(
(
start_m
+
1
)
*
BLOCK_M
+
seqlen_k
-
seqlen_q
,
BLOCK_N
)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks
=
min
(
n_blocks
,
n_blocks_seqlen
)
# If we have no blocks after adjusting for seqlen deltas, this WG is
# part of the blocks that are all 0. We exit early.
if
n_blocks
<=
0
:
o_offset
=
(
off_z
*
stride_oz
+
cu_seqlens_q_start
*
stride_om
+
off_h_q
*
stride_oh
)
O_block_ptr
=
tl
.
make_block_ptr
(
base
=
Out
+
o_offset
,
shape
=
(
seqlen_q
,
BLOCK_DMODEL
),
strides
=
(
stride_om
,
stride_on
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
),
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
Out
.
type
.
element_ty
)
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0
# for these masked blocks.
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
# tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here?
return
is_mqa
=
hq
!=
hk
if
is_mqa
:
# noqa: SIM108
off_h_k
=
off_h_q
%
hk
else
:
off_h_k
=
off_h_q
n_extra_tokens
=
0
if
seqlen_k
<
BLOCK_N
:
n_extra_tokens
=
BLOCK_N
-
seqlen_k
elif
seqlen_k
%
BLOCK_N
:
n_extra_tokens
=
seqlen_k
%
BLOCK_N
padded_head
=
ACTUAL_BLOCK_DMODEL
!=
BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel.
q_offset
=
(
off_z
*
stride_qz
+
off_h_q
*
stride_qh
+
cu_seqlens_q_start
*
stride_qm
)
Q_block_ptr
=
tl
.
make_block_ptr
(
base
=
Q
+
q_offset
,
shape
=
(
seqlen_q
,
ACTUAL_BLOCK_DMODEL
),
strides
=
(
stride_qm
,
stride_qk
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
),
)
k_offset
=
(
off_z
*
stride_kz
+
off_h_k
*
stride_kh
+
cu_seqlens_k_start
*
stride_kn
)
K_block_ptr
=
tl
.
make_block_ptr
(
base
=
K
+
k_offset
,
shape
=
(
ACTUAL_BLOCK_DMODEL
,
seqlen_k
),
strides
=
(
stride_kk
,
stride_kn
),
offsets
=
(
0
,
0
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_N
),
order
=
(
0
,
1
),
)
v_offset
=
(
off_z
*
stride_vz
+
off_h_k
*
stride_vh
+
cu_seqlens_k_start
*
stride_vk
)
V_block_ptr
=
tl
.
make_block_ptr
(
base
=
V
+
v_offset
,
shape
=
(
seqlen_k
,
ACTUAL_BLOCK_DMODEL
),
strides
=
(
stride_vk
,
stride_vn
),
offsets
=
(
0
,
0
),
block_shape
=
(
BLOCK_N
,
BLOCK_DMODEL
),
order
=
(
1
,
0
),
)
if
BIAS_TYPE
!=
0
:
bias_ptr
=
tl
.
make_block_ptr
(
base
=
bias
+
off_h_q
*
stride_bh
,
shape
=
(
seqlen_q
,
seqlen_k
),
strides
=
(
stride_bm
,
stride_bn
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_N
),
order
=
(
1
,
0
),
)
else
:
bias_ptr
=
None
if
ENABLE_DROPOUT
:
batch_philox_offset
=
philox_offset_base
\
+
(
off_z
*
hq
+
off_h_q
)
\
*
seqlen_q
*
seqlen_k
else
:
batch_philox_offset
=
0
# We can ask to return the dropout mask without actually doing any dropout.
# In this case, we return an invalid pointer so indicate the mask is not i
# valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if
RETURN_ENCODED_SOFTMAX
:
encoded_softmax_block_ptr
=
tl
.
make_block_ptr
(
base
=
encoded_softmax
+
off_h_q
*
seqlen_q
*
seqlen_k
,
shape
=
(
seqlen_q
,
seqlen_k
),
strides
=
(
seqlen_k
,
1
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_N
),
order
=
(
1
,
0
),
)
else
:
encoded_softmax_block_ptr
=
0
# initialize pointer to m and l
m_i
=
tl
.
full
([
BLOCK_M
],
float
(
"-inf"
),
dtype
=
tl
.
float32
)
l_i
=
tl
.
full
([
BLOCK_M
],
1.0
,
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale
=
sm_scale
*
1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q
=
load_fn
(
Q_block_ptr
,
True
,
padded_head
,
"zero"
)
q
=
(
q
*
qk_scale
).
to
(
Q_block_ptr
.
type
.
element_ty
)
# Here we compute how many full and masked blocks we have.
padded_block_k
=
n_extra_tokens
!=
0
is_modulo_mn
=
not
padded_block_k
and
(
seqlen_q
%
BLOCK_M
==
0
)
if
IS_CAUSAL
:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks
=
BLOCK_M
//
BLOCK_N
+
(
not
is_modulo_mn
)
else
:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks
=
padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
# block. In this case we might exceed n_blocks so pick the min.
masked_blocks
=
min
(
masked_blocks
,
n_blocks
)
n_full_blocks
=
n_blocks
-
masked_blocks
block_min
=
0
block_max
=
n_blocks
*
BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its
# value because there is no masking. Similarly we do not need padding.
if
n_full_blocks
>
0
:
block_max
=
(
n_blocks
-
masked_blocks
)
*
BLOCK_N
acc
,
l_i
,
m_i
=
_attn_fwd_inner
(
acc
,
l_i
,
m_i
,
q
,
K_block_ptr
,
V_block_ptr
,
start_m
,
seqlen_k
,
dropout_p
,
philox_seed
,
batch_philox_offset
,
encoded_softmax_block_ptr
,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min
,
block_max
,
0
,
0
,
0
,
bias_ptr
,
# IS_CAUSAL, ....
False
,
BLOCK_M
,
BLOCK_DMODEL
,
BLOCK_N
,
offs_m
,
offs_n
,
# _, MASK_STEPS, ...
PRE_LOAD_V
,
False
,
ENABLE_DROPOUT
,
RETURN_ENCODED_SOFTMAX
,
padded_head
,
)
block_min
=
block_max
block_max
=
n_blocks
*
BLOCK_N
tl
.
debug_barrier
()
# Remaining blocks, if any, are full / not masked.
if
masked_blocks
>
0
:
offs_n_causal
=
offs_n
+
(
seqlen_q
-
seqlen_k
)
if
IS_CAUSAL
else
0
K_block_ptr
=
tl
.
advance
(
K_block_ptr
,
(
0
,
n_full_blocks
*
BLOCK_N
))
V_block_ptr
=
tl
.
advance
(
V_block_ptr
,
(
n_full_blocks
*
BLOCK_N
,
0
))
if
bias_ptr
is
not
None
:
bias_ptr
=
tl
.
advance
(
bias_ptr
,
(
0
,
n_full_blocks
*
BLOCK_N
))
if
RETURN_ENCODED_SOFTMAX
:
encoded_softmax_block_ptr
=
tl
.
advance
(
encoded_softmax_block_ptr
,
(
0
,
n_full_blocks
))
acc
,
l_i
,
m_i
=
_attn_fwd_inner
(
acc
,
l_i
,
m_i
,
q
,
K_block_ptr
,
V_block_ptr
,
start_m
,
seqlen_k
,
dropout_p
,
philox_seed
,
batch_philox_offset
,
encoded_softmax_block_ptr
,
block_min
,
block_max
,
offs_n_causal
,
masked_blocks
,
n_extra_tokens
,
bias_ptr
,
IS_CAUSAL
,
BLOCK_M
,
BLOCK_DMODEL
,
BLOCK_N
,
offs_m
,
offs_n
,
# _, MASK_STEPS, ...
PRE_LOAD_V
,
True
,
ENABLE_DROPOUT
,
RETURN_ENCODED_SOFTMAX
,
padded_head
,
)
# epilogue
acc
=
acc
/
l_i
[:,
None
]
if
ENABLE_DROPOUT
:
acc
=
acc
/
(
1
-
dropout_p
)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx
=
(
start_m
+
1
)
*
BLOCK_M
start_m_idx
=
start_m
*
BLOCK_M
causal_start_idx
=
seqlen_q
-
seqlen_k
acc
=
acc
.
to
(
Out
.
type
.
element_ty
)
if
IS_CAUSAL
:
# noqa: SIM102
if
causal_start_idx
>
start_m_idx
and
causal_start_idx
<
end_m_idx
:
out_mask_boundary
=
tl
.
full
((
BLOCK_DMODEL
,
),
causal_start_idx
,
dtype
=
tl
.
int32
)
mask_m_offsets
=
start_m_idx
+
tl
.
arange
(
0
,
BLOCK_M
)
out_ptrs_mask
=
(
mask_m_offsets
[:,
None
]
>=
out_mask_boundary
[
None
,
:])
z
=
0.0
acc
=
tl
.
where
(
out_ptrs_mask
,
acc
,
z
.
to
(
acc
.
type
.
element_ty
))
# write back LSE
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
# overflow_size = end_m_idx - seqlen_q
# if overflow_size > 0:
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# # This is a > check because mask being 0 blocks the store.
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
# else:
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset
=
(
off_z
*
stride_oz
+
cu_seqlens_q_start
*
stride_om
+
off_h_q
*
stride_oh
)
O_block_ptr
=
tl
.
make_block_ptr
(
base
=
Out
+
o_offset
,
shape
=
(
seqlen_q
,
ACTUAL_BLOCK_DMODEL
),
strides
=
(
stride_om
,
stride_on
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
),
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl
.
store
(
O_block_ptr
,
acc
,
boundary_check
=
(
0
,
1
))
def
check_args
(
q
,
k
,
v
,
o
,
varlen
=
True
,
max_seqlens
=
None
,
cu_seqlens_q
=
None
,
cu_seqlens_k
=
None
,
):
assert
q
.
dim
()
==
k
.
dim
()
and
q
.
dim
()
==
v
.
dim
()
if
varlen
:
assert
q
.
dim
()
==
3
total_q
,
nheads_q
,
head_size
=
q
.
shape
total_k
,
nheads_k
,
_
=
k
.
shape
assert
cu_seqlens_q
is
not
None
assert
cu_seqlens_k
is
not
None
assert
len
(
cu_seqlens_q
)
==
len
(
cu_seqlens_k
)
else
:
assert
q
.
dim
()
==
4
batch
,
nheads_q
,
seqlen_q
,
head_size
=
q
.
shape
_
,
nheads_k
,
seqlen_k
,
_
=
k
.
shape
assert
max_seqlens
>
0
assert
k
.
shape
==
v
.
shape
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
and
q
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
# TODO: Change assert if we support qkl f8 and v f16
assert
q
.
dtype
==
k
.
dtype
and
q
.
dtype
==
v
.
dtype
assert
head_size
<=
256
assert
o
.
shape
==
q
.
shape
assert
(
nheads_q
%
nheads_k
)
==
0
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
o
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
causal
=
False
,
sm_scale
=
1.0
,
bias
=
None
,
):
if
o
is
None
:
o
=
torch
.
empty_like
(
q
,
dtype
=
v
.
dtype
)
check_args
(
q
,
k
,
v
,
o
,
varlen
=
True
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
)
if
True
:
# varlen
total_q
,
nheads_q
,
head_size
=
q
.
shape
total_k
,
nheads_k
,
_
=
k
.
shape
batch
=
len
(
cu_seqlens_q
)
-
1
q_strides
=
(
0
,
q
.
stride
(
1
),
q
.
stride
(
0
),
q
.
stride
(
2
))
k_strides
=
(
0
,
k
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
))
v_strides
=
(
0
,
v
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
))
o_strides
=
(
0
,
o
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
2
))
else
:
batch
,
seqlen_q
,
nheads_q
,
head_size
=
q
.
shape
_
,
seqlen_k
,
nheads_k
,
_
=
k
.
shape
q_strides
=
(
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
q
.
stride
(
3
))
k_strides
=
(
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
k
.
stride
(
3
))
v_strides
=
(
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
v
.
stride
(
3
))
o_strides
=
(
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
o
.
stride
(
3
))
# Get closest power of 2 over or equal to 32.
unpadded_head_dims
=
{
32
,
64
,
128
,
256
}
if
head_size
not
in
unpadded_head_dims
:
padded_d_model
=
None
for
i
in
unpadded_head_dims
:
if
i
>
head_size
:
padded_d_model
=
i
break
assert
padded_d_model
is
not
None
else
:
padded_d_model
=
head_size
grid
=
lambda
META
:
(
triton
.
cdiv
(
max_seqlens_q
,
META
[
"BLOCK_M"
]),
nheads_q
,
batch
,
)
encoded_softmax
=
None
# Seed the RNG so we get reproducible results for testing.
philox_seed
=
0x1BF52
philox_offset
=
0x1D4B42
if
bias
is
not
None
:
bias_strides
=
(
bias
.
stride
(
0
),
bias
.
stride
(
1
),
bias
.
stride
(
2
),
bias
.
stride
(
3
),
)
else
:
bias_strides
=
(
0
,
0
,
0
,
0
)
attn_fwd
[
grid
](
q
,
k
,
v
,
bias
,
sm_scale
,
None
,
o
,
*
q_strides
,
*
k_strides
,
*
v_strides
,
*
o_strides
,
*
bias_strides
,
cu_seqlens_q
,
cu_seqlens_k
,
dropout_p
=
0.0
,
philox_seed
=
philox_seed
,
philox_offset_base
=
philox_offset
,
encoded_softmax
=
encoded_softmax
,
hq
=
nheads_q
,
hk
=
nheads_k
,
ACTUAL_BLOCK_DMODEL
=
head_size
,
MAX_SEQLENS_Q
=
max_seqlens_q
,
MAX_SEQLENS_K
=
max_seqlens_k
,
IS_CAUSAL
=
causal
,
VARLEN
=
True
,
BLOCK_DMODEL
=
padded_d_model
,
BIAS_TYPE
=
0
if
bias
is
None
else
1
,
ENABLE_DROPOUT
=
False
,
RETURN_ENCODED_SOFTMAX
=
False
,
)
ctx
.
grid
=
grid
ctx
.
sm_scale
=
sm_scale
ctx
.
BLOCK_DMODEL
=
head_size
ctx
.
causal
=
causal
ctx
.
dropout_p
=
0.0
ctx
.
philox_seed
=
philox_seed
ctx
.
philox_offset
=
philox_offset
ctx
.
encoded_softmax
=
encoded_softmax
ctx
.
return_encoded_softmax
=
False
return
o
,
encoded_softmax
triton_attention
=
_attention
.
apply
vllm/attention/selector.py
View file @
99b471c2
import
enum
import
os
from
functools
import
lru_cache
from
typing
import
Type
...
...
@@ -5,45 +7,81 @@ import torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_cpu
,
is_hip
logger
=
init_logger
(
__name__
)
VLLM_ATTENTION_BACKEND
=
"VLLM_ATTENTION_BACKEND"
class
_Backend
(
enum
.
Enum
):
FLASH_ATTN
=
enum
.
auto
()
XFORMERS
=
enum
.
auto
()
ROCM_FLASH
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
@
lru_cache
(
maxsize
=
None
)
def
get_attn_backend
(
dtype
:
torch
.
dtype
)
->
Type
[
AttentionBackend
]:
if
_can_use_flash_attn
(
dtype
):
backend
=
_which_attn_to_use
(
dtype
)
if
backend
==
_Backend
.
FLASH_ATTN
:
logger
.
info
(
"Using FlashAttention backend."
)
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
FlashAttentionBackend
)
return
FlashAttentionBackend
el
se
:
el
if
backend
==
_Backend
.
XFORMERS
:
logger
.
info
(
"Using XFormers backend."
)
from
vllm.attention.backends.xformers
import
(
# noqa: F401
XFormersBackend
)
return
XFormersBackend
elif
backend
==
_Backend
.
ROCM_FLASH
:
logger
.
info
(
"Using ROCmFlashAttention backend."
)
from
vllm.attention.backends.rocm_flash_attn
import
(
# noqa: F401
ROCmFlashAttentionBackend
)
return
ROCmFlashAttentionBackend
elif
backend
==
_Backend
.
TORCH_SDPA
:
logger
.
info
(
"Using Torch SDPA backend."
)
from
vllm.attention.backends.torch_sdpa
import
TorchSDPABackend
return
TorchSDPABackend
else
:
raise
ValueError
(
"Invalid attention backend."
)
def
_can_use_flash_attn
(
dtype
:
torch
.
dtype
)
->
bool
:
def
_which_attn_to_use
(
dtype
:
torch
.
dtype
)
->
_Backend
:
"""Returns which flash attention backend to use."""
if
is_cpu
():
return
_Backend
.
TORCH_SDPA
if
is_hip
():
# AMD GPUs.
logger
.
info
(
"Cannot use FlashAttention backend for AMD GPUs."
)
return
False
if
torch
.
cuda
.
get_device_capability
()[
0
]
!=
9
:
# not Instinct series GPUs.
logger
.
info
(
"flash_atten is not supported on NAVI GPUs."
)
return
_Backend
.
ROCM_FLASH
# NVIDIA GPUs.
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
# Volta and Turing NVIDIA GPUs.
logger
.
info
(
"Cannot use FlashAttention backend for Volta and Turing "
"GPUs."
)
return
False
return
_Backend
.
XFORMERS
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
logger
.
info
(
"Cannot use FlashAttention backend for dtype other than "
"torch.float16 or torch.bfloat16."
)
return
False
return
_Backend
.
XFORMERS
try
:
import
flash_attn
# noqa: F401
except
ImportError
:
logger
.
info
(
"Cannot use FlashAttention because the package is not found. "
"Please install it for better performance."
)
return
False
return
True
"Cannot use FlashAttention backend because the flash_attn package "
"is not found. Please install it for better performance."
)
return
_Backend
.
XFORMERS
backend_by_env_var
=
os
.
getenv
(
VLLM_ATTENTION_BACKEND
)
if
backend_by_env_var
is
not
None
:
return
_Backend
[
backend_by_env_var
]
# Default case.
return
_Backend
.
FLASH_ATTN
vllm/config.py
View file @
99b471c2
import
enum
import
json
import
os
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Optional
,
Union
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
TYPE_CHECKING
,
ClassVar
,
List
,
Optional
,
Union
import
torch
from
packaging.version
import
Version
from
transformers
import
PretrainedConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.utils
import
get_cpu_memory
,
get_nvcc_cuda_version
,
is_hip
,
is_neuron
from
vllm.utils
import
(
get_cpu_memory
,
get_nvcc_cuda_version
,
is_cpu
,
is_hip
,
is_neuron
)
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
vllm.model_executor.model_loader.loader
import
BaseModelLoader
logger
=
init_logger
(
__name__
)
# If true, will load models from ModelScope instead of Hugging Face Hub.
VLLM_USE_MODELSCOPE
=
os
.
environ
.
get
(
"VLLM_USE_MODELSCOPE"
,
"False"
).
lower
()
==
"true"
_GB
=
1
<<
30
...
...
@@ -30,18 +38,6 @@ class ModelConfig:
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
...
...
@@ -59,12 +55,19 @@ class ModelConfig:
output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
quantization_param_path: Path to JSON file containing scaling factors.
Used to load KV cache scaling factors into the model when KV cache
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
model dtype is FP8_E4M3 on ROCm.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
"""
def
__init__
(
...
...
@@ -73,8 +76,6 @@ class ModelConfig:
tokenizer
:
str
,
tokenizer_mode
:
str
,
trust_remote_code
:
bool
,
download_dir
:
Optional
[
str
],
load_format
:
str
,
dtype
:
Union
[
str
,
torch
.
dtype
],
seed
:
int
,
revision
:
Optional
[
str
]
=
None
,
...
...
@@ -82,40 +83,26 @@ class ModelConfig:
tokenizer_revision
:
Optional
[
str
]
=
None
,
max_model_len
:
Optional
[
int
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
quantization_param_path
:
Optional
[
str
]
=
None
,
enforce_eager
:
bool
=
False
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
5
,
skip_tokenizer_init
:
bool
=
False
,
)
->
None
:
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer_mode
=
tokenizer_mode
self
.
trust_remote_code
=
trust_remote_code
self
.
download_dir
=
download_dir
self
.
load_format
=
load_format
self
.
seed
=
seed
self
.
revision
=
revision
self
.
code_revision
=
code_revision
self
.
tokenizer_revision
=
tokenizer_revision
self
.
quantization
=
quantization
self
.
quantization_param_path
=
quantization_param_path
self
.
enforce_eager
=
enforce_eager
self
.
max_context_len_to_capture
=
max_context_len_to_capture
self
.
max_logprobs
=
max_logprobs
if
os
.
environ
.
get
(
"VLLM_USE_MODELSCOPE"
,
"False"
).
lower
()
==
"true"
:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from
modelscope.hub.snapshot_download
import
snapshot_download
if
not
os
.
path
.
exists
(
model
):
model_path
=
snapshot_download
(
model_id
=
model
,
cache_dir
=
download_dir
,
revision
=
revision
)
else
:
model_path
=
model
self
.
model
=
model_path
self
.
download_dir
=
model_path
self
.
tokenizer
=
model_path
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
code_revision
)
...
...
@@ -123,39 +110,11 @@ class ModelConfig:
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
hf_text_config
,
max_model_len
)
self
.
_verify_load_format
()
self
.
_verify_tokenizer_mode
()
if
not
self
.
skip_tokenizer_init
:
self
.
_verify_tokenizer_mode
()
self
.
_verify_quantization
()
self
.
_verify_cuda_graph
()
def
_verify_load_format
(
self
)
->
None
:
load_format
=
self
.
load_format
.
lower
()
supported_load_format
=
[
"auto"
,
"pt"
,
"safetensors"
,
"npcache"
,
"dummy"
]
rocm_not_supported_load_format
=
[]
if
load_format
not
in
supported_load_format
:
raise
ValueError
(
f
"Unknown load format:
{
self
.
load_format
}
. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'."
)
if
is_hip
()
and
load_format
in
rocm_not_supported_load_format
:
rocm_supported_load_format
=
[
f
for
f
in
supported_load_format
if
(
f
not
in
rocm_not_supported_load_format
)
]
raise
ValueError
(
f
"load format '
{
load_format
}
' is not supported in ROCm. "
f
"Supported load format are "
f
"
{
rocm_supported_load_format
}
"
)
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
if
"MixtralForCausalLM"
in
architectures
and
load_format
==
"pt"
:
raise
ValueError
(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. "
)
self
.
load_format
=
load_format
def
_verify_tokenizer_mode
(
self
)
->
None
:
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
]:
...
...
@@ -165,32 +124,34 @@ class ModelConfig:
self
.
tokenizer_mode
=
tokenizer_mode
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
"awq"
,
"gptq"
,
"squeezellm"
,
"marlin"
]
rocm_
not_
supported_quantization
=
[
"
aw
q"
,
"
marlin
"
]
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"
gpt
q"
,
"
squeezellm
"
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
# Parse quantization method from the HF model config, if available.
hf_quant_config
=
getattr
(
self
.
hf_config
,
"quantization_config"
,
None
)
if
hf_quant_config
is
not
None
:
hf_quant_method
=
str
(
hf_quant_config
[
"quant_method"
]).
lower
()
# If the GPTQ model is serialized in marlin format, use marlin.
if
(
hf_quant_method
==
"gptq"
and
"is_marlin_format"
in
hf_quant_config
and
hf_quant_config
[
"is_marlin_format"
]):
quant_cfg
=
getattr
(
self
.
hf_config
,
"quantization_config"
,
None
)
if
quant_cfg
is
not
None
:
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_format_marlin
=
(
quant_cfg
.
get
(
"checkpoint_format"
)
==
"marlin"
or
quant_cfg
.
get
(
"is_marlin_format"
,
False
))
# Use marlin if the GPTQ model is serialized in marlin format.
if
quant_method
==
"gptq"
and
is_format_marlin
:
logger
.
info
(
"The model is serialized in Marlin format. "
"Using Marlin kernel."
)
hf_
quant_method
=
"marlin"
quant_method
=
"marlin"
if
self
.
quantization
==
"gptq"
:
self
.
quantization
=
hf_
quant_method
self
.
quantization
=
quant_method
if
self
.
quantization
is
None
:
self
.
quantization
=
hf_
quant_method
elif
self
.
quantization
!=
hf_
quant_method
:
self
.
quantization
=
quant_method
elif
self
.
quantization
!=
quant_method
:
raise
ValueError
(
"Quantization method specified in the model config "
f
"(
{
hf_
quant_method
}
) does not match the quantization "
f
"(
{
quant_method
}
) does not match the quantization "
f
"method specified in the `quantization` argument "
f
"(
{
self
.
quantization
}
)."
)
...
...
@@ -200,7 +161,7 @@ class ModelConfig:
f
"Unknown quantization method:
{
self
.
quantization
}
. Must "
f
"be one of
{
supported_quantization
}
."
)
if
is_hip
(
)
and
self
.
quantization
in
rocm_
not_
supported_quantization
:
)
and
self
.
quantization
not
in
rocm_supported_quantization
:
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
...
...
@@ -324,7 +285,7 @@ class CacheConfig:
vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage.
forced_
num_gpu_blocks: Number of GPU blocks to use. This overrides the
num_gpu_blocks
_override
: Number of GPU blocks to use. This overrides the
profiled num_gpu_blocks if specified. Does nothing if None.
"""
...
...
@@ -334,14 +295,14 @@ class CacheConfig:
gpu_memory_utilization
:
float
,
swap_space
:
int
,
cache_dtype
:
str
,
forced_
num_gpu_blocks
:
Optional
[
int
]
=
None
,
num_gpu_blocks
_override
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
enable_prefix_caching
:
bool
=
False
,
)
->
None
:
self
.
block_size
=
block_size
self
.
gpu_memory_utilization
=
gpu_memory_utilization
self
.
swap_space_bytes
=
swap_space
*
_GB
self
.
forced_
num_gpu_blocks
=
forced_
num_gpu_blocks
self
.
num_gpu_blocks
_override
=
num_gpu_blocks
_override
self
.
cache_dtype
=
cache_dtype
self
.
sliding_window
=
sliding_window
self
.
enable_prefix_caching
=
enable_prefix_caching
...
...
@@ -366,21 +327,20 @@ class CacheConfig:
def
_verify_cache_dtype
(
self
)
->
None
:
if
self
.
cache_dtype
==
"auto"
:
pass
elif
self
.
cache_dtype
==
"fp8_e5m2"
:
if
is_hip
():
raise
NotImplementedError
(
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet."
)
nvcc_cuda_version
=
get_nvcc_cuda_version
()
if
nvcc_cuda_version
and
nvcc_cuda_version
<
Version
(
"11.8"
):
raise
ValueError
(
"FP8 is not supported when cuda version is lower than 11.8."
)
elif
self
.
cache_dtype
==
"fp8"
:
if
not
is_hip
():
nvcc_cuda_version
=
get_nvcc_cuda_version
()
if
nvcc_cuda_version
<
Version
(
"11.8"
):
raise
ValueError
(
"FP8 is not supported when cuda version is"
"lower than 11.8."
)
logger
.
info
(
"Using fp8_e5m2 data type to store kv cache. It reduces "
"the GPU memory footprint and boosts the performance. "
"But it may cause slight accuracy drop. "
"Currently we only support fp8 without scaling factors and "
"make e5m2 as a default format."
)
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"But it may cause slight accuracy drop without scaling "
"factors. FP8_E5M2 (without scaling) is only supported on "
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
"is instead supported for common inference criteria."
)
else
:
raise
ValueError
(
f
"Unknown kv cache dtype:
{
self
.
cache_dtype
}
"
)
...
...
@@ -406,7 +366,7 @@ class CacheConfig:
@
dataclass
class
TokenizerPoolConfig
:
"""Configuration for the tokenizer pool.
Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
...
...
@@ -430,9 +390,9 @@ class TokenizerPoolConfig:
tokenizer_pool_extra_config
:
Optional
[
Union
[
str
,
dict
]]
)
->
Optional
[
"TokenizerPoolConfig"
]:
"""Create a TokenizerPoolConfig from the given parameters.
If tokenizer_pool_size is 0, return None.
Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
...
...
@@ -455,6 +415,65 @@ class TokenizerPoolConfig:
return
tokenizer_pool_config
class
LoadFormat
(
str
,
enum
.
Enum
):
AUTO
=
"auto"
PT
=
"pt"
SAFETENSORS
=
"safetensors"
NPCACHE
=
"npcache"
DUMMY
=
"dummy"
TENSORIZER
=
"tensorizer"
@
dataclass
class
LoadConfig
:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"""
load_format
:
Union
[
str
,
LoadFormat
,
"BaseModelLoader"
]
=
LoadFormat
.
AUTO
download_dir
:
Optional
[
str
]
=
None
model_loader_extra_config
:
Optional
[
Union
[
str
,
dict
]]
=
field
(
default_factory
=
dict
)
def
__post_init__
(
self
):
model_loader_extra_config
=
self
.
model_loader_extra_config
or
{}
if
isinstance
(
model_loader_extra_config
,
str
):
self
.
model_loader_extra_config
=
json
.
loads
(
model_loader_extra_config
)
self
.
_verify_load_format
()
def
_verify_load_format
(
self
)
->
None
:
if
not
isinstance
(
self
.
load_format
,
str
):
return
load_format
=
self
.
load_format
.
lower
()
self
.
load_format
=
LoadFormat
(
load_format
)
rocm_not_supported_load_format
:
List
[
str
]
=
[]
if
is_hip
()
and
load_format
in
rocm_not_supported_load_format
:
rocm_supported_load_format
=
[
f
for
f
in
LoadFormat
.
__members__
if
(
f
not
in
rocm_not_supported_load_format
)
]
raise
ValueError
(
f
"load format '
{
load_format
}
' is not supported in ROCm. "
f
"Supported load formats are "
f
"
{
rocm_supported_load_format
}
"
)
class
ParallelConfig
:
"""Configuration for the distributed execution.
...
...
@@ -530,9 +549,13 @@ class SchedulerConfig:
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
num_lookahead_slots: The number of slots to allocate per sequence per
step, beyond the known token ids. This is used in speculative
decoding to store KV activations of tokens which may or may not be
accepted.
delay_factor: Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt.
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
"""
...
...
@@ -543,24 +566,35 @@ class SchedulerConfig:
max_num_seqs
:
int
,
max_model_len
:
int
,
use_v2_block_manager
:
bool
=
False
,
num_lookahead_slots
:
int
=
0
,
delay_factor
:
float
=
0.0
,
enable_chunked_prefill
:
bool
=
False
,
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
else
:
# If max_model_len is too short, use 2048 as the default value for
# higher throughput.
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
if
enable_chunked_prefill
:
# For chunked prefill, choose the well-tuned batch size.
self
.
max_num_batched_tokens
=
768
else
:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
if
enable_chunked_prefill
:
logger
.
info
(
"Chunked prefill is enabled (EXPERIMENTAL)."
)
self
.
max_num_seqs
=
max_num_seqs
self
.
max_model_len
=
max_model_len
self
.
delay_factor
=
delay_factor
self
.
use_v2_block_manager
=
use_v2_block_manager
self
.
num_lookahead_slots
=
num_lookahead_slots
self
.
delay_factor
=
delay_factor
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
if
self
.
max_num_batched_tokens
<
self
.
max_model_len
:
if
(
self
.
max_num_batched_tokens
<
self
.
max_model_len
and
not
self
.
chunked_prefill_enabled
):
raise
ValueError
(
f
"max_num_batched_tokens (
{
self
.
max_num_batched_tokens
}
) is "
f
"smaller than max_model_len (
{
self
.
max_model_len
}
). "
...
...
@@ -568,12 +602,19 @@ class SchedulerConfig:
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len."
)
if
self
.
max_num_batched_tokens
<
self
.
max_num_seqs
:
raise
ValueError
(
f
"max_num_batched_tokens (
{
self
.
max_num_batched_tokens
}
) must "
"be greater than or equal to max_num_seqs "
f
"(
{
self
.
max_num_seqs
}
)."
)
if
self
.
num_lookahead_slots
<
0
:
raise
ValueError
(
"num_lookahead_slots "
f
"(
{
self
.
num_lookahead_slots
}
) must be greater than or "
"equal to 0."
)
class
DeviceConfig
:
...
...
@@ -582,6 +623,8 @@ class DeviceConfig:
# Automated device type detection
if
is_neuron
():
self
.
device_type
=
"neuron"
elif
is_cpu
():
self
.
device_type
=
"cpu"
else
:
# We don't call torch.cuda.is_available() here to
# avoid initializing CUDA before workers are forked
...
...
@@ -598,6 +641,223 @@ class DeviceConfig:
self
.
device
=
torch
.
device
(
self
.
device_type
)
class
SpeculativeConfig
:
"""Configuration for speculative decoding.
The configuration is currently specialized to draft-model speculative
decoding with top-1 proposals.
"""
@
staticmethod
def
maybe_create_spec_config
(
target_model_config
:
ModelConfig
,
target_parallel_config
:
ParallelConfig
,
target_dtype
:
str
,
speculative_model
:
Optional
[
str
],
num_speculative_tokens
:
Optional
[
int
],
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
use_v2_block_manager
:
bool
,
)
->
Optional
[
"SpeculativeConfig"
]:
"""Create a SpeculativeConfig if possible, else return None.
This function attempts to create a SpeculativeConfig object based on the
provided parameters. If the necessary conditions are met, it returns an
instance of SpeculativeConfig. Otherwise, it returns None.
Args:
target_model_config (ModelConfig): The configuration of the target
model.
target_parallel_config (ParallelConfig): The parallel configuration
for the target model.
target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative
model, if provided.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided.
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
speculation for some sequences.
enable_chunked_prefill (bool): Whether vLLM is configured to use
chunked prefill or not. Used for raising an error since its not
yet compatible with spec decode.
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
"""
if
(
speculative_model
is
None
and
num_speculative_tokens
is
None
):
return
None
if
speculative_model
is
not
None
and
num_speculative_tokens
is
None
:
raise
ValueError
(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f
"
{
speculative_model
=
}
and
{
num_speculative_tokens
=
}
."
)
assert
(
speculative_model
is
not
None
and
num_speculative_tokens
is
not
None
)
if
enable_chunked_prefill
:
raise
ValueError
(
"Speculative decoding and chunked prefill are "
f
"currently mutually exclusive (
{
enable_chunked_prefill
=
}
)."
)
if
not
use_v2_block_manager
:
raise
ValueError
(
"Speculative decoding requires usage of the V2 "
"block manager. Enable it with --use-v2-block-manager."
)
# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
draft_revision
=
None
draft_code_revision
=
None
draft_quantization
=
None
draft_model_config
=
ModelConfig
(
model
=
speculative_model
,
tokenizer
=
target_model_config
.
tokenizer
,
tokenizer_mode
=
target_model_config
.
tokenizer_mode
,
trust_remote_code
=
target_model_config
.
trust_remote_code
,
dtype
=
target_model_config
.
dtype
,
seed
=
target_model_config
.
seed
,
revision
=
draft_revision
,
code_revision
=
draft_code_revision
,
tokenizer_revision
=
target_model_config
.
tokenizer_revision
,
max_model_len
=
None
,
quantization
=
draft_quantization
,
enforce_eager
=
target_model_config
.
enforce_eager
,
max_context_len_to_capture
=
target_model_config
.
max_context_len_to_capture
,
max_logprobs
=
target_model_config
.
max_logprobs
,
)
draft_model_config
.
max_model_len
=
(
SpeculativeConfig
.
_maybe_override_draft_max_model_len
(
speculative_max_model_len
,
draft_model_config
.
max_model_len
,
target_model_config
.
max_model_len
,
))
draft_parallel_config
=
(
SpeculativeConfig
.
create_draft_parallel_config
(
target_parallel_config
))
return
SpeculativeConfig
(
draft_model_config
,
draft_parallel_config
,
num_speculative_tokens
,
)
@
staticmethod
def
_maybe_override_draft_max_model_len
(
speculative_max_model_len
:
Optional
[
int
],
draft_max_model_len
:
int
,
target_max_model_len
:
int
,
)
->
int
:
"""Determine the max sequence len for the draft model. This is usually
the draft_max_model_len, but may be the target_max_model_len if it is
less than the draft_max_model_len, or may be speculative_max_model_len
if it is specified.
This is necessary so that sequences do not exceed the capacity of the
draft model or the target model.
speculative_max_model_len is mainly used for testing that sequences can
skip speculation.
"""
if
speculative_max_model_len
is
not
None
:
if
speculative_max_model_len
>
draft_max_model_len
:
raise
ValueError
(
f
"
{
speculative_max_model_len
=
}
cannot be "
f
"larger than
{
draft_max_model_len
=
}
"
)
if
speculative_max_model_len
>
target_max_model_len
:
raise
ValueError
(
f
"
{
speculative_max_model_len
=
}
cannot be "
f
"larger than
{
target_max_model_len
=
}
"
)
return
speculative_max_model_len
return
min
(
draft_max_model_len
,
target_max_model_len
,
)
@
staticmethod
def
create_draft_parallel_config
(
target_parallel_config
:
ParallelConfig
)
->
ParallelConfig
:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config. In the future the
draft worker can have a different parallel strategy, e.g. TP=1.
"""
draft_parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
target_parallel_config
.
pipeline_parallel_size
,
tensor_parallel_size
=
target_parallel_config
.
tensor_parallel_size
,
worker_use_ray
=
target_parallel_config
.
worker_use_ray
,
max_parallel_loading_workers
=
target_parallel_config
.
max_parallel_loading_workers
,
disable_custom_all_reduce
=
target_parallel_config
.
disable_custom_all_reduce
,
tokenizer_pool_config
=
target_parallel_config
.
tokenizer_pool_config
,
ray_workers_use_nsight
=
target_parallel_config
.
ray_workers_use_nsight
,
placement_group
=
target_parallel_config
.
placement_group
,
)
return
draft_parallel_config
def
__init__
(
self
,
draft_model_config
:
ModelConfig
,
draft_parallel_config
:
ParallelConfig
,
num_speculative_tokens
:
int
,
):
"""Create a SpeculativeConfig object.
Args:
draft_model_config: ModelConfig for the draft model.
draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model.
"""
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
if
self
.
num_speculative_tokens
<=
0
:
raise
ValueError
(
"Expected num_speculative_tokens to be greater "
f
"than zero (
{
self
.
num_speculative_tokens
}
)."
)
if
self
.
draft_model_config
:
self
.
draft_model_config
.
verify_with_parallel_config
(
self
.
draft_parallel_config
)
@
property
def
num_lookahead_slots
(
self
)
->
int
:
"""The number of additional slots the scheduler should allocate per
step, in addition to the slots allocated for each known token.
This is equal to the number of speculative tokens, as each speculative
token must be scored.
"""
return
self
.
num_speculative_tokens
def
__repr__
(
self
)
->
str
:
draft_model
=
self
.
draft_model_config
.
model
num_spec_tokens
=
self
.
num_speculative_tokens
return
f
"SpeculativeConfig(
{
draft_model
=
}
,
{
num_spec_tokens
=
}
)"
@
dataclass
class
LoRAConfig
:
max_lora_rank
:
int
...
...
@@ -634,9 +894,12 @@ class LoRAConfig:
self
.
lora_dtype
=
model_config
.
dtype
elif
isinstance
(
self
.
lora_dtype
,
str
):
self
.
lora_dtype
=
getattr
(
torch
,
self
.
lora_dtype
)
if
model_config
.
quantization
is
not
None
:
raise
ValueError
(
"LoRA is not supported with quantized models yet."
)
if
model_config
.
quantization
and
model_config
.
quantization
not
in
[
"awq"
,
"gptq"
]:
# TODO support marlin and squeezellm
logger
.
warning
(
f
"
{
model_config
.
quantization
}
quantization is not "
"tested with LoRA yet."
)
def
verify_with_scheduler_config
(
self
,
scheduler_config
:
SchedulerConfig
):
if
scheduler_config
.
max_num_batched_tokens
>
65528
:
...
...
@@ -802,7 +1065,7 @@ def _get_and_verify_max_len(
derived_max_model_len
*=
scaling_factor
if
max_model_len
is
None
:
max_model_len
=
derived_max_model_len
max_model_len
=
int
(
derived_max_model_len
)
elif
max_model_len
>
derived_max_model_len
:
# Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input
...
...
@@ -819,3 +1082,53 @@ def _get_and_verify_max_len(
"to incorrect model outputs or CUDA errors. Make sure the "
"value is correct and within the model context size."
)
return
int
(
max_model_len
)
@
dataclass
class
DecodingConfig
:
"""Dataclass which contains the decoding strategy of the engine"""
# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
guided_decoding_backend
:
str
=
'outlines'
def
__post_init__
(
self
):
valid_guided_backends
=
[
'outlines'
,
'lm-format-enforcer'
]
backend
=
self
.
guided_decoding_backend
if
backend
not
in
valid_guided_backends
:
raise
ValueError
(
f
"Invalid guided_decoding_backend '
{
backend
}
,"
f
"must be one of
{
valid_guided_backends
}
"
)
@
dataclass
(
frozen
=
True
)
class
EngineConfig
:
"""Dataclass which contains all engine-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
model_config
:
ModelConfig
cache_config
:
CacheConfig
parallel_config
:
ParallelConfig
scheduler_config
:
SchedulerConfig
device_config
:
DeviceConfig
load_config
:
LoadConfig
lora_config
:
Optional
[
LoRAConfig
]
vision_language_config
:
Optional
[
VisionLanguageConfig
]
speculative_config
:
Optional
[
SpeculativeConfig
]
decoding_config
:
Optional
[
DecodingConfig
]
def
__post_init__
(
self
):
"""Verify configs are valid & consistent with each other.
"""
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
if
self
.
lora_config
:
self
.
lora_config
.
verify_with_model_config
(
self
.
model_config
)
self
.
lora_config
.
verify_with_scheduler_config
(
self
.
scheduler_config
)
def
to_dict
(
self
):
"""Return the configs as a dictionary, for use in **kwargs.
"""
return
dict
(
(
field
.
name
,
getattr
(
self
,
field
.
name
))
for
field
in
fields
(
self
))
vllm/core/block/__init__.py
0 → 100644
View file @
99b471c2
vllm/core/block/block_table.py
View file @
99b471c2
...
...
@@ -85,7 +85,9 @@ class BlockTable:
device
=
device
)
self
.
_num_full_slots
=
len
(
token_ids
)
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
append_token_ids
(
self
,
token_ids
:
List
[
int
],
num_lookahead_slots
:
int
=
0
)
->
None
:
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
...
...
@@ -102,14 +104,13 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert
self
.
_is_allocated
assert
self
.
_blocks
is
not
None
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
))
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
)
+
num_lookahead_slots
)
blocks
=
self
.
_blocks
[
self
.
_num_full_slots
//
self
.
_block_size
:]
first_chunk_size
=
self
.
_block_size
-
(
self
.
_num_full_slots
%
self
.
_block_size
)
token_blocks
=
[
token_ids
[:
first_chunk_size
]]
+
chunk_list
(
token_ids
[
first_chunk_size
:],
self
.
_block_size
)
token_blocks
=
self
.
_chunk_token_blocks_for_append
(
token_ids
)
for
block
,
token_block
in
zip
(
blocks
,
token_blocks
):
block
.
append_token_ids
(
token_block
)
...
...
@@ -195,6 +196,25 @@ class BlockTable:
assert
self
.
_is_allocated
return
[
block
.
block_id
for
block
in
self
.
_blocks
]
def
get_unseen_token_ids
(
self
,
sequence_token_ids
:
List
[
int
])
->
List
[
int
]:
"""Get the number of "unseen" tokens in the sequence.
Unseen tokens are tokens in the sequence corresponding to this block
table, but are not yet appended to this block table.
Args:
sequence_token_ids (List[int]): The list of token ids in the
sequence.
Returns:
List[int]: The postfix of sequence_token_ids that has not yet been
appended to the block table.
"""
# Since the block table is append-only, the unseen token ids are the
# ones after the appended ones.
return
sequence_token_ids
[
self
.
num_full_slots
:]
def
_allocate_blocks_for_token_ids
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Device
)
->
List
[
Block
]:
...
...
@@ -243,3 +263,29 @@ class BlockTable:
int: The total number of tokens currently stored in the BlockTable.
"""
return
self
.
_num_full_slots
def
get_num_blocks_touched_by_append_slots
(
self
,
token_ids
:
List
[
int
],
num_lookahead_slots
:
int
)
->
int
:
"""Determine how many blocks will be "touched" by appending the token
ids.
This is required for the scheduler to determine whether a sequence can
continue generation, or if it must be preempted.
"""
all_token_ids
=
token_ids
+
[
-
1
]
*
num_lookahead_slots
token_blocks
=
self
.
_chunk_token_blocks_for_append
(
all_token_ids
)
return
len
(
token_blocks
)
def
_chunk_token_blocks_for_append
(
self
,
token_ids
:
List
[
int
])
->
List
[
List
[
int
]]:
"""Split the token ids into block-sized chunks so they can be easily
appended to blocks. The first such "token block" may have less token ids
than the block size, since the last allocated block may be partially
full.
"""
first_chunk_size
=
self
.
_block_size
-
(
self
.
_num_full_slots
%
self
.
_block_size
)
token_blocks
=
[
token_ids
[:
first_chunk_size
]]
+
chunk_list
(
token_ids
[
first_chunk_size
:],
self
.
_block_size
)
return
token_blocks
vllm/core/block/common.py
View file @
99b471c2
...
...
@@ -99,7 +99,7 @@ class CopyOnWriteTracker:
refcounter
:
RefCounter
,
allocator
:
BlockAllocator
,
):
self
.
_copy_on_writes
=
defaultdict
(
list
)
self
.
_copy_on_writes
:
Dict
[
BlockId
,
List
[
BlockId
]]
=
defaultdict
(
list
)
self
.
_refcounter
=
refcounter
self
.
_allocator
=
allocator
...
...
@@ -138,6 +138,8 @@ class CopyOnWriteTracker:
prev_block
=
block
.
prev_block
).
block_id
# Track src/dst copy.
assert
src_block_id
is
not
None
assert
block_id
is
not
None
self
.
_copy_on_writes
[
src_block_id
].
append
(
block_id
)
return
block_id
...
...
@@ -180,6 +182,6 @@ def get_all_blocks_recursively(last_block: Block) -> List[Block]:
recurse
(
block
.
prev_block
,
lst
)
lst
.
append
(
block
)
all_blocks
=
[]
all_blocks
:
List
[
Block
]
=
[]
recurse
(
last_block
,
all_blocks
)
return
all_blocks
vllm/core/block/interfaces.py
View file @
99b471c2
from
abc
import
ABC
,
abstractmethod
,
abstractproperty
from
typing
import
Dict
,
List
,
Optional
,
Protocol
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
FrozenSet
,
List
,
Optional
,
Protocol
from
vllm.utils
import
Device
...
...
@@ -10,23 +10,28 @@ class Block(ABC):
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
pass
@
abstractproperty
@
property
@
abstractmethod
def
block_id
(
self
)
->
Optional
[
int
]:
pass
@
abstractproperty
@
property
@
abstractmethod
def
token_ids
(
self
)
->
List
[
int
]:
pass
@
abstractproperty
@
property
@
abstractmethod
def
num_empty_slots
(
self
)
->
int
:
pass
@
abstractproperty
@
property
@
abstractmethod
def
is_full
(
self
)
->
bool
:
pass
@
abstractproperty
@
property
@
abstractmethod
def
prev_block
(
self
)
->
Optional
[
"Block"
]:
pass
...
...
@@ -52,7 +57,7 @@ class BlockAllocator(ABC):
@
abstractmethod
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
])
->
Block
:
token_ids
:
List
[
int
]
,
device
:
Device
)
->
Block
:
pass
@
abstractmethod
...
...
@@ -64,11 +69,12 @@ class BlockAllocator(ABC):
pass
@
abstractmethod
def
get_num_free_blocks
(
self
)
->
int
:
def
get_num_free_blocks
(
self
,
device
:
Device
)
->
int
:
pass
@
abstractproperty
def
all_block_ids
(
self
)
->
frozenset
[
int
]:
@
property
@
abstractmethod
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]:
pass
@
abstractmethod
...
...
@@ -91,8 +97,7 @@ class BlockAllocator(ABC):
class
DeviceAwareBlockAllocator
(
BlockAllocator
):
@
abstractmethod
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
])
->
Block
:
pass
@
abstractmethod
...
...
vllm/core/block_manager_v1.py
View file @
99b471c2
...
...
@@ -2,7 +2,9 @@
from
abc
import
ABC
,
abstractmethod
from
itertools
import
count
,
takewhile
from
os.path
import
commonprefix
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
from
vllm.block
import
BlockTable
,
PhysicalTokenBlock
from
vllm.core.evictor
import
EvictionPolicy
,
Evictor
,
make_evictor
...
...
@@ -231,10 +233,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if
self
.
enable_caching
:
logger
.
info
(
"Automatic prefix caching is enabled."
)
self
.
gpu_allocator
=
CachedBlockAllocator
(
Device
.
GPU
,
block_size
,
num_gpu_blocks
)
self
.
cpu_allocator
=
CachedBlockAllocator
(
Device
.
CPU
,
block_size
,
num_cpu_blocks
)
self
.
gpu_allocator
:
BlockAllocatorBase
=
CachedBlockAllocator
(
Device
.
GPU
,
block_size
,
num_gpu_blocks
)
self
.
cpu_allocator
:
BlockAllocatorBase
=
CachedBlockAllocator
(
Device
.
CPU
,
block_size
,
num_cpu_blocks
)
else
:
self
.
gpu_allocator
=
UncachedBlockAllocator
(
Device
.
GPU
,
block_size
,
num_gpu_blocks
)
...
...
@@ -292,7 +294,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
def
can_append_slot
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_append_slots
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
=
0
)
->
bool
:
assert
(
num_lookahead_slots
==
0
),
"lookahead allocation not supported in BlockSpaceManagerV1"
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
num_free_gpu_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
...
...
@@ -323,7 +330,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self
,
seq
:
Sequence
,
)
->
bool
:
token_ids_len
=
len
(
seq
.
data
.
get_
token_ids
()
)
token_ids_len
=
seq
.
data
.
get_
len
()
return
token_ids_len
>
0
and
token_ids_len
%
seq
.
block_size
==
0
def
_maybe_promote_last_block
(
...
...
@@ -364,10 +371,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
assert
new_block
.
ref_count
==
1
return
new_block
def
append_slot
(
def
append_slot
s
(
self
,
seq
:
Sequence
,
)
->
Optional
[
Tuple
[
int
,
int
]]:
num_lookahead_slots
:
int
=
0
,
)
->
Dict
[
int
,
List
[
int
]]:
"""Allocate a physical slot for a new token."""
logical_blocks
=
seq
.
logical_token_blocks
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
...
@@ -386,7 +394,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Allocate a new physical block.
new_block
=
self
.
_allocate_last_physical_block
(
seq
)
block_table
.
append
(
new_block
)
return
None
return
{}
# We want to append the token to the last physical block.
last_block
=
block_table
[
-
1
]
...
...
@@ -399,7 +407,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
maybe_new_block
=
self
.
_maybe_promote_last_block
(
seq
,
last_block
)
block_table
[
-
1
]
=
maybe_new_block
return
None
return
{}
else
:
# The last block is shared with other sequences.
# Copy on Write: Allocate a new block and copy the tokens.
...
...
@@ -407,7 +415,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
block_table
[
-
1
]
=
new_block
self
.
gpu_allocator
.
free
(
last_block
)
return
last_block
.
block_number
,
new_block
.
block_number
return
{
last_block
.
block_number
:
[
new_block
.
block_number
]}
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
# NOTE: fork does not allocate a new physical block.
...
...
@@ -433,7 +441,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
blocks
.
update
(
self
.
block_tables
[
seq
.
seq_id
])
return
list
(
blocks
)
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
=
0
)
->
bool
:
assert
(
num_lookahead_slots
==
0
),
"BlockSpaceManagerV1 does not support lookahead allocation"
blocks
=
self
.
_get_physical_blocks
(
seq_group
)
num_swapped_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
SWAPPED
)
num_free_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
...
...
@@ -443,7 +455,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
num_required_blocks
=
len
(
blocks
)
+
num_swapped_seqs
return
num_free_blocks
-
num_required_blocks
>=
self
.
watermark_blocks
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
=
0
)
->
Dict
[
int
,
int
]:
assert
(
num_lookahead_slots
==
0
),
"BlockSpaceManagerV1 does not support lookahead allocation"
# CPU block -> GPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
...
...
@@ -573,7 +590,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for
b
in
takewhile
(
lambda
b
:
b
.
computed
,
block_table
[:
-
1
])
]
def
get_common_computed_block_ids
(
self
,
seqs
:
List
[
Sequence
])
->
List
[
int
]:
def
get_common_computed_block_ids
(
self
,
seqs
:
List
[
Sequence
])
->
GenericSequence
[
int
]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
...
...
vllm/core/block_manager_v2.py
View file @
99b471c2
"""A block manager that manages token blocks."""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
vllm.core.block.block_table
import
BlockTable
from
vllm.core.block.cpu_gpu_block_allocator
import
CpuGpuBlockAllocator
...
...
@@ -21,6 +22,24 @@ class BlockSpaceManagerV2(BlockSpaceManager):
sliding-window are not feature complete. This class implements the design
described in https://github.com/vllm-project/vllm/pull/3492.
Lookahead slots
The block manager has the notion of a "lookahead slot". These are slots
in the KV cache that are allocated for a sequence. Unlike the other
allocated slots, the content of these slots is undefined -- the worker
may use the memory allocations in any way.
In practice, a worker could use these lookahead slots to run multiple
forward passes for a single scheduler invocation. Each successive
forward pass would write KV activations to the corresponding lookahead
slot. This allows low inter-token latency use-cases, where the overhead
of continuous batching scheduling is amortized over >1 generated tokens.
Speculative decoding uses lookahead slots to store KV activations of
proposal tokens.
See https://github.com/vllm-project/vllm/pull/3250 for more information
on lookahead scheduling.
Args:
block_size (int): The size of each memory block.
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
...
...
@@ -116,35 +135,51 @@ class BlockSpaceManagerV2(BlockSpaceManager):
for
seq
in
waiting_seqs
[
1
:]:
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
fork
()
def
can_append_slot
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
def
can_append_slots
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
bool
:
"""Determine if there is enough space in the GPU KV cache to continue
generation of the specified sequence group.
We use a worst-case heuristic: assume each touched block will require a
new allocation (either via CoW or new block). We can append slots if the
number of touched blocks is less than the number of free blocks.
"Lookahead slots" are slots that are allocated in addition to the slots
for known tokens. The contents of the lookahead slots are not defined.
This is used by speculative decoding when speculating future tokens.
"""
num_touched_blocks
=
0
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
num_touched_blocks
+=
(
block_table
.
get_num_blocks_touched_by_append_slots
(
token_ids
=
block_table
.
get_unseen_token_ids
(
seq
.
get_token_ids
()),
num_lookahead_slots
=
num_lookahead_slots
,
))
num_free_gpu_blocks
=
self
.
block_allocator
.
get_num_free_blocks
(
Device
.
GPU
)
num_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
return
num_seqs
<=
num_free_gpu_blocks
return
num_touched_blocks
<=
num_free_gpu_blocks
def
append_slot
(
def
append_slot
s
(
self
,
seq
:
Sequence
,
)
->
Optional
[
Tuple
[
int
,
int
]]:
num_lookahead_slots
:
int
,
)
->
Dict
[
int
,
List
[
int
]]:
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
# Get unseen token ids.
num_full_slots
=
block_table
.
num_full_slots
unseen_token_ids
=
seq
.
get_token_ids
()[
num_full_slots
:]
assert
unseen_token_ids
block_table
.
append_token_ids
(
unseen_token_ids
)
# Return any copy-on-writes.
_
=
self
.
block_allocator
.
clear_copy_on_writes
()
# TODO extend append_slot interface to append_slots
# @cadedaniel will do in https://github.com/vllm-project/vllm/pull/3250
block_table
.
append_token_ids
(
token_ids
=
block_table
.
get_unseen_token_ids
(
seq
.
get_token_ids
()),
num_lookahead_slots
=
num_lookahead_slots
,
)
return
None
# Return any new copy-on-writes.
new_cows
=
self
.
block_allocator
.
clear_copy_on_writes
()
return
new_cows
def
free
(
self
,
seq
:
Sequence
)
->
None
:
if
seq
.
seq_id
not
in
self
.
block_tables
:
...
...
@@ -171,7 +206,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
# as computed.
self
.
block_allocator
.
mark_blocks_as_computed
()
def
get_common_computed_block_ids
(
self
,
seqs
:
List
[
Sequence
])
->
List
[
int
]:
def
get_common_computed_block_ids
(
self
,
seqs
:
List
[
Sequence
])
->
GenericSequence
[
int
]:
"""Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks.
...
...
@@ -191,10 +227,12 @@ class BlockSpaceManagerV2(BlockSpaceManager):
src_block_table
=
self
.
block_tables
[
parent_seq
.
seq_id
]
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
fork
()
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
bool
:
return
False
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
Dict
[
int
,
int
]:
raise
NotImplementedError
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
...
...
vllm/core/interfaces.py
View file @
99b471c2
import
enum
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
from
typing
import
Sequence
as
GenericSequence
from
vllm.sequence
import
Sequence
,
SequenceGroup
...
...
@@ -44,14 +45,16 @@ class BlockSpaceManager(ABC):
pass
@
abstractmethod
def
can_append_slot
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_append_slots
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
bool
:
pass
@
abstractmethod
def
append_slot
(
def
append_slot
s
(
self
,
seq
:
Sequence
,
)
->
Optional
[
Tuple
[
int
,
int
]]:
num_lookahead_slots
:
int
,
)
->
Dict
[
int
,
List
[
int
]]:
pass
@
abstractmethod
...
...
@@ -59,11 +62,13 @@ class BlockSpaceManager(ABC):
pass
@
abstractmethod
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
bool
:
pass
@
abstractmethod
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
Dict
[
int
,
int
]:
pass
@
abstractmethod
...
...
@@ -99,7 +104,8 @@ class BlockSpaceManager(ABC):
pass
@
abstractmethod
def
get_common_computed_block_ids
(
self
,
seqs
:
List
[
Sequence
])
->
List
[
int
]:
def
get_common_computed_block_ids
(
self
,
seqs
:
List
[
Sequence
])
->
GenericSequence
[
int
]:
pass
@
abstractmethod
...
...
vllm/core/policy.py
View file @
99b471c2
...
...
@@ -38,9 +38,7 @@ class FCFS(Policy):
class
PolicyFactory
:
_POLICY_REGISTRY
=
{
'fcfs'
:
FCFS
,
}
_POLICY_REGISTRY
=
{
'fcfs'
:
FCFS
}
@
classmethod
def
get_policy
(
cls
,
policy_name
:
str
,
**
kwargs
)
->
Policy
:
...
...
vllm/core/scheduler.py
View file @
99b471c2
import
enum
import
time
from
collections
import
deque
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.policy
import
PolicyFactory
from
vllm.core.policy
import
Policy
,
PolicyFactory
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
from
vllm.utils
import
merge_dicts
logger
=
init_logger
(
__name__
)
...
...
@@ -28,9 +29,67 @@ class PreemptionMode(enum.Enum):
RECOMPUTE
=
enum
.
auto
()
# seq_group: SequenceGroup to schedule.
# token_chunk_size: The number of prefill tokens to be processed in the next
# step.
@
dataclass
class
SchedulingBudget
:
"""The available slots for scheduling.
TODO(sang): Right now, the budget is request_id-aware meaning it can ignore
budget update from the same request_id. It is because in normal scheduling
path, we update RUNNING num_seqs ahead of time, meaning it could be
updated more than once when scheduling RUNNING requests. Since this won't
happen if we only have chunked prefill scheduling, we can remove this
feature from the API when chunked prefill is enabled by default.
"""
token_budget
:
int
max_num_seqs
:
int
_requeset_ids_num_batched_tokens
:
Set
[
str
]
=
field
(
default_factory
=
set
)
_requeset_ids_num_curr_seqs
:
Set
[
str
]
=
field
(
default_factory
=
set
)
_num_batched_tokens
:
int
=
0
_num_curr_seqs
:
int
=
0
def
can_schedule
(
self
,
*
,
num_new_tokens
:
int
,
num_new_seqs
:
int
):
assert
num_new_tokens
!=
0
assert
num_new_seqs
!=
0
return
(
self
.
num_batched_tokens
+
num_new_tokens
<=
self
.
token_budget
and
self
.
num_curr_seqs
+
num_new_seqs
<=
self
.
max_num_seqs
)
def
remaining_token_budget
(
self
):
return
self
.
token_budget
-
self
.
num_batched_tokens
def
add_num_batched_tokens
(
self
,
req_id
:
str
,
num_batched_tokens
:
int
):
if
req_id
in
self
.
_requeset_ids_num_batched_tokens
:
return
self
.
_requeset_ids_num_batched_tokens
.
add
(
req_id
)
self
.
_num_batched_tokens
+=
num_batched_tokens
def
subtract_num_batched_tokens
(
self
,
req_id
:
str
,
num_batched_tokens
:
int
):
if
req_id
in
self
.
_requeset_ids_num_batched_tokens
:
self
.
_requeset_ids_num_batched_tokens
.
remove
(
req_id
)
self
.
_num_batched_tokens
-=
num_batched_tokens
def
add_num_seqs
(
self
,
req_id
:
str
,
num_curr_seqs
:
int
):
if
req_id
in
self
.
_requeset_ids_num_curr_seqs
:
return
self
.
_requeset_ids_num_curr_seqs
.
add
(
req_id
)
self
.
_num_curr_seqs
+=
num_curr_seqs
def
subtract_num_seqs
(
self
,
req_id
:
str
,
num_curr_seqs
:
int
):
if
req_id
in
self
.
_requeset_ids_num_curr_seqs
:
self
.
_requeset_ids_num_curr_seqs
.
remove
(
req_id
)
self
.
_num_curr_seqs
-=
num_curr_seqs
@
property
def
num_batched_tokens
(
self
):
return
self
.
_num_batched_tokens
@
property
def
num_curr_seqs
(
self
):
return
self
.
_num_curr_seqs
@
dataclass
class
ScheduledSequenceGroup
:
# A sequence group that's scheduled.
...
...
@@ -41,51 +100,29 @@ class ScheduledSequenceGroup:
token_chunk_size
:
int
@
dataclass
class
SchedulerOutputs
:
def
__init__
(
self
,
scheduled_seq_groups
:
Iterable
[
ScheduledSequenceGroup
],
prompt_run
:
bool
,
num_batched_tokens
:
int
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
ignored_seq_groups
:
List
[
SequenceGroup
],
)
->
None
:
"""A list of sequence groups to be scheduled as a single batch.
Args:
scheduled_seq_groups: A tuple of scheduled sequence group and its
token chunk size.
prompt_run: True if all sequence groups are in prefill phase.
If False, all sequence groups are in decoding phase.
num_batched_tokens: Total number of batched tokens.
blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block
number.
blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block
number.
blocks_to_copy: Blocks to copy. Source to a list of dest blocks.
ignored_seq_groups: Sequence groups that are going to be ignored.
"""
# A tuple of scheduled sequence group and its chunk size.
self
.
scheduled_seq_groups
:
ScheduledSequenceGroup
=
scheduled_seq_groups
# True if all sequence groups are in prefill phase. If False, all
# sequence groups are in decoding phase.
self
.
prompt_run
:
bool
=
prompt_run
# Total number of batched tokens.
self
.
num_batched_tokens
:
int
=
num_batched_tokens
# Blocks to swap in. Dict of CPU -> GPU block number.
self
.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
blocks_to_swap_in
# Blocks to swap out. Dict of GPU -> CPU block number.
self
.
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
blocks_to_swap_out
# Blocks to copy. Source to a list of dest blocks.
self
.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
blocks_to_copy
# Sequence groups that are going to be ignored.
self
.
ignored_seq_groups
:
List
[
SequenceGroup
]
=
ignored_seq_groups
"""The scheduling decision made from a scheduler."""
# Scheduled sequence groups.
scheduled_seq_groups
:
Iterable
[
ScheduledSequenceGroup
]
# Number of prefill groups scheduled.
num_prefill_groups
:
int
# Total number of batched tokens.
num_batched_tokens
:
int
# Blocks to swap in. Dict of CPU -> GPU block number.
blocks_to_swap_in
:
Dict
[
int
,
int
]
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out
:
Dict
[
int
,
int
]
# Blocks to copy. Source to a list of dest blocks.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
# Sequence groups that are going to be ignored.
ignored_seq_groups
:
List
[
SequenceGroup
]
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
def
__post_init__
(
self
):
# Swap in and swap out should never happen at the same time.
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
assert
not
(
self
.
blocks_to_swap_in
and
self
.
blocks_to_swap_out
)
self
.
num_loras
:
int
=
len
(
self
.
lora_requests
)
if
self
.
num_loras
>
0
:
...
...
@@ -96,14 +133,106 @@ class SchedulerOutputs:
return
(
not
self
.
scheduled_seq_groups
and
not
self
.
blocks_to_swap_in
and
not
self
.
blocks_to_swap_out
and
not
self
.
blocks_to_copy
)
def
_sort_by_lora_ids
(
self
)
->
bool
:
def
_sort_by_lora_ids
(
self
):
self
.
scheduled_seq_groups
=
sorted
(
self
.
scheduled_seq_groups
,
key
=
lambda
g
:
(
g
.
seq_group
.
lora_int_id
,
g
.
seq_group
.
request_id
))
@
property
def
lora_requests
(
self
)
->
Set
[
LoRARequest
]:
return
{
g
.
seq_group
.
lora_request
for
g
in
self
.
scheduled_seq_groups
}
return
{
g
.
seq_group
.
lora_request
for
g
in
self
.
scheduled_seq_groups
if
g
.
seq_group
.
lora_request
is
not
None
}
@
dataclass
class
SchedulerRunningOutputs
:
"""The requests that are scheduled from a running queue.
Could contain prefill (prefill that's chunked) or decodes. If there's not
enough memory, it can be preempted (for recompute) or swapped out.
"""
# Selected sequences that are running and in a decoding phase.
decode_seq_groups
:
List
[
SequenceGroup
]
# Selected sequences that are running and in a prefill phase.
# I.e., it means the prefill has been chunked.
prefill_seq_groups
:
List
[
SequenceGroup
]
# The preempted sequences.
preempted
:
List
[
SequenceGroup
]
# Sequences that are swapped out.
swapped_out
:
List
[
SequenceGroup
]
# The blocks to swap out.
blocks_to_swap_out
:
Dict
[
int
,
int
]
# The blocks to copy.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
@
classmethod
def
create_empty
(
cls
)
->
"SchedulerRunningOutputs"
:
return
SchedulerRunningOutputs
(
decode_seq_groups
=
[],
prefill_seq_groups
=
[],
preempted
=
[],
swapped_out
=
[],
blocks_to_swap_out
=
{},
blocks_to_copy
=
{},
num_lookahead_slots
=
0
,
)
@
dataclass
class
SchedulerSwappedInOutputs
:
"""The requests that are scheduled from a swap queue.
Could contain prefill (prefill that's chunked) or decodes.
"""
# Selected sequences that are going to be swapped in and is in a
# decoding phase.
decode_seq_groups
:
List
[
SequenceGroup
]
# Selected sequences that are going to be swapped in and in a prefill
# phase. I.e., it means the prefill has been chunked.
prefill_seq_groups
:
List
[
SequenceGroup
]
# The blocks to swap in.
blocks_to_swap_in
:
Dict
[
int
,
int
]
# The blocks to copy.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
@
classmethod
def
create_empty
(
cls
)
->
"SchedulerSwappedInOutputs"
:
return
SchedulerSwappedInOutputs
(
decode_seq_groups
=
[],
prefill_seq_groups
=
[],
blocks_to_swap_in
=
{},
blocks_to_copy
=
{},
num_lookahead_slots
=
0
,
)
@
dataclass
class
SchedulerPrefillOutputs
:
"""The requests that are scheduled from a waiting queue.
Could contain a fresh prefill requests or preempted requests that need
to be recomputed from scratch.
"""
# Selected sequences for prefill.
seq_groups
:
List
[
SequenceGroup
]
# Ignored sequence groups.
ignored_seq_groups
:
List
[
SequenceGroup
]
num_lookahead_slots
:
int
@
classmethod
def
create_empty
(
cls
)
->
"SchedulerPrefillOutputs"
:
return
SchedulerPrefillOutputs
(
seq_groups
=
[],
ignored_seq_groups
=
[],
num_lookahead_slots
=
0
,
)
class
Scheduler
:
...
...
@@ -121,11 +250,12 @@ class Scheduler:
# LoRAs. This should be improved in the future.
self
.
lora_config
=
lora_config
self
.
prompt_limit
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
scheduler_config
.
max_num_batched_tokens
)
# Instantiate the scheduling policy.
self
.
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
if
self
.
scheduler_config
.
chunked_prefill_enabled
:
self
.
prompt_limit
=
self
.
scheduler_config
.
max_model_len
else
:
self
.
prompt_limit
=
min
(
self
.
scheduler_config
.
max_model_len
,
self
.
scheduler_config
.
max_num_batched_tokens
)
BlockSpaceManagerImpl
=
BlockSpaceManager
.
get_block_space_manager_class
(
version
=
"v2"
if
self
.
scheduler_config
.
...
...
@@ -140,10 +270,13 @@ class Scheduler:
enable_caching
=
self
.
cache_config
.
enable_prefix_caching
)
# Sequence groups in the WAITING state.
# Contain new prefill or preempted requests.
self
.
waiting
:
Deque
[
SequenceGroup
]
=
deque
()
# Sequence groups in the RUNNING state.
# Contain decode requests.
self
.
running
:
Deque
[
SequenceGroup
]
=
deque
()
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self
.
swapped
:
Deque
[
SequenceGroup
]
=
deque
()
# Time at previous scheduling step
...
...
@@ -157,6 +290,11 @@ class Scheduler:
def
lora_enabled
(
self
)
->
bool
:
return
bool
(
self
.
lora_config
)
@
property
def
num_decoding_tokens_per_seq
(
self
)
->
int
:
"""The number of new tokens."""
return
1
def
add_seq_group
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
...
...
@@ -198,215 +336,552 @@ class Scheduler:
self
.
free_seq
(
seq
)
def
has_unfinished_seqs
(
self
)
->
bool
:
return
self
.
waiting
or
self
.
running
or
self
.
swapped
return
len
(
self
.
waiting
)
!=
0
or
len
(
self
.
running
)
!=
0
or
len
(
self
.
swapped
)
!=
0
def
get_num_unfinished_seq_groups
(
self
)
->
int
:
return
len
(
self
.
waiting
)
+
len
(
self
.
running
)
+
len
(
self
.
swapped
)
def
_schedule
(
self
)
->
SchedulerOutputs
:
def
_schedule_running
(
self
,
running_queue
:
deque
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
policy
:
Policy
,
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerRunningOutputs
]:
"""Schedule sequence groups that are running.
Running queue should include decode and chunked prefill requests.
Args:
running_queue: The queue that contains running requests (i.e.,
decodes). The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted.
policy: The sorting policy to sort running_queue.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining running queue (should be always 0) after
scheduling and SchedulerRunningOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
{}
# Fix the current time.
now
=
time
.
time
()
# Join waiting sequences if possible.
if
not
self
.
swapped
:
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
scheduled
:
List
[
SequenceGroup
]
=
[]
# The total number of sequences on the fly, including the
# requests in the generation phase.
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
for
seq_group
in
self
.
running
)
curr_loras
=
set
(
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
)
if
self
.
lora_enabled
else
None
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
leftover_waiting_sequences
=
deque
()
num_batched_tokens
=
0
while
self
.
_passed_delay
(
now
)
and
self
.
waiting
:
seq_group
=
self
.
waiting
[
0
]
waiting_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)
assert
len
(
waiting_seqs
)
==
1
,
(
"Waiting sequence group should have only one prompt "
"sequence."
)
# get_len includes output tokens if the request has been
# preempted.
num_prefill_tokens
=
waiting_seqs
[
0
].
get_len
()
if
num_prefill_tokens
>
self
.
prompt_limit
:
logger
.
warning
(
f
"Input prompt (
{
num_prefill_tokens
}
tokens) is too "
f
"long and exceeds limit of
{
self
.
prompt_limit
}
"
)
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
popleft
()
continue
# If the sequence group cannot be allocated, stop.
can_allocate
=
self
.
block_manager
.
can_allocate
(
seq_group
)
if
can_allocate
==
AllocStatus
.
LATER
:
break
elif
can_allocate
==
AllocStatus
.
NEVER
:
logger
.
warning
(
f
"Input prompt (
{
num_prefill_tokens
}
tokens) is too "
f
"long and exceeds the capacity of block_manager"
)
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
popleft
()
continue
lora_int_id
=
0
if
self
.
lora_enabled
:
lora_int_id
=
seq_group
.
lora_int_id
if
(
lora_int_id
>
0
and
lora_int_id
not
in
curr_loras
and
len
(
curr_loras
)
>=
self
.
lora_config
.
max_loras
):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_waiting_sequences
.
appendleft
(
seq_group
)
self
.
waiting
.
popleft
()
continue
# If the number of batched tokens exceeds the limit, stop.
num_batched_tokens
+=
num_prefill_tokens
if
(
num_batched_tokens
>
self
.
scheduler_config
.
max_num_batched_tokens
):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
if
(
num_curr_seqs
+
num_new_seqs
>
self
.
scheduler_config
.
max_num_seqs
):
break
if
lora_int_id
>
0
:
curr_loras
.
add
(
lora_int_id
)
self
.
waiting
.
popleft
()
self
.
_allocate
(
seq_group
)
self
.
running
.
append
(
seq_group
)
num_curr_seqs
+=
num_new_seqs
scheduled
.
append
(
ScheduledSequenceGroup
(
seq_group
=
seq_group
,
token_chunk_size
=
num_prefill_tokens
))
self
.
waiting
.
extendleft
(
leftover_waiting_sequences
)
if
scheduled
or
ignored_seq_groups
:
self
.
prev_prompt
=
True
scheduler_outputs
=
SchedulerOutputs
(
scheduled_seq_groups
=
scheduled
,
prompt_run
=
True
,
num_batched_tokens
=
num_batched_tokens
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
ignored_seq_groups
=
ignored_seq_groups
,
)
return
scheduler_outputs
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
preempted
:
List
[
SequenceGroup
]
=
[]
swapped_out
:
List
[
SequenceGroup
]
=
[]
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
self
.
running
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
running
)
# Reserve new token slots for the running sequence groups.
running
:
Deque
[
SequenceGroup
]
=
deque
()
preempted
:
List
[
SequenceGroup
]
=
[]
while
self
.
running
:
seq_group
=
self
.
running
.
popleft
()
while
not
self
.
block_manager
.
can_append_slot
(
seq_group
):
if
self
.
running
:
now
=
time
.
time
()
running_queue
=
policy
.
sort_by_priority
(
now
,
running_queue
)
while
running_queue
:
seq_group
=
running_queue
[
0
]
num_running_tokens
=
self
.
_get_num_new_tokens
(
seq_group
,
SequenceStatus
.
RUNNING
,
enable_chunking
,
budget
)
# We can have up to 1 running prefill at any given time in running
# queue, which means we can guarantee chunk size is at least 1.
assert
num_running_tokens
!=
0
running_queue
.
popleft
()
while
not
self
.
_can_append_slots
(
seq_group
):
budget
.
subtract_num_batched_tokens
(
seq_group
.
request_id
,
num_running_tokens
)
num_running_seqs
=
seq_group
.
get_max_num_running_seqs
()
budget
.
subtract_num_seqs
(
seq_group
.
request_id
,
num_running_seqs
)
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
curr_loras
.
remove
(
seq_group
.
lora_int_id
)
if
running_queue
:
# Preempt the lowest-priority sequence groups.
victim_seq_group
=
self
.
running
.
pop
()
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
preempted
.
append
(
victim_seq_group
)
victim_seq_group
=
running_queue
.
pop
()
preempted_mode
=
self
.
_preempt
(
victim_seq_group
,
blocks_to_swap_out
)
if
preempted_mode
==
PreemptionMode
.
RECOMPUTE
:
preempted
.
append
(
victim_seq_group
)
else
:
swapped_out
.
append
(
victim_seq_group
)
else
:
# No other sequence groups can be preempted.
# Preempt the current sequence group.
self
.
_preempt
(
seq_group
,
blocks_to_swap_out
)
preempted
.
append
(
seq_group
)
preempted_mode
=
self
.
_preempt
(
seq_group
,
blocks_to_swap_out
)
if
preempted_mode
==
PreemptionMode
.
RECOMPUTE
:
preempted
.
append
(
seq_group
)
else
:
swapped_out
.
append
(
seq_group
)
break
else
:
# Append new slots to the sequence group.
self
.
_append_slot
(
seq_group
,
blocks_to_copy
)
running
.
append
(
seq_group
)
self
.
running
=
running
# Swap in the sequence groups in the SWAPPED state if possible.
self
.
swapped
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
swapped
)
if
not
preempted
:
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
for
seq_group
in
self
.
running
)
curr_loras
=
set
(
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
)
if
self
.
lora_enabled
else
None
leftover_swapped
=
deque
()
while
self
.
swapped
:
seq_group
=
self
.
swapped
[
0
]
lora_int_id
=
0
if
self
.
lora_enabled
:
lora_int_id
=
seq_group
.
lora_int_id
if
(
lora_int_id
>
0
and
lora_int_id
not
in
curr_loras
and
len
(
curr_loras
)
>=
self
.
lora_config
.
max_loras
):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_swapped
.
appendleft
(
seq_group
)
self
.
swapped
.
popleft
()
continue
self
.
_append_slots
(
seq_group
,
blocks_to_copy
)
is_prefill
=
seq_group
.
is_prefill
()
if
is_prefill
:
prefill_seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
=
seq_group
,
token_chunk_size
=
num_running_tokens
))
else
:
decode_seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
=
seq_group
,
token_chunk_size
=
1
))
budget
.
add_num_batched_tokens
(
seq_group
.
request_id
,
num_running_tokens
)
# OPTIMIZATION: Note that get_max_num_running_seqs is
# expensive. For the default scheduling chase where
# enable_chunking is False, num_seqs are updated before running
# this method, so we don't have to update it again here.
if
enable_chunking
:
num_running_seqs
=
seq_group
.
get_max_num_running_seqs
()
budget
.
add_num_seqs
(
seq_group
.
request_id
,
num_running_seqs
)
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
curr_loras
.
add
(
seq_group
.
lora_int_id
)
# Make sure all queues are updated.
assert
len
(
running_queue
)
==
0
return
running_queue
,
SchedulerRunningOutputs
(
decode_seq_groups
=
decode_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
preempted
=
preempted
,
swapped_out
=
swapped_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
))
# If the sequence group cannot be swapped in, stop.
if
not
self
.
block_manager
.
can_swap_in
(
seq_group
):
break
def
_schedule_swapped
(
self
,
swapped_queue
:
deque
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
policy
:
Policy
,
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerSwappedInOutputs
]:
"""Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and
curr_loras <= max_lora from the scheduling config. The input arguments
`budget` and `curr_loras` are updated based on scheduled seq_groups.
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
if
(
num_curr_seqs
+
num_new_seqs
>
self
.
scheduler_config
.
max_num_seqs
):
break
Args:
swapped_queue: The queue that contains swapped out requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in.
policy: The sorting policy to sort swapped_queue.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining swapped_queue after scheduling and
SchedulerSwappedInOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
{}
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
now
=
time
.
time
()
swapped_queue
=
policy
.
sort_by_priority
(
now
,
swapped_queue
)
leftover_swapped
:
Deque
[
SequenceGroup
]
=
deque
()
while
swapped_queue
:
seq_group
=
swapped_queue
[
0
]
# If the sequence group cannot be swapped in, stop.
if
not
self
.
block_manager
.
can_swap_in
(
seq_group
):
break
lora_int_id
=
0
if
self
.
lora_enabled
:
lora_int_id
=
seq_group
.
lora_int_id
assert
curr_loras
is
not
None
assert
self
.
lora_config
is
not
None
if
(
lora_int_id
>
0
and
(
lora_int_id
not
in
curr_loras
)
and
len
(
curr_loras
)
>=
self
.
lora_config
.
max_loras
):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_swapped
.
appendleft
(
seq_group
)
swapped_queue
.
popleft
()
continue
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
num_new_tokens
=
self
.
_get_num_new_tokens
(
seq_group
,
SequenceStatus
.
SWAPPED
,
enable_chunking
,
budget
)
if
(
num_new_tokens
==
0
or
not
budget
.
can_schedule
(
num_new_tokens
=
num_new_tokens
,
num_new_seqs
=
num_new_seqs
)):
break
if
lora_int_id
>
0
and
curr_loras
is
not
None
:
curr_loras
.
add
(
lora_int_id
)
swapped_queue
.
popleft
()
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_append_slots
(
seq_group
,
blocks_to_copy
)
is_prefill
=
seq_group
.
is_prefill
()
if
is_prefill
:
prefill_seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
,
token_chunk_size
=
num_new_tokens
))
else
:
assert
num_new_tokens
==
1
decode_seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
,
token_chunk_size
=
1
))
budget
.
add_num_batched_tokens
(
seq_group
.
request_id
,
num_new_tokens
)
budget
.
add_num_seqs
(
seq_group
.
request_id
,
num_new_seqs
)
swapped_queue
.
extendleft
(
leftover_swapped
)
if
lora_int_id
>
0
:
curr_loras
.
add
(
lora_int_id
)
self
.
swapped
.
popleft
()
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_append_slot
(
seq_group
,
blocks_to_copy
)
num_curr_seqs
+=
num_new_seqs
self
.
running
.
append
(
seq_group
)
self
.
swapped
.
extendleft
(
leftover_swapped
)
# Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state.
num_batched_tokens
=
sum
(
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
for
seq_group
in
self
.
running
)
scheduler_outputs
=
SchedulerOutputs
(
scheduled_seq_groups
=
[
ScheduledSequenceGroup
(
seq_group
=
running_group
,
token_chunk_size
=
1
)
for
running_group
in
self
.
running
],
prompt_run
=
False
,
num_batched_tokens
=
num_batched_tokens
,
return
swapped_queue
,
SchedulerSwappedInOutputs
(
decode_seq_groups
=
decode_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
ignored_seq_groups
=
[],
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
))
def
_schedule_prefills
(
self
,
waiting_queue
:
deque
,
budget
:
SchedulingBudget
,
curr_loras
:
Optional
[
Set
[
int
]],
enable_chunking
:
bool
=
False
,
)
->
Tuple
[
deque
,
SchedulerPrefillOutputs
]:
"""Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
as a new prefill (that starts from beginning -> most recently generated
tokens).
It schedules waiting requests as long as it fits `budget` and
curr_loras <= max_lora from the scheduling config. The input arguments
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
waiting_queue: The queue that contains prefill requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are scheduled.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining waiting_queue after scheduling and
SchedulerSwappedInOutputs.
"""
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
seq_groups
:
List
[
SequenceGroup
]
=
[]
# We don't sort waiting queue because we assume it is sorted.
# Copy the queue so that the input queue is not modified.
waiting_queue
=
deque
([
s
for
s
in
waiting_queue
])
leftover_waiting_sequences
:
Deque
[
SequenceGroup
]
=
deque
()
while
self
.
_passed_delay
(
time
.
time
())
and
waiting_queue
:
seq_group
=
waiting_queue
[
0
]
waiting_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
)
assert
len
(
waiting_seqs
)
==
1
,
(
"Waiting sequence group should have only one prompt "
"sequence."
)
num_new_tokens
=
self
.
_get_num_new_tokens
(
seq_group
,
SequenceStatus
.
WAITING
,
enable_chunking
,
budget
)
if
not
enable_chunking
:
num_prompt_tokens
=
waiting_seqs
[
0
].
get_len
()
assert
num_new_tokens
==
num_prompt_tokens
if
num_new_tokens
>
self
.
prompt_limit
:
logger
.
warning
(
f
"Input prompt (
{
num_new_tokens
}
tokens) is too long"
f
" and exceeds limit of
{
self
.
prompt_limit
}
"
)
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
waiting_queue
.
popleft
()
continue
# If the sequence group cannot be allocated, stop.
can_allocate
=
self
.
block_manager
.
can_allocate
(
seq_group
)
if
can_allocate
==
AllocStatus
.
LATER
:
break
elif
can_allocate
==
AllocStatus
.
NEVER
:
logger
.
warning
(
f
"Input prompt (
{
num_new_tokens
}
tokens) is too long"
f
" and exceeds the capacity of block_manager"
)
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
waiting_queue
.
popleft
()
continue
lora_int_id
=
0
if
self
.
lora_enabled
:
lora_int_id
=
seq_group
.
lora_int_id
assert
curr_loras
is
not
None
assert
self
.
lora_config
is
not
None
if
(
self
.
lora_enabled
and
lora_int_id
>
0
and
lora_int_id
not
in
curr_loras
and
len
(
curr_loras
)
>=
self
.
lora_config
.
max_loras
):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_waiting_sequences
.
appendleft
(
seq_group
)
waiting_queue
.
popleft
()
continue
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
if
(
num_new_tokens
==
0
or
not
budget
.
can_schedule
(
num_new_tokens
=
num_new_tokens
,
num_new_seqs
=
num_new_seqs
)):
break
# Can schedule this request.
if
curr_loras
is
not
None
and
lora_int_id
>
0
:
curr_loras
.
add
(
lora_int_id
)
waiting_queue
.
popleft
()
self
.
_allocate_and_set_running
(
seq_group
)
seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
=
seq_group
,
token_chunk_size
=
num_new_tokens
))
budget
.
add_num_batched_tokens
(
seq_group
.
request_id
,
num_new_tokens
)
budget
.
add_num_seqs
(
seq_group
.
request_id
,
num_new_seqs
)
# Queue requests that couldn't be scheduled.
waiting_queue
.
extendleft
(
leftover_waiting_sequences
)
if
len
(
seq_groups
)
>
0
:
self
.
prev_prompt
=
True
return
waiting_queue
,
SchedulerPrefillOutputs
(
seq_groups
=
seq_groups
,
ignored_seq_groups
=
ignored_seq_groups
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
True
))
def
_schedule_default
(
self
)
->
SchedulerOutputs
:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
it batches as many prefill requests as possible. And it schedules
decodes. If there's a pressure on GPU memory, decode requests can
be swapped or preempted.
"""
# Include running requests to the budget.
budget
=
SchedulingBudget
(
token_budget
=
self
.
scheduler_config
.
max_num_batched_tokens
,
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
,
)
# Make sure we include num running seqs before scheduling prefill,
# so that we don't schedule beyond max_num_seqs for prefill.
for
seq_group
in
self
.
running
:
budget
.
add_num_seqs
(
seq_group
.
request_id
,
seq_group
.
get_max_num_running_seqs
())
curr_loras
=
set
(
seq_group
.
lora_int_id
for
seq_group
in
self
.
running
)
if
self
.
lora_enabled
else
None
remaining_waiting
,
prefills
=
(
self
.
waiting
,
SchedulerPrefillOutputs
.
create_empty
())
remaining_running
,
running_scheduled
=
(
self
.
running
,
SchedulerRunningOutputs
.
create_empty
())
remaining_swapped
,
swapped_in
=
(
self
.
swapped
,
SchedulerSwappedInOutputs
.
create_empty
())
# If any requests are swapped, prioritized swapped requests.
if
not
self
.
swapped
:
remaining_waiting
,
prefills
=
self
.
_schedule_prefills
(
self
.
waiting
,
budget
,
curr_loras
,
enable_chunking
=
False
)
fcfs_policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
if
len
(
prefills
.
seq_groups
)
==
0
:
remaining_running
,
running_scheduled
=
self
.
_schedule_running
(
self
.
running
,
budget
,
curr_loras
,
fcfs_policy
,
enable_chunking
=
False
)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
if
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
)
==
0
:
remaining_swapped
,
swapped_in
=
self
.
_schedule_swapped
(
self
.
swapped
,
budget
,
curr_loras
,
fcfs_policy
)
assert
(
budget
.
num_batched_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# Update waiting requests.
self
.
waiting
=
remaining_waiting
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
# Update new running requests.
self
.
running
=
remaining_running
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
decode_seq_groups
])
# Update swapped requests.
self
.
swapped
=
remaining_swapped
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
# There should be no prefill from running queue because this policy
# doesn't allow chunked prefills.
assert
len
(
running_scheduled
.
prefill_seq_groups
)
==
0
assert
len
(
swapped_in
.
prefill_seq_groups
)
==
0
return
SchedulerOutputs
(
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
running_scheduled
.
decode_seq_groups
+
swapped_in
.
decode_seq_groups
),
num_prefill_groups
=
len
(
prefills
.
seq_groups
),
num_batched_tokens
=
budget
.
num_batched_tokens
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_copy
=
merge_dicts
(
running_scheduled
.
blocks_to_copy
,
swapped_in
.
blocks_to_copy
),
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
)
def
_schedule_chunked_prefill
(
self
):
"""Schedule queued requests.
Chunked prefill allows to chunk prefill requests, batch them together
with decode requests. This policy 1. schedule as many decoding requests
as possible. 2. schedule chunked prefill requests that are not
finished. 3. schedule swapped request. 4. schedule new prefill
requests.
The policy can sustain the high GPU utilization because it can put
prefill and decodes requests to the same batch, while it improves
inter token latency because decodes requests don't need to blocked
by prefill requests.
"""
budget
=
SchedulingBudget
(
token_budget
=
self
.
scheduler_config
.
max_num_batched_tokens
,
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
,
)
curr_loras
:
Set
[
int
]
=
set
()
remaining_waiting
,
prefills
=
(
self
.
waiting
,
SchedulerPrefillOutputs
.
create_empty
())
remaining_running
,
running_scheduled
=
(
self
.
running
,
SchedulerRunningOutputs
.
create_empty
())
remaining_swapped
,
swapped_in
=
(
self
.
swapped
,
SchedulerSwappedInOutputs
.
create_empty
())
# Decoding should be always scheduled first by fcfs.
fcfs_policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
remaining_running
,
running_scheduled
=
self
.
_schedule_running
(
self
.
running
,
budget
,
curr_loras
,
fcfs_policy
,
enable_chunking
=
True
)
# Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in.
if
len
(
running_scheduled
.
preempted
)
+
len
(
running_scheduled
.
swapped_out
)
==
0
:
remaining_swapped
,
swapped_in
=
self
.
_schedule_swapped
(
self
.
swapped
,
budget
,
curr_loras
,
fcfs_policy
)
# Schedule new prefills.
remaining_waiting
,
prefills
=
self
.
_schedule_prefills
(
self
.
waiting
,
budget
,
curr_loras
,
enable_chunking
=
True
)
assert
(
budget
.
num_batched_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
)
assert
budget
.
num_curr_seqs
<=
self
.
scheduler_config
.
max_num_seqs
# Update waiting requests.
self
.
waiting
=
remaining_waiting
self
.
waiting
.
extendleft
(
running_scheduled
.
preempted
)
# Update new running requests.
self
.
running
=
remaining_running
self
.
running
.
extend
([
s
.
seq_group
for
s
in
prefills
.
seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
decode_seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
running_scheduled
.
prefill_seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
decode_seq_groups
])
self
.
running
.
extend
(
[
s
.
seq_group
for
s
in
swapped_in
.
prefill_seq_groups
])
# Update swapped requests.
self
.
swapped
=
remaining_swapped
self
.
swapped
.
extend
(
running_scheduled
.
swapped_out
)
return
SchedulerOutputs
(
scheduled_seq_groups
=
(
prefills
.
seq_groups
+
running_scheduled
.
prefill_seq_groups
+
swapped_in
.
prefill_seq_groups
+
running_scheduled
.
decode_seq_groups
+
swapped_in
.
decode_seq_groups
),
num_prefill_groups
=
(
len
(
prefills
.
seq_groups
)
+
len
(
swapped_in
.
prefill_seq_groups
)
+
len
(
running_scheduled
.
prefill_seq_groups
)),
num_batched_tokens
=
budget
.
num_batched_tokens
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_copy
=
merge_dicts
(
running_scheduled
.
blocks_to_copy
,
swapped_in
.
blocks_to_copy
),
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
)
def
_schedule
(
self
)
->
SchedulerOutputs
:
"""Schedule queued requests."""
if
self
.
scheduler_config
.
chunked_prefill_enabled
:
return
self
.
_schedule_chunked_prefill
()
else
:
return
self
.
_schedule_default
()
def
_can_append_slots
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
"""Determine whether or not we have enough space in the KV cache to
continue generation of the sequence group.
"""
# Appending slots only occurs in decoding.
is_prefill
=
False
return
self
.
block_manager
.
can_append_slots
(
seq_group
=
seq_group
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
),
)
def
_can_swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
# Swapping in is considered decode.
is_prefill
=
False
return
self
.
block_manager
.
can_swap_in
(
seq_group
=
seq_group
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
),
)
return
scheduler_outputs
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]:
# Schedule sequence groups.
...
...
@@ -417,7 +892,8 @@ class Scheduler:
# Create input data structures.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
for
i
,
scheduled_seq_group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
seq_group
.
maybe_set_first_scheduled_time
(
now
)
...
...
@@ -437,9 +913,12 @@ class Scheduler:
self
.
block_manager
.
get_common_computed_block_ids
(
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)))
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
is_prompt
=
seq_group
.
is_prefill
()
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
seq_group
.
request_id
,
is_prompt
=
scheduler_outputs
.
prompt
_run
,
is_prompt
=
is_
prompt
,
seq_data
=
seq_data
,
sampling_params
=
seq_group
.
sampling_params
,
block_tables
=
block_tables
,
...
...
@@ -452,7 +931,7 @@ class Scheduler:
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data
=
seq_group
.
multi_modal_data
if
scheduler_outputs
.
prompt_run
else
None
,
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
...
...
@@ -477,31 +956,42 @@ class Scheduler:
self
.
running
=
deque
(
seq_group
for
seq_group
in
self
.
running
if
not
seq_group
.
is_finished
())
def
_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
_allocate
_and_set_running
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
self
.
block_manager
.
allocate
(
seq_group
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
seq
.
status
=
SequenceStatus
.
RUNNING
def
_append_slot
(
def
_append_slot
s
(
self
,
seq_group
:
SequenceGroup
,
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
"""Appends new slots to the sequences in the given sequence group.
Args:
seq_group (SequenceGroup): The sequence group containing the
sequences to append slots to.
blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source
block indices to lists of destination block indices. This
dictionary is updated with the new source and destination block
indices for the appended slots.
"""
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
ret
=
self
.
block_manager
.
append_slot
(
seq
)
if
ret
is
not
None
:
src_block
,
dst_block
=
ret
if
src_block
in
blocks_to_copy
:
blocks_to_copy
[
src_block
].
append
(
dst_block
)
else
:
blocks_to_copy
[
src_block
]
=
[
dst_block
]
cows
=
self
.
block_manager
.
append_slots
(
seq
,
num_lookahead_slots
)
for
src
,
dests
in
cows
.
items
():
if
src
not
in
blocks_to_copy
:
blocks_to_copy
[
src
]
=
[]
blocks_to_copy
[
src
].
extend
(
dests
)
def
_preempt
(
self
,
seq_group
:
SequenceGroup
,
blocks_to_swap_out
:
Dict
[
int
,
int
],
preemption_mode
:
Optional
[
PreemptionMode
]
=
None
,
)
->
Non
e
:
)
->
PreemptionMod
e
:
# If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
...
...
@@ -524,6 +1014,7 @@ class Scheduler:
self
.
_preempt_by_swap
(
seq_group
,
blocks_to_swap_out
)
else
:
raise
AssertionError
(
"Invalid preemption mode."
)
return
preemption_mode
def
_preempt_by_recompute
(
self
,
...
...
@@ -535,9 +1026,6 @@ class Scheduler:
seq
.
status
=
SequenceStatus
.
WAITING
self
.
free_seq
(
seq
)
seq
.
reset_state_for_recompute
()
# NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue.
self
.
waiting
.
appendleft
(
seq_group
)
def
_preempt_by_swap
(
self
,
...
...
@@ -545,7 +1033,6 @@ class Scheduler:
blocks_to_swap_out
:
Dict
[
int
,
int
],
)
->
None
:
self
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
self
.
swapped
.
append
(
seq_group
)
def
_swap_in
(
self
,
...
...
@@ -588,3 +1075,39 @@ class Scheduler:
else
:
passed_delay
=
True
return
passed_delay
def
_get_num_lookahead_slots
(
self
,
is_prefill
:
bool
)
->
int
:
"""The number of slots to allocate per sequence per step, beyond known
token ids. Speculative decoding uses these slots to store KV activations
of tokens which may or may not be accepted.
Speculative decoding does not yet support prefill, so we do not perform
lookahead allocation for prefill.
"""
if
is_prefill
:
return
0
return
self
.
scheduler_config
.
num_lookahead_slots
def
_get_num_new_tokens
(
self
,
seq_group
:
SequenceGroup
,
status
:
SequenceStatus
,
enable_chunking
:
bool
,
budget
:
SchedulingBudget
)
->
int
:
"""Get the next new tokens to compute for a given sequence group
that's in a given `status`.
The API could chunk the number of tokens to compute based on `budget`
if `enable_chunking` is True. If a sequence group has multiple
sequences (e.g., running beam search), it means it is in decoding
phase, so chunking doesn't happen.
"""
num_new_tokens
=
0
seqs
=
seq_group
.
get_seqs
(
status
=
status
)
for
seq
in
seqs
:
num_new_tokens
+=
seq
.
get_num_new_tokens
()
# Chunk if a running request cannot fit in.
# If number of seq > 1, it means it is doing beam search in a
# decode phase. Do not chunk in that case.
if
enable_chunking
and
len
(
seqs
)
==
1
:
num_new_tokens
=
min
(
num_new_tokens
,
budget
.
remaining_token_budget
())
return
num_new_tokens
Prev
1
…
6
7
8
9
10
11
12
13
14
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment