Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1b78ef29
Commit
1b78ef29
authored
Jul 31, 2025
by
zhuwenwen
Browse files
remove unused code
parent
0628e4b4
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
3 additions
and
2252 deletions
+3
-2252
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+0
-469
vllm/attention/backends/cpu_mla.py
vllm/attention/backends/cpu_mla.py
+0
-307
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+0
-403
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+0
-356
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+0
-707
vllm/attention/backends/tree_decoding_utils.py
vllm/attention/backends/tree_decoding_utils.py
+1
-2
vllm/zero_overhead/llm_engine.py
vllm/zero_overhead/llm_engine.py
+2
-8
No files found.
vllm/attention/backends/blocksparse_attn.py
deleted
100644 → 0
View file @
0628e4b4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
)
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
@
dataclass
class
BlocksparseParams
:
max_seqlen
:
int
# Num q heads per tensor-parallel rank/partition
num_heads
:
int
# per TP partition
# Num kv heads per tensor-parallel rank/partition
num_kv_heads
:
int
# block size used for blocksparse attention.
# This is the block_size used in `local_blocks`, `vert_stride`.
block_size
:
int
# Number of blocks for local attention, i.e., number of
# local attended tokens / `sparse_block_size`
local_blocks
:
int
# Attend to one block per every `vert_stride` blocks.
# Controlling the sparsity
vert_stride
:
int
"""
If to use the same vertical stride offset for all heads,
i.e., attend to the same block of tokens on all heads.
By default, it is False, i.e., attention on the non-local
blocks depends on the `head_idx`, that is on
blocks satisfying
`(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
`block_idx = position_id // sparse_block_size`.
See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
for more detail.
"""
homo_head
:
bool
=
False
# If within a group, the kv offsets that each q attends is the same or no.
homo_head_group
:
bool
=
False
# Decided by homo_head and homo_head group
head_sliding_step
:
int
=
field
(
init
=
False
)
# range of q heads to for a TP rank
active_head_range
:
Tuple
=
field
(
init
=
False
)
def
__post_init__
(
self
):
assert
self
.
block_size
>
0
assert
self
.
local_blocks
>=
0
assert
self
.
vert_stride
>=
1
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
total_heads
=
tp_size
*
self
.
num_heads
total_kv_heads
=
tp_size
*
self
.
num_kv_heads
if
self
.
homo_head
:
self
.
head_sliding_step
=
0
elif
self
.
homo_head_group
:
head_sliding_step
=
get_head_sliding_step
(
total_kv_heads
,
self
.
vert_stride
)
# negative indicates sliding along kv heads, i.e., homo q group
self
.
head_sliding_step
=
-
head_sliding_step
else
:
self
.
head_sliding_step
=
get_head_sliding_step
(
total_heads
,
self
.
vert_stride
)
self
.
active_head_range
=
(
tp_rank
*
self
.
num_heads
,
(
tp_rank
+
1
)
*
self
.
num_heads
,
)
class
BlocksparseFlashAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"BLOCK_SPARSE_FLASH_ATTN"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"BlocksparseFlashAttentionImpl"
]:
return
BlocksparseFlashAttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
BlocksparseFlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"BlocksparseFlashAttentionMetadataBuilder"
]:
return
BlocksparseFlashAttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
class
BlocksparseFlashAttentionMetadata
(
AttentionMetadata
):
"""A copy of Metadata for FlashAttentionBackend,
to avoid having to install flash_attn.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
max_query_len
:
Optional
[
int
]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# Max number of query tokens for among request in the batch.
max_decode_query_len
:
Optional
[
int
]
=
None
_cached_prefill_metadata
:
Optional
[
"BlocksparseFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"BlocksparseFlashAttentionMetadata"
]
=
None
block_tables_list
:
Optional
[
List
[
int
]]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"BlocksparseFlashAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
assert
self
.
seq_lens
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
query_start_loc
is
not
None
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
block_tables
is
not
None
assert
self
.
seq_start_loc
is
not
None
self
.
_cached_prefill_metadata
=
BlocksparseFlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
enable_kv_scales_calculation
=
self
.
enable_kv_scales_calculation
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_seq_len
=
0
,
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
],
seq_start_loc
=
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
],
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
block_tables_list
=
self
.
block_tables_list
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"BlocksparseFlashAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
assert
self
.
block_tables
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
self
.
_cached_decode_metadata
=
BlocksparseFlashAttentionMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
block_tables_list
=
self
.
block_tables_list
)
return
self
.
_cached_decode_metadata
class
BlocksparseFlashAttentionMetadataBuilder
(
CommonMetadataBuilder
[
BlocksparseFlashAttentionMetadata
]):
_metadata_cls
=
BlocksparseFlashAttentionMetadata
class
BlocksparseFlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
assert
blocksparse_params
is
not
None
assert
alibi_slopes
is
None
,
ValueError
(
"Alibi not support for blocksparse flash attention."
)
assert
sliding_window
is
None
,
ValueError
(
"sliding_window is invalid for blocksparse attention."
)
assert
logits_soft_cap
is
None
,
ValueError
(
"logits_soft_cap is invalid for blocksparse attention."
)
if
"num_heads"
not
in
blocksparse_params
:
blocksparse_params
[
"num_heads"
]
=
num_heads
if
"num_kv_heads"
not
in
blocksparse_params
:
blocksparse_params
[
"num_kv_heads"
]
=
num_kv_heads
or
num_heads
self
.
blocksparse_params
=
BlocksparseParams
(
**
blocksparse_params
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
alibi_slopes
=
alibi_slopes
self
.
num_kv_heads
=
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
local_blocks
=
self
.
blocksparse_params
.
local_blocks
self
.
vert_stride
=
self
.
blocksparse_params
.
vert_stride
self
.
sparse_block_size
=
self
.
blocksparse_params
.
block_size
self
.
head_sliding_step
=
self
.
blocksparse_params
.
head_sliding_step
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
total_num_heads
=
num_heads
*
self
.
tp_size
self
.
bs_attn
=
LocalStridedBlockSparseAttn
(
total_num_heads
,
self
.
blocksparse_params
.
max_seqlen
,
self
.
blocksparse_params
.
local_blocks
,
self
.
blocksparse_params
.
vert_stride
,
self
.
blocksparse_params
.
block_size
,
homo_head
=
self
.
blocksparse_params
.
homo_head
,
active_head_range
=
self
.
blocksparse_params
.
active_head_range
,
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl"
)
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
BlocksparseFlashAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for BlocksparseFlashAttentionImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
.
numel
()
>
0
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
assert
kv_cache
.
numel
()
==
0
\
or
prefill_meta
.
block_tables
is
None
\
or
prefill_meta
.
block_tables
.
numel
()
==
0
,
\
"Does not support prefix-enabled attention."
output
=
self
.
bs_attn
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
sm_scale
=
self
.
scale
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
=
PagedAttention
.
forward_decode
(
query
,
key_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
self
.
blocksparse_params
.
max_seqlen
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
layer
.
_k_scale
,
layer
.
_v_scale
,
tp_rank
=
self
.
tp_rank
,
blocksparse_local_blocks
=
self
.
local_blocks
,
blocksparse_vert_stride
=
self
.
vert_stride
,
blocksparse_block_size
=
self
.
sparse_block_size
,
blocksparse_head_sliding_step
=
self
.
head_sliding_step
,
)
assert
output
is
not
None
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/cpu_mla.py
deleted
100644 → 0
View file @
0628e4b4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
vllm._custom_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadataBuilder
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.mla.common
import
MLACommonImpl
,
MLACommonState
from
vllm.attention.backends.torch_sdpa
import
TorchSDPAMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.cpu_model_runner
import
ModelInputForCPUBuilder
class
CPUMLABackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"CPU_MLA"
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"CPUMLAMetadata"
]:
return
CPUMLAMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"CPUMLAMetadataBuilder"
]:
return
CPUMLAMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"MLACommonState"
]:
return
MLACommonState
@
staticmethod
def
get_impl_cls
()
->
Type
[
"CPUMLAImpl"
]:
return
CPUMLAImpl
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
# assumed to be 1 for MLA
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
ops
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
ops
.
copy_blocks_mla
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
576
]
@
dataclass
class
CPUMLAMetadata
(
TorchSDPAMetadata
):
# New for MLA
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions
:
torch
.
Tensor
=
None
# required by MLACommonImpl
is_profile_run
:
bool
=
False
class
CPUMLAMetadataBuilder
(
AttentionMetadataBuilder
[
CPUMLAMetadata
]):
def
__init__
(
self
,
input_builder
:
ModelInputForCPUBuilder
)
->
None
:
self
.
chunked_prefill
=
input_builder
.
chunked_prefill
self
.
input_builder
=
input_builder
assert
not
self
.
chunked_prefill
,
\
"chunked prefill is currently not supported"
def
prepare
(
self
):
self
.
input_data
=
self
.
input_builder
.
input_data
def
build
(
self
,
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
):
input_data
=
self
.
input_data
prefill_seq_lens
=
seq_lens
[
0
:
input_data
.
num_prefills
]
prefill_query_lens
=
query_lens
[
0
:
input_data
.
num_prefills
]
slot_mapping
=
torch
.
tensor
(
input_data
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
# metadata for prefill
if
input_data
.
num_prefills
>
0
:
query_lens_tensor
=
torch
.
tensor
(
prefill_query_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
kv_lens_tensor
=
torch
.
tensor
(
prefill_seq_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
query_start_loc
=
torch
.
zeros
(
input_data
.
num_prefills
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
kv_start_loc
=
torch
.
zeros
(
input_data
.
num_prefills
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
torch
.
int32
,
out
=
query_start_loc
[
1
:])
torch
.
cumsum
(
kv_lens_tensor
,
dim
=
0
,
dtype
=
torch
.
int32
,
out
=
kv_start_loc
[
1
:])
max_query_len
=
max
(
prefill_query_lens
)
max_kv_len
=
max
(
prefill_seq_lens
)
# for chunked-prefill
if
self
.
chunked_prefill
:
prefill_block_tables
=
make_tensor_with_pad
(
self
.
input_data
.
prefill_block_tables
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
else
:
prefill_block_tables
=
None
else
:
query_start_loc
=
None
kv_start_loc
=
None
max_query_len
=
None
max_kv_len
=
None
prefill_block_tables
=
None
# metadata for decode
if
input_data
.
num_decode_tokens
!=
0
:
seq_lens_tensor
=
torch
.
tensor
(
input_data
.
seq_lens
[
input_data
.
num_prefills
:],
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
block_tables
=
make_tensor_with_pad
(
self
.
input_data
.
decode_block_tables
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
else
:
block_tables
=
torch
.
tensor
([])
seq_lens_tensor
=
torch
.
tensor
(
input_data
.
seq_lens
[:
input_data
.
num_prefills
],
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
# For multi-modal models
placeholder_index_maps
=
None
if
len
(
input_data
.
multi_modal_inputs_list
)
!=
0
:
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
input_data
.
multi_modal_placeholder_maps
.
items
()
}
return
CPUMLAMetadata
(
chunked_prefill
=
self
.
chunked_prefill
,
seq_lens
=
prefill_seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_kv_len
=
max_kv_len
,
prefill_query_start_loc
=
query_start_loc
,
kv_start_loc
=
kv_start_loc
,
max_decode_seq_len
=
input_data
.
max_decode_seq_len
,
num_prefills
=
input_data
.
num_prefills
,
num_prefill_tokens
=
input_data
.
num_prefill_tokens
,
num_decode_tokens
=
input_data
.
num_decode_tokens
,
block_tables
=
block_tables
,
prefill_block_tables
=
prefill_block_tables
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
enable_kv_scales_calculation
=
False
,
input_positions
=
torch
.
tensor
([
self
.
input_data
.
input_positions
]))
class
CPUMLAImpl
(
MLACommonImpl
[
CPUMLAMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"CPUMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap"
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CPUMLAImpl"
)
# states is implemented.
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"CPUMLAImpl with FP8 KV cache not yet supported"
)
def
_forward_prefill
(
self
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
CPUMLAMetadata
,
# type: ignore[override]
)
->
torch
.
Tensor
:
prefill_metadata
=
attn_metadata
.
prefill_metadata
assert
prefill_metadata
is
not
None
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
output
=
torch
.
empty_like
(
q
)
ipex_ops
.
varlen_attention
(
query
=
q
,
key
=
k
,
value
=
v_padded
,
out
=
output
,
seqlen_q
=
prefill_metadata
.
prefill_query_start_loc
,
seqlen_k
=
prefill_metadata
.
prefill_query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
max_query_len
,
pdropout
=
0.0
,
softmax_scale
=
self
.
scale
,
zero_tensors
=
False
,
is_causal
=
True
,
return_softmax
=
False
,
gen_
=
None
,
logits_soft_cap
=
0.0
,
window_size_left
=-
1
,
window_size_right
=-
1
,
alibi_slopes
=
None
,
)
# remove padding
output
=
output
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
return
output
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
CPUMLAMetadata
,
# type: ignore[override]
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
decode_meta
=
attn_metadata
.
decode_metadata
assert
decode_meta
is
not
None
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
o
=
q
.
new_empty
(
q
.
shape
[
0
],
self
.
num_heads
,
self
.
kv_lora_rank
)
# Run MQA
ops
.
mla_decode_kvcache_cpu
(
o
,
q
,
kv_c_and_k_pe_cache
,
self
.
scale
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
)
return
self
.
_v_up_proj
(
o
)
vllm/attention/backends/ipex_attn.py
deleted
100644 → 0
View file @
0628e4b4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm._ipex_ops
import
ipex_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
_PARTITION_SIZE
=
512
class
IpexAttnBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"IPEX"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"IpexAttnBackendImpl"
]:
return
IpexAttnBackendImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"IpexAttnMetadata"
]:
return
IpexAttnMetadata
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
from
vllm._ipex_ops
import
ipex_ops
as
ops
ops
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
from
vllm._ipex_ops
import
ipex_ops
as
ops
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
@
dataclass
class
IpexAttnMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
"""Metadata for IpexAttnBackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
slot_mapping
:
torch
.
Tensor
seq_lens
:
Optional
[
List
[
int
]]
seqlen_q
:
Optional
[
torch
.
Tensor
]
max_seqlen
:
Optional
[
int
]
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"IpexAttnMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_decode_tokens
==
0
:
assert
self
.
num_prefills
>
0
return
self
return
None
@
property
def
decode_metadata
(
self
)
->
Optional
[
"IpexAttnMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_prefills
>
0
:
assert
self
.
num_decode_tokens
==
0
return
None
return
self
class
IpexAttnBackendImpl
(
AttentionImpl
[
IpexAttnMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
use_irope
:
logger
.
warning_once
(
"Using irope in Ipex is not supported yet, it will fall"
" back to global attention for long context."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"IPEX backend does not support block-sparse attention."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
need_mask
=
(
self
.
sliding_window
is
not
None
)
if
logits_soft_cap
is
None
:
logits_soft_cap
=
-
1
self
.
logits_soft_cap
=
logits_soft_cap
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
is_quantized_kv_cache
(
kv_cache_dtype
):
raise
NotImplementedError
(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead."
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl"
)
def
split_kv_cache
(
self
,
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x
=
1
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
//
x
,
-
1
,
x
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
,
-
1
)
return
key_cache
,
value_cache
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for IpexAttentionImpl"
)
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
.
numel
()
>
0
:
key_cache
,
value_cache
=
self
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
ipex_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
layer
.
_k_scale_float
,
layer
.
_v_scale_float
,
)
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
seq_lens
is
not
None
if
(
kv_cache
.
numel
()
==
0
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
if
attn_metadata
.
attn_bias
is
None
:
if
self
.
sliding_window
is
not
None
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
None
,
dtype
=
query
.
dtype
)
attn_metadata
.
attn_bias
=
att_masks
output
=
torch
.
empty
(
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
,
device
=
query
.
device
)
ipex_ops
.
varlen_attention
(
query
,
key
,
value
,
output
,
attn_metadata
.
seqlen_q
,
attn_metadata
.
seqlen_q
,
self
.
alibi_slopes
,
attn_metadata
.
max_seqlen
,
attn_metadata
.
max_seqlen
,
pdropout
=
0.0
,
softmax_scale
=
self
.
scale
,
zero_tensors
=
False
,
is_causal
=
True
,
return_softmax
=
False
,
gen_
=
None
,
window_size_left
=-
1
,
window_size_right
=-
1
,
logits_soft_cap
=
self
.
logits_soft_cap
,
)
else
:
# prefix-enabled attention
raise
RuntimeError
(
"IPEX backend doesn't support prefix decoding."
)
else
:
# Decoding run.
max_seq_len
=
attn_metadata
.
max_decode_seq_len
output
=
torch
.
empty_like
(
query
)
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
((
max_seq_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory
# shortage.
use_v1
=
(
max_seq_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
if
use_v1
:
# Run PagedAttention V1.
ipex_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
self
.
num_kv_heads
,
self
.
scale
,
attn_metadata
.
block_tables
,
attn_metadata
.
seq_lens_tensor
,
block_size
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
layer
.
_k_scale_float
,
layer
.
_v_scale_float
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ipex_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
self
.
num_kv_heads
,
self
.
scale
,
attn_metadata
.
block_tables
,
attn_metadata
.
seq_lens_tensor
,
block_size
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
layer
.
_k_scale_float
,
layer
.
_v_scale_float
,
)
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
seq_lens
:
List
[
int
],
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
,
device
=
alibi_slopes
.
device
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
inf_mask
=
torch
.
empty
(
(
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
,
device
=
alibi_slopes
.
device
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
return
attn_biases
def
_make_sliding_window_bias
(
seq_lens
:
List
[
int
],
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
seq_len
in
seq_lens
:
tensor
=
torch
.
full
(
(
1
,
seq_len
,
seq_len
),
dtype
=
dtype
,
fill_value
=
1
,
)
shift
=
0
mask
=
torch
.
tril
(
tensor
,
diagonal
=
shift
).
to
(
dtype
)
# type: ignore
if
window_size
is
not
None
:
mask
=
torch
.
triu
(
mask
,
diagonal
=
shift
-
window_size
+
1
)
mask
=
torch
.
log
(
mask
)
attn_biases
.
append
(
mask
.
to
(
dtype
))
return
attn_biases
vllm/attention/backends/pallas.py
deleted
100644 → 0
View file @
0628e4b4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch_xla.experimental.custom_kernel
# Required to register custom ops.
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
PallasAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"PALLAS"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"PallasAttentionBackendImpl"
]:
return
PallasAttentionBackendImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"PallasMetadata"
]:
return
PallasMetadata
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_kv_heads
,
num_blocks
,
block_size
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
raise
RuntimeError
(
"swap_blocks is not used for the TPU backend."
)
@
torch
.
compile
(
backend
=
"openxla"
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
src_to_dists
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
)
->
None
:
src_indices
,
dst_indices
=
src_to_dists
for
k_cache
,
v_cache
in
kv_caches
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
k_cache
,
True
)
k_cache
[:,
dst_indices
]
=
k_cache
[:,
src_indices
]
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
v_cache
,
True
)
v_cache
[:,
dst_indices
]
=
v_cache
[:,
src_indices
]
@
dataclass
class
PallasMetadata
(
AttentionMetadata
):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
context_lens
:
Optional
[
torch
.
Tensor
]
=
None
effective_query_lens
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
assert
self
.
num_decode_tokens
==
0
return
self
@
property
def
decode_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
block_tables
is
not
None
assert
self
.
context_lens
is
not
None
return
self
class
PallasAttentionBackendImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
use_irope
:
logger
.
warning_once
(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
logits_soft_cap
=
logits_soft_cap
if
head_size
%
128
!=
0
:
raise
NotImplementedError
(
f
"Head size must be a multiple of 128, found
{
head_size
}
."
)
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
"Alibi slopes is not supported."
)
if
sliding_window
is
not
None
:
raise
NotImplementedError
(
"Sliding window is not supported."
)
if
is_quantized_kv_cache
(
kv_cache_dtype
):
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
blocksparse_params
is
not
None
:
raise
NotImplementedError
(
"Blocksparse is not supported."
)
if
torch_xla
.
tpu
.
version
()
<
4
:
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
self
.
megacore_mode
=
None
tpu_env
=
torch_xla
.
tpu
.
get_tpu_env
()
tpu_type
=
(
tpu_env
.
get
(
"ACCELERATOR_TYPE"
,
None
)
or
tpu_env
.
get
(
"TYPE"
,
None
)
or
tpu_env
.
get
(
"TPU_ACCELERATOR_TYPE"
,
None
))
assert
tpu_type
is
not
None
tpu_type
=
tpu_type
.
lower
()
if
((
"lite"
not
in
tpu_type
)
and
(
"v6"
not
in
tpu_type
)):
if
self
.
num_kv_heads
%
2
==
0
:
self
.
megacore_mode
=
"kv_head"
else
:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self
.
megacore_mode
=
"batch"
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl"
)
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
attn_metadata
:
PallasMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for PallasAttentionImpl"
)
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
[
0
].
numel
()
>
0
:
slot_mapping
=
attn_metadata
.
slot_mapping
key_cache
,
value_cache
=
kv_cache
write_to_kv_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
query
=
query
*
self
.
scale
if
attn_metadata
.
num_prefills
>
0
:
if
attn_metadata
.
block_tables
is
None
:
# Prefill without paged KV cache.
assert
seq_len
%
16
==
0
,
(
"Pallas FlashAttention kernel requires seq_len to be a "
f
"multiple of 16 but got
{
seq_len
}
"
)
# Handle GQA/MQA.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
# FlashAttention kernel requires the input shape to be
# [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output
=
torch
.
ops
.
xla
.
flash_attention
(
query
.
permute
(
0
,
2
,
1
,
3
),
key
.
permute
(
0
,
2
,
1
,
3
),
value
.
permute
(
0
,
2
,
1
,
3
),
True
,
)
output
=
output
.
permute
(
0
,
2
,
1
,
3
)
else
:
# Prefill with paged KV cache.
# TODO(woosuk): Tune the below knobs.
num_kv_pages_per_compute_block
=
16
num_queries_per_compute_block
=
16
assert
seq_len
%
num_queries_per_compute_block
==
0
output
=
torch
.
ops
.
xla
.
multi_queries_paged_attention
(
query
,
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
attn_metadata
.
effective_query_lens
,
num_kv_pages_per_compute_block
,
num_queries_per_compute_block
,
use_kernel
=
True
,
attn_logits_soft_cap
=
self
.
logits_soft_cap
,
)
else
:
# Decoding run.
assert
kv_cache
[
0
].
numel
()
>
0
query
=
query
.
squeeze
(
dim
=
1
)
pages_per_compute_block
=
16
# TODO(woosuk): Tune this value.
assert
attn_metadata
.
block_tables
is
not
None
assert
attn_metadata
.
context_lens
is
not
None
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# the kernel compilation will fail. To avoid this, we split the
# batch dimension into smaller chunks and run the kernel multiple
# times.
MAX_SMEM_USAGE
=
512
*
1024
size_per_seq
=
4
*
attn_metadata
.
block_tables
.
shape
[
1
]
max_num_seq
=
MAX_SMEM_USAGE
//
size_per_seq
if
batch_size
<=
max_num_seq
:
output
=
paged_attention
(
query
,
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
pages_per_compute_block
,
self
.
megacore_mode
,
attn_logits_soft_cap
=
self
.
logits_soft_cap
,
)
else
:
chunk_size
=
max_num_seq
# Make sure the chunk size is a multiple of 2.
chunk_size
=
chunk_size
//
2
*
2
num_chunks
=
(
batch_size
+
chunk_size
-
1
)
//
chunk_size
output
=
torch
.
empty_like
(
query
)
for
chunk_idx
in
range
(
num_chunks
):
chunk_start
=
chunk_idx
*
chunk_size
chunk_end
=
chunk_start
+
chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output
=
paged_attention
(
query
[
chunk_start
:
chunk_end
],
key_cache
,
value_cache
,
attn_metadata
.
context_lens
[
chunk_start
:
chunk_end
],
attn_metadata
.
block_tables
[
chunk_start
:
chunk_end
],
pages_per_compute_block
,
self
.
megacore_mode
,
attn_logits_soft_cap
=
self
.
logits_soft_cap
,
)
output
[
chunk_start
:
chunk_end
]
=
chunk_output
# Reshape the output tensor.
return
output
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
def
write_to_kv_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
key_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
value_cache
,
True
)
key
=
key
.
flatten
(
0
,
2
)
value
=
value
.
flatten
(
0
,
2
)
key_cache
=
key_cache
.
flatten
(
0
,
2
)
value_cache
=
value_cache
.
flatten
(
0
,
2
)
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
def
paged_attention
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
pages_per_compute_block
:
int
,
megacore_mode
:
Optional
[
str
],
*
,
attn_logits_soft_cap
:
Optional
[
float
],
)
->
torch
.
Tensor
:
batch_size
=
query
.
shape
[
0
]
if
megacore_mode
==
"batch"
and
batch_size
%
2
!=
0
:
megacore_mode
=
None
else
:
megacore_mode
=
megacore_mode
return
torch
.
ops
.
xla
.
paged_attention
(
query
,
key_cache
,
value_cache
,
context_lens
,
block_tables
,
pages_per_compute_block
,
megacore_mode
=
megacore_mode
,
attn_logits_soft_cap
=
attn_logits_soft_cap
,
)
vllm/attention/backends/torch_sdpa.py
deleted
100644 → 0
View file @
0628e4b4
This diff is collapsed.
Click to expand it.
vllm/attention/backends/tree_decoding_utils.py
View file @
1b78ef29
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Type
,
TypeVar
,
Union
,
Optional
import
torch
from
vllm.attention.backends.blocksparse_attn
import
BlocksparseFlashAttentionImpl
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.ops.paged_attn
import
PagedAttention
...
...
@@ -52,4 +51,4 @@ def move_cache(
kv_cache_dtype
)
else
:
raise
NotImplementedError
(
"Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!"
)
\ No newline at end of file
raise
NotImplementedError
(
"Only ROCmFlash/XFormers backends support move cache for now!"
)
\ No newline at end of file
vllm/zero_overhead/llm_engine.py
View file @
1b78ef29
...
...
@@ -25,7 +25,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.registry
import
MultiModalRegistry
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
,
ParallelSampleSequenceGroup
,
SequenceGroup
,
SequenceGroupBase
,
SequenceGroupMetadata
from
vllm.tracing
import
init_tracer
...
...
@@ -588,7 +587,6 @@ class ZeroOverheadEngine(LLMEngine):
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
Optional
[
SequenceGroup
]:
...
...
@@ -604,7 +602,6 @@ class ZeroOverheadEngine(LLMEngine):
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
)
return
None
...
...
@@ -618,11 +615,10 @@ class ZeroOverheadEngine(LLMEngine):
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
processed_inputs
)
seq
=
ZeroOverheadSequence
(
seq_id
,
decoder_inputs
,
block_size
,
eos_token_id
,
lora_request
,
prompt_adapter_request
)
lora_request
)
encoder_seq
=
(
None
if
encoder_inputs
is
None
else
ZeroOverheadSequence
(
seq_id
,
encoder_inputs
,
block_size
,
eos_token_id
,
lora_request
,
prompt_adapter_request
))
seq_id
,
encoder_inputs
,
block_size
,
eos_token_id
,
lora_request
))
# Create a SequenceGroup based on SamplingParams or PoolingParams
if
isinstance
(
params
,
SamplingParams
):
...
...
@@ -633,7 +629,6 @@ class ZeroOverheadEngine(LLMEngine):
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
,
priority
=
priority
)
elif
isinstance
(
params
,
PoolingParams
):
...
...
@@ -643,7 +638,6 @@ class ZeroOverheadEngine(LLMEngine):
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
,
priority
=
priority
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment