Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99324e25
Commit
99324e25
authored
Jul 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.2' into v0.9.2-ori
parents
cc7f22a8
a5dd03c1
Changes
475
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
253 additions
and
41 deletions
+253
-41
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+6
-1
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+7
-4
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+6
-1
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+6
-1
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+6
-0
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+7
-2
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+45
-6
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+6
-2
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+2
-2
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+6
-1
vllm/attention/layer.py
vllm/attention/layer.py
+21
-8
vllm/attention/ops/ipex_attn.py
vllm/attention/ops/ipex_attn.py
+3
-3
vllm/attention/ops/nki_flash_attn.py
vllm/attention/ops/nki_flash_attn.py
+6
-9
vllm/attention/ops/pallas_kv_cache_update.py
vllm/attention/ops/pallas_kv_cache_update.py
+120
-0
vllm/attention/ops/triton_flash_attention.py
vllm/attention/ops/triton_flash_attention.py
+6
-1
No files found.
Too many changes to show.
To preserve performance only
475 of 475+
files are displayed.
Plain diff
Email patch
vllm/attention/backends/flash_attn.py
View file @
99324e25
...
...
@@ -654,7 +654,6 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
support_head_sizes
=
FlashAttentionBackend
.
get_supported_head_sizes
()
...
...
@@ -673,6 +672,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
...
...
@@ -692,6 +692,11 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for FlashAttentionImpl"
)
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if
not
flash_attn_supports_fp8
()
or
output
.
dtype
!=
torch
.
bfloat16
:
assert
(
...
...
vllm/attention/backends/flashinfer.py
View file @
99324e25
...
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
os
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
...
...
@@ -50,8 +49,7 @@ if TYPE_CHECKING:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
os
.
getenv
(
"FLASHINFER_KV_CACHE_LAYOUT"
,
"NHD"
).
upper
()
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
envs
.
VLLM_KV_CACHE_LAYOUT
or
"NHD"
class
FlashInferBackend
(
AttentionBackend
):
...
...
@@ -957,7 +955,6 @@ class FlashInferImpl(AttentionImpl):
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
attn_type
!=
AttentionType
.
DECODER
:
...
...
@@ -975,8 +972,14 @@ class FlashInferImpl(AttentionImpl):
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashInferMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for FlashInferImpl"
)
# TODO: directly write to output tensor
num_heads
:
int
=
self
.
num_heads
head_size
:
int
=
self
.
head_size
...
...
vllm/attention/backends/hpu_attn.py
View file @
99324e25
...
...
@@ -148,7 +148,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
alibi_slopes_tensor
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
bfloat16
)
self
.
alibi_slopes
=
alibi_slopes_tensor
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
self
.
prefill_impl
==
'fsdpa'
:
...
...
@@ -181,6 +180,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
HPUAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
...
...
@@ -193,6 +193,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for HPUAttentionImpl"
)
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
_
,
seq_len_kv
,
_
=
key
.
shape
...
...
vllm/attention/backends/ipex_attn.py
View file @
99324e25
...
...
@@ -145,7 +145,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
need_mask
=
(
self
.
sliding_window
is
not
None
)
if
logits_soft_cap
is
None
:
...
...
@@ -192,6 +191,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with IPEX varlen_attention and PagedAttention.
...
...
@@ -206,6 +206,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for IpexAttentionImpl"
)
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
...
...
vllm/attention/backends/mla/common.py
View file @
99324e25
...
...
@@ -1334,11 +1334,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
output
is
not
None
:
raise
NotImplementedError
(
"output is not yet supported for MLAImplBase"
)
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for MLAImplBase"
)
if
attn_metadata
.
is_profile_run
and
\
attn_metadata
.
context_chunk_workspace
is
not
None
:
# During the profile run try to simulate to worse case output size
...
...
vllm/attention/backends/pallas.py
View file @
99324e25
...
...
@@ -121,9 +121,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
logits_soft_cap
=
logits_soft_cap
if
head_size
%
128
!=
0
:
...
...
@@ -172,6 +171,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
attn_metadata
:
PallasMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
...
...
@@ -187,6 +187,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for PallasAttentionImpl"
)
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
99324e25
...
...
@@ -17,6 +17,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.config
import
get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
use_rocm_custom_paged_attention
...
...
@@ -37,11 +38,11 @@ def is_rocm_aiter_paged_attn_enabled() -> bool:
@
cache
def
_get_paged_attn_module
()
->
PagedAttention
:
"""
Initializes the appropriate PagedAttention module from `attention/ops`,
Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
The choice of attention module depends on whether
The choice of attention module depends on whether
AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
...
...
@@ -527,7 +528,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
paged_attn_module
=
_get_paged_attn_module
()
...
...
@@ -584,6 +584,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
logger
.
debug
(
"Using naive (SDPA) attention in ROCmBackend"
)
self
.
aiter_kv_scales_initialized
=
False
self
.
force_fp8_attention
=
(
get_current_vllm_config
()
is
not
None
and
get_current_vllm_config
().
model_config
.
override_attention_dtype
==
"fp8"
)
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
...
...
@@ -593,6 +597,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim
).
reshape
(
tokens
,
n_kv_heads
*
n_rep
,
head_dim
))
def
fused_output_quant_supported
(
self
,
dtype
:
torch
.
dtype
,
static
:
bool
,
group_shape
:
tuple
[
int
,
int
]):
if
self
.
use_triton_flash_attn
:
return
dtype
==
current_platform
.
fp8_dtype
(
)
and
static
and
group_shape
==
(
-
1
,
-
1
)
# per-tensor
# Only supported in the Triton backend
return
False
def
forward
(
self
,
layer
:
AttentionLayer
,
...
...
@@ -602,6 +615,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
...
...
@@ -655,6 +669,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
and
not
self
.
use_triton_flash_attn
:
raise
NotImplementedError
(
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now"
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
assert
value
is
not
None
...
...
@@ -770,9 +789,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query
.
dtype
,
seq_lens
,
make_attn_mask
=
causal_mask
)
# type: ignore
use_fp8_scales
=
(
layer
.
_q_scale
and
layer
.
_k_scale
and
layer
.
_v_scale
and
layer
.
_prob_scale
and
self
.
kv_cache_dtype
==
"fp8"
)
and
(
self
.
kv_cache_dtype
==
"fp8"
or
self
.
force_fp8_attention
))
full_scales
=
(
layer
.
_q_scale
.
item
(),
layer
.
_k_scale
.
item
(),
layer
.
_v_scale
.
item
(),
...
...
@@ -791,6 +813,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks
[
0
][
None
]
if
attn_masks
is
not
None
else
None
,
full_scales
,
output_scale
,
)
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
...
...
@@ -880,7 +903,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert
_PARTITION_SIZE_ROCM
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
dtype
=
query
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
...
...
@@ -914,9 +937,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
output_scale
,
)
else
:
output
[
num_prefill_tokens
:]
=
paged_attn
.
forward_decode
(
# PagedAttention does not support fused quant, manually quantize
if
output_scale
is
None
:
out_pa
=
output
[
num_prefill_tokens
:]
else
:
out_pa
=
torch
.
empty_like
(
output
[
num_prefill_tokens
:],
dtype
=
query
.
dtype
)
out_pa
[:]
=
paged_attn
.
forward_decode
(
decode_query
,
key_cache
,
value_cache
,
...
...
@@ -937,6 +968,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
layer
.
_v_scale
,
)
# Manually perform quantization
if
output_scale
is
not
None
:
out_uq
=
out_pa
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
out_q
=
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
ops
.
scaled_fp8_quant
(
out_uq
,
output_scale
,
output
=
out_q
[
num_prefill_tokens
:])
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
...
...
vllm/attention/backends/torch_sdpa.py
View file @
99324e25
...
...
@@ -65,7 +65,7 @@ class TorchSDPABackend(AttentionBackend):
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
raise
NotImplementedError
(
"Swap is not supported in TorchSDPABackend."
)
@
staticmethod
def
copy_blocks
(
...
...
@@ -433,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
or
self
.
sliding_window
is
not
None
)
...
...
@@ -459,6 +458,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
...
...
@@ -473,6 +473,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl"
)
# For warming-up
if
attn_metadata
is
None
:
...
...
vllm/attention/backends/utils.py
View file @
99324e25
...
...
@@ -373,7 +373,7 @@ class CommonAttentionState(AttentionState):
f
"Expected attn_backend name to be either 'XFORMERS',"
\
f
"'ROCM_FLASH', or 'FLASH_ATTN', but "
\
f
"got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_add_additonal_input_buffers_for_enc_dec_model
(
self
.
_add_addit
i
onal_input_buffers_for_enc_dec_model
(
attn_metadata
=
attn_metadata
,
input_buffers
=
input_buffers
)
return
input_buffers
...
...
@@ -427,7 +427,7 @@ class CommonAttentionState(AttentionState):
attn_metadata
.
max_encoder_seq_len
=
self
.
runner
.
max_seq_len_to_capture
attn_metadata
.
num_encoder_tokens
=
0
def
_add_additonal_input_buffers_for_enc_dec_model
(
def
_add_addit
i
onal_input_buffers_for_enc_dec_model
(
self
,
attn_metadata
,
input_buffers
:
Dict
[
str
,
Any
]):
"""
Saves additional input buffers specific to the encoder-decoder model
...
...
vllm/attention/backends/xformers.py
View file @
99324e25
...
...
@@ -415,7 +415,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
...
...
@@ -435,6 +434,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
"XFormersMetadata"
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
...
...
@@ -487,6 +487,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for XFormersImpl"
)
attn_type
=
self
.
attn_type
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
...
...
vllm/attention/layer.py
View file @
99324e25
...
...
@@ -80,6 +80,9 @@ class Attention(nn.Module):
calculate_kv_scales
=
False
if
num_kv_heads
is
None
:
num_kv_heads
=
num_heads
assert
num_heads
%
num_kv_heads
==
0
,
\
f
"num_heads (
{
num_heads
}
) is not "
\
f
"divisible by num_kv_heads (
{
num_kv_heads
}
)"
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
...
...
@@ -206,7 +209,7 @@ class Attention(nn.Module):
if
self
.
use_output
:
output_shape
=
(
output_shape
if
output_shape
is
not
None
else
query
.
shape
)
output
=
torch
.
empty
(
output_shape
,
output
=
torch
.
zeros
(
output_shape
,
dtype
=
query
.
dtype
,
device
=
query
.
device
)
hidden_size
=
output_shape
[
-
1
]
...
...
@@ -291,7 +294,9 @@ class MultiHeadAttention(nn.Module):
self
.
scale
=
scale
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
,
\
f
"num_heads (
{
self
.
num_heads
}
) is not "
\
f
"divisible by num_kv_heads (
{
self
.
num_kv_heads
}
)"
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
dtype
=
torch
.
get_default_dtype
()
...
...
@@ -301,12 +306,17 @@ class MultiHeadAttention(nn.Module):
block_size
=
16
,
is_attention_free
=
False
)
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
if
backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
}:
backend
=
_Backend
.
XFORMERS
if
current_platform
.
is_rocm
():
# currently, only torch_sdpa is supported on rocm
self
.
attn_backend
=
_Backend
.
TORCH_SDPA
else
:
if
backend
in
(
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
,
_Backend
.
FLEX_ATTENTION
):
backend
=
_Backend
.
XFORMERS
self
.
attn_backend
=
backend
if
backend
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
PALLAS_VLLM_V1
}
else
_Backend
.
TORCH_SDPA
self
.
attn_backend
=
backend
if
backend
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
PALLAS_VLLM_V1
}
else
_Backend
.
TORCH_SDPA
def
forward
(
self
,
...
...
@@ -430,6 +440,7 @@ def unified_attention_with_output(
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
wait_for_kv_layer_from_connector
(
layer_name
)
forward_context
:
ForwardContext
=
get_forward_context
()
...
...
@@ -444,7 +455,8 @@ def unified_attention_with_output(
value
,
kv_cache
,
attn_metadata
,
output
=
output
)
output
=
output
,
output_scale
=
output_scale
)
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
...
...
@@ -455,6 +467,7 @@ def unified_attention_with_output_fake(
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
return
...
...
vllm/attention/ops/ipex_attn.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
try
:
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
...
...
@@ -29,7 +29,7 @@ class _PagedAttention:
head_size
:
int
,
*
args
,
)
->
Tuple
[
int
,
...]:
return
(
2
,
num_blocks
,
block_size
*
num_kv_heads
*
head_size
)
return
2
,
num_blocks
,
block_size
*
num_kv_heads
*
head_size
@
staticmethod
def
split_kv_cache
(
...
...
@@ -120,7 +120,7 @@ class _PagedAttention:
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
*
args
,
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
...
...
vllm/attention/ops/nki_flash_attn.py
View file @
99324e25
...
...
@@ -8,9 +8,7 @@ import torch
from
neuronxcc
import
nki
from
neuronxcc.nki.language
import
par_dim
def
ceil_div
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
from
vllm.utils
import
cdiv
def
is_power_of_2
(
x
):
...
...
@@ -35,11 +33,10 @@ def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
(
num_tiles
,
num_blocks_per_tile
))
block_tables_sbuf
=
nl
.
zeros
(
(
ceil_div
(
num_tiles
,
B_P_SIZE
),
par_dim
(
B_P_SIZE
),
num_blocks_per_tile
),
(
cdiv
(
num_tiles
,
B_P_SIZE
),
par_dim
(
B_P_SIZE
),
num_blocks_per_tile
),
dtype
=
nl
.
int32
,
)
for
i
in
nl
.
affine_range
(
c
eil_
div
(
num_tiles
,
B_P_SIZE
)):
for
i
in
nl
.
affine_range
(
cdiv
(
num_tiles
,
B_P_SIZE
)):
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_f
=
nl
.
arange
(
num_blocks_per_tile
)[
None
,
:]
block_tables_sbuf
[
i
,
i_p
,
i_f
]
=
nl
.
load
(
...
...
@@ -83,7 +80,7 @@ def transform_block_tables_for_indirect_load(
assert
is_power_of_2
(
num_blocks_per_tile
),
f
"
{
num_blocks_per_tile
=
}
is not power of 2"
num_loads
=
c
eil_
div
(
num_blocks_per_tile
,
B_P_SIZE
)
num_loads
=
cdiv
(
num_blocks_per_tile
,
B_P_SIZE
)
block_tables_transposed
=
nl
.
ndarray
(
(
num_loads
,
...
...
@@ -165,7 +162,7 @@ def load_kv_tile_from_cache(
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
"""
# load key cache
num_loads
=
c
eil_
div
(
num_blocks_per_large_tile
,
B_P_SIZE
)
num_loads
=
cdiv
(
num_blocks_per_large_tile
,
B_P_SIZE
)
for
load_idx
in
nl
.
affine_range
(
num_loads
):
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_f
=
nl
.
arange
(
tiled_block_size
*
B_D_SIZE
)[
None
,
:]
...
...
@@ -605,7 +602,7 @@ def flash_paged_attention(
)
for
large_k_tile_idx
in
nl
.
sequential_range
(
0
,
num_large_k_tile
):
num_loads
=
c
eil_
div
(
num_blocks_per_large_tile
,
B_P_SIZE
)
num_loads
=
cdiv
(
num_blocks_per_large_tile
,
B_P_SIZE
)
cur_k_tile
=
nl
.
ndarray
(
(
par_dim
(
B_D_SIZE
),
LARGE_TILE_SZ
),
dtype
=
kernel_dtype
,
...
...
vllm/attention/ops/pallas_kv_cache_update.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
jax
from
jax.experimental
import
pallas
as
pl
from
jax.experimental.pallas
import
tpu
as
pltpu
from
vllm.utils
import
cdiv
def
_kv_cache_update_kernel
(
# Prefetch
slices_ref
,
# [3, padded_num_slices], list of (kv_cache_start,
# new_kv_start, slice_len)
# Input
new_kv_hbm_ref
,
# [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref
,
# [total_num_pages * page_size, num_combined_kv_heads,
# head_dim]
# Output
_
,
# [total_num_pages * page_size, num_combined_kv_heads, head_dim]
# Scratch
scratch
,
# [num_slices_per_block, page_size, num_combined_kv_heads,
# head_dim]
sem
,
):
async_copies
=
[]
block_idx
=
pl
.
program_id
(
0
)
num_slices_per_block
=
scratch
.
shape
[
0
]
# Copy from new_kv_hbm_ref to scratch
for
i
in
range
(
num_slices_per_block
):
offset_i
=
i
+
block_idx
*
num_slices_per_block
new_kv_start
=
slices_ref
[
1
,
offset_i
]
length
=
slices_ref
[
2
,
offset_i
]
async_copy
=
pltpu
.
make_async_copy
(
new_kv_hbm_ref
.
at
[
pl
.
ds
(
new_kv_start
,
length
),
...],
scratch
.
at
[
i
,
pl
.
ds
(
0
,
length
),
...],
sem
,
)
async_copy
.
start
()
async_copies
.
append
(
async_copy
)
for
async_copy
in
async_copies
:
async_copy
.
wait
()
# Copy from scratch to kv_cache_hbm_ref
async_copies
.
clear
()
for
i
in
range
(
num_slices_per_block
):
offset_i
=
i
+
block_idx
*
num_slices_per_block
kv_cache_start
=
slices_ref
[
0
,
offset_i
]
length
=
slices_ref
[
2
,
offset_i
]
async_copy
=
pltpu
.
make_async_copy
(
scratch
.
at
[
i
,
pl
.
ds
(
0
,
length
),
...],
kv_cache_hbm_ref
.
at
[
pl
.
ds
(
kv_cache_start
,
length
),
...],
sem
,
)
async_copy
.
start
()
async_copies
.
append
(
async_copy
)
for
async_copy
in
async_copies
:
async_copy
.
wait
()
@
functools
.
partial
(
jax
.
jit
,
static_argnames
=
[
"page_size"
,
"num_slices_per_block"
],
)
def
kv_cache_update
(
new_kv
:
jax
.
Array
,
# [total_num_token, num_combined_kv_heads, head_dim]
slices
:
jax
.
Array
,
# [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
kv_cache
:
jax
.
Array
,
# [total_num_pages * page_size, num_combined_kv_heads, head_dim]
num_kv_update_slices
:
jax
.
Array
,
# [1]
*
,
page_size
:
int
=
32
,
num_slices_per_block
:
int
=
8
,
):
assert
slices
.
shape
[
1
]
%
num_slices_per_block
==
0
_
,
num_combined_kv_heads
,
head_dim
=
new_kv
.
shape
assert
kv_cache
.
shape
[
1
]
==
num_combined_kv_heads
assert
kv_cache
.
shape
[
2
]
==
head_dim
assert
head_dim
%
128
==
0
# TODO: Add dynamic check to make sure that the all the slice lengths are
# smaller or equal to page_size
in_specs
=
[
pl
.
BlockSpec
(
memory_space
=
pltpu
.
TPUMemorySpace
.
ANY
),
pl
.
BlockSpec
(
memory_space
=
pltpu
.
TPUMemorySpace
.
ANY
),
]
out_specs
=
[
pl
.
BlockSpec
(
memory_space
=
pltpu
.
TPUMemorySpace
.
ANY
)]
out_shape
=
[
jax
.
ShapeDtypeStruct
(
kv_cache
.
shape
,
dtype
=
kv_cache
.
dtype
)]
scalar_prefetches
=
[
slices
]
scratch
=
pltpu
.
VMEM
(
(
num_slices_per_block
,
page_size
,
num_combined_kv_heads
,
head_dim
),
new_kv
.
dtype
,
)
scratch_shapes
=
[
scratch
,
pltpu
.
SemaphoreType
.
DMA
,
]
kernel
=
pl
.
pallas_call
(
_kv_cache_update_kernel
,
grid_spec
=
pltpu
.
PrefetchScalarGridSpec
(
num_scalar_prefetch
=
len
(
scalar_prefetches
),
in_specs
=
in_specs
,
out_specs
=
out_specs
,
grid
=
(
cdiv
(
num_kv_update_slices
[
0
],
num_slices_per_block
),
),
scratch_shapes
=
scratch_shapes
,
),
out_shape
=
out_shape
,
input_output_aliases
=
{
len
(
scalar_prefetches
)
+
1
:
0
},
)
return
kernel
(
*
scalar_prefetches
,
new_kv
,
kv_cache
)[
0
]
vllm/attention/ops/triton_flash_attention.py
View file @
99324e25
...
...
@@ -25,9 +25,14 @@ Not currently supported:
import
torch
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
on_gfx1x
from
vllm.triton_utils
import
tl
,
triton
# Avoid misleading ROCm warning.
if
current_platform
.
is_rocm
():
from
vllm.platforms.rocm
import
on_gfx1x
else
:
on_gfx1x
=
lambda
*
args
,
**
kwargs
:
False
torch_dtype
:
tl
.
constexpr
=
torch
.
float16
...
...
Prev
1
…
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment