Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99324e25
Commit
99324e25
authored
Jul 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.2' into v0.9.2-ori
parents
cc7f22a8
a5dd03c1
Changes
475
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
253 additions
and
41 deletions
+253
-41
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+6
-1
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+7
-4
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+6
-1
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+6
-1
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+6
-0
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+7
-2
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+45
-6
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+6
-2
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+2
-2
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+6
-1
vllm/attention/layer.py
vllm/attention/layer.py
+21
-8
vllm/attention/ops/ipex_attn.py
vllm/attention/ops/ipex_attn.py
+3
-3
vllm/attention/ops/nki_flash_attn.py
vllm/attention/ops/nki_flash_attn.py
+6
-9
vllm/attention/ops/pallas_kv_cache_update.py
vllm/attention/ops/pallas_kv_cache_update.py
+120
-0
vllm/attention/ops/triton_flash_attention.py
vllm/attention/ops/triton_flash_attention.py
+6
-1
No files found.
Too many changes to show.
To preserve performance only
475 of 475+
files are displayed.
Plain diff
Email patch
vllm/attention/backends/flash_attn.py
View file @
99324e25
...
@@ -654,7 +654,6 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -654,7 +654,6 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap
=
0
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
support_head_sizes
=
FlashAttentionBackend
.
get_supported_head_sizes
()
support_head_sizes
=
FlashAttentionBackend
.
get_supported_head_sizes
()
...
@@ -673,6 +672,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -673,6 +672,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
attn_metadata
:
FlashAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
"""Forward pass with FlashAttention.
...
@@ -692,6 +692,11 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -692,6 +692,11 @@ class FlashAttentionImpl(AttentionImpl):
"""
"""
assert
output
is
not
None
,
"Output tensor must be provided."
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for FlashAttentionImpl"
)
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if
not
flash_attn_supports_fp8
()
or
output
.
dtype
!=
torch
.
bfloat16
:
if
not
flash_attn_supports_fp8
()
or
output
.
dtype
!=
torch
.
bfloat16
:
assert
(
assert
(
...
...
vllm/attention/backends/flashinfer.py
View file @
99324e25
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
dataclasses
import
os
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -50,8 +49,7 @@ if TYPE_CHECKING:
...
@@ -50,8 +49,7 @@ if TYPE_CHECKING:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
os
.
getenv
(
"FLASHINFER_KV_CACHE_LAYOUT"
,
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
envs
.
VLLM_KV_CACHE_LAYOUT
or
"NHD"
"NHD"
).
upper
()
class
FlashInferBackend
(
AttentionBackend
):
class
FlashInferBackend
(
AttentionBackend
):
...
@@ -957,7 +955,6 @@ class FlashInferImpl(AttentionImpl):
...
@@ -957,7 +955,6 @@ class FlashInferImpl(AttentionImpl):
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
logits_soft_cap
=
logits_soft_cap
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
attn_type
!=
AttentionType
.
DECODER
:
if
attn_type
!=
AttentionType
.
DECODER
:
...
@@ -975,8 +972,14 @@ class FlashInferImpl(AttentionImpl):
...
@@ -975,8 +972,14 @@ class FlashInferImpl(AttentionImpl):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashInferMetadata
,
attn_metadata
:
FlashInferMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for FlashInferImpl"
)
# TODO: directly write to output tensor
# TODO: directly write to output tensor
num_heads
:
int
=
self
.
num_heads
num_heads
:
int
=
self
.
num_heads
head_size
:
int
=
self
.
head_size
head_size
:
int
=
self
.
head_size
...
...
vllm/attention/backends/hpu_attn.py
View file @
99324e25
...
@@ -148,7 +148,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -148,7 +148,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
alibi_slopes_tensor
=
torch
.
tensor
(
alibi_slopes
,
alibi_slopes_tensor
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
bfloat16
)
dtype
=
torch
.
bfloat16
)
self
.
alibi_slopes
=
alibi_slopes_tensor
self
.
alibi_slopes
=
alibi_slopes_tensor
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
self
.
prefill_impl
==
'fsdpa'
:
if
self
.
prefill_impl
==
'fsdpa'
:
...
@@ -181,6 +180,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -181,6 +180,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
HPUAttentionMetadata
,
attn_metadata
:
HPUAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
"""Forward pass with xFormers and PagedAttention.
...
@@ -193,6 +193,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -193,6 +193,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for HPUAttentionImpl"
)
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
_
,
seq_len_kv
,
_
=
key
.
shape
_
,
seq_len_kv
,
_
=
key
.
shape
...
...
vllm/attention/backends/ipex_attn.py
View file @
99324e25
...
@@ -145,7 +145,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -145,7 +145,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
need_mask
=
(
self
.
sliding_window
is
not
None
)
self
.
need_mask
=
(
self
.
sliding_window
is
not
None
)
if
logits_soft_cap
is
None
:
if
logits_soft_cap
is
None
:
...
@@ -192,6 +191,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -192,6 +191,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with IPEX varlen_attention and PagedAttention.
"""Forward pass with IPEX varlen_attention and PagedAttention.
...
@@ -206,6 +206,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -206,6 +206,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for IpexAttentionImpl"
)
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
...
...
vllm/attention/backends/mla/common.py
View file @
99324e25
...
@@ -1334,11 +1334,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1334,11 +1334,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
attn_metadata
:
T
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
output
is
not
None
:
if
output
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"output is not yet supported for MLAImplBase"
)
"output is not yet supported for MLAImplBase"
)
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for MLAImplBase"
)
if
attn_metadata
.
is_profile_run
and
\
if
attn_metadata
.
is_profile_run
and
\
attn_metadata
.
context_chunk_workspace
is
not
None
:
attn_metadata
.
context_chunk_workspace
is
not
None
:
# During the profile run try to simulate to worse case output size
# During the profile run try to simulate to worse case output size
...
...
vllm/attention/backends/pallas.py
View file @
99324e25
...
@@ -121,9 +121,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -121,9 +121,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
logits_soft_cap
=
logits_soft_cap
self
.
logits_soft_cap
=
logits_soft_cap
if
head_size
%
128
!=
0
:
if
head_size
%
128
!=
0
:
...
@@ -172,6 +171,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -172,6 +171,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
attn_metadata
:
PallasMetadata
,
attn_metadata
:
PallasMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
"""Forward pass with Pallas attention.
...
@@ -187,6 +187,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -187,6 +187,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
shape = [batch_size, seq_len, num_heads * head_size]
"""
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for PallasAttentionImpl"
)
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
99324e25
...
@@ -17,6 +17,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
...
@@ -17,6 +17,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder
)
CommonMetadataBuilder
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.config
import
get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
use_rocm_custom_paged_attention
from
vllm.platforms.rocm
import
use_rocm_custom_paged_attention
...
@@ -37,11 +38,11 @@ def is_rocm_aiter_paged_attn_enabled() -> bool:
...
@@ -37,11 +38,11 @@ def is_rocm_aiter_paged_attn_enabled() -> bool:
@
cache
@
cache
def
_get_paged_attn_module
()
->
PagedAttention
:
def
_get_paged_attn_module
()
->
PagedAttention
:
"""
"""
Initializes the appropriate PagedAttention module from `attention/ops`,
Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function
which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
The choice of attention module depends on whether
The choice of attention module depends on whether
AITER paged attention is enabled:
AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
...
@@ -527,7 +528,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -527,7 +528,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
paged_attn_module
=
_get_paged_attn_module
()
self
.
paged_attn_module
=
_get_paged_attn_module
()
...
@@ -584,6 +584,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -584,6 +584,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
logger
.
debug
(
"Using naive (SDPA) attention in ROCmBackend"
)
logger
.
debug
(
"Using naive (SDPA) attention in ROCmBackend"
)
self
.
aiter_kv_scales_initialized
=
False
self
.
aiter_kv_scales_initialized
=
False
self
.
force_fp8_attention
=
(
get_current_vllm_config
()
is
not
None
and
get_current_vllm_config
().
model_config
.
override_attention_dtype
==
"fp8"
)
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
...
@@ -593,6 +597,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -593,6 +597,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim
).
reshape
(
tokens
,
n_kv_heads
*
n_rep
,
head_dim
).
reshape
(
tokens
,
n_kv_heads
*
n_rep
,
head_dim
))
head_dim
))
def
fused_output_quant_supported
(
self
,
dtype
:
torch
.
dtype
,
static
:
bool
,
group_shape
:
tuple
[
int
,
int
]):
if
self
.
use_triton_flash_attn
:
return
dtype
==
current_platform
.
fp8_dtype
(
)
and
static
and
group_shape
==
(
-
1
,
-
1
)
# per-tensor
# Only supported in the Triton backend
return
False
def
forward
(
def
forward
(
self
,
self
,
layer
:
AttentionLayer
,
layer
:
AttentionLayer
,
...
@@ -602,6 +615,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -602,6 +615,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
"""Forward pass with FlashAttention and PagedAttention.
...
@@ -655,6 +669,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -655,6 +669,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"""
"""
assert
output
is
not
None
,
"Output tensor must be provided."
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
and
not
self
.
use_triton_flash_attn
:
raise
NotImplementedError
(
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now"
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
if
key
is
not
None
:
assert
value
is
not
None
assert
value
is
not
None
...
@@ -770,9 +789,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -770,9 +789,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query
.
dtype
,
query
.
dtype
,
seq_lens
,
seq_lens
,
make_attn_mask
=
causal_mask
)
# type: ignore
make_attn_mask
=
causal_mask
)
# type: ignore
use_fp8_scales
=
(
layer
.
_q_scale
and
layer
.
_k_scale
use_fp8_scales
=
(
layer
.
_q_scale
and
layer
.
_k_scale
and
layer
.
_v_scale
and
layer
.
_prob_scale
and
layer
.
_v_scale
and
layer
.
_prob_scale
and
self
.
kv_cache_dtype
==
"fp8"
)
and
(
self
.
kv_cache_dtype
==
"fp8"
or
self
.
force_fp8_attention
))
full_scales
=
(
full_scales
=
(
layer
.
_q_scale
.
item
(),
layer
.
_k_scale
.
item
(),
layer
.
_q_scale
.
item
(),
layer
.
_k_scale
.
item
(),
layer
.
_v_scale
.
item
(),
layer
.
_v_scale
.
item
(),
...
@@ -791,6 +813,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -791,6 +813,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks
[
0
][
None
]
attn_masks
[
0
][
None
]
if
attn_masks
is
not
None
else
None
,
if
attn_masks
is
not
None
else
None
,
full_scales
,
full_scales
,
output_scale
,
)
)
elif
self
.
use_naive_attn
:
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
...
@@ -880,7 +903,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -880,7 +903,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert
_PARTITION_SIZE_ROCM
%
block_size
==
0
assert
_PARTITION_SIZE_ROCM
%
block_size
==
0
tmp_output
=
torch
.
empty
(
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
dtype
=
query
.
dtype
,
device
=
output
.
device
,
device
=
output
.
device
,
)
)
exp_sums
=
torch
.
empty
(
exp_sums
=
torch
.
empty
(
...
@@ -914,9 +937,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -914,9 +937,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_k_scale
,
layer
.
_v_scale
,
layer
.
_v_scale
,
output_scale
,
)
)
else
:
else
:
output
[
num_prefill_tokens
:]
=
paged_attn
.
forward_decode
(
# PagedAttention does not support fused quant, manually quantize
if
output_scale
is
None
:
out_pa
=
output
[
num_prefill_tokens
:]
else
:
out_pa
=
torch
.
empty_like
(
output
[
num_prefill_tokens
:],
dtype
=
query
.
dtype
)
out_pa
[:]
=
paged_attn
.
forward_decode
(
decode_query
,
decode_query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
...
@@ -937,6 +968,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -937,6 +968,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
layer
.
_v_scale
,
layer
.
_v_scale
,
)
)
# Manually perform quantization
if
output_scale
is
not
None
:
out_uq
=
out_pa
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
out_q
=
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
ops
.
scaled_fp8_quant
(
out_uq
,
output_scale
,
output
=
out_q
[
num_prefill_tokens
:])
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
...
...
vllm/attention/backends/torch_sdpa.py
View file @
99324e25
...
@@ -65,7 +65,7 @@ class TorchSDPABackend(AttentionBackend):
...
@@ -65,7 +65,7 @@ class TorchSDPABackend(AttentionBackend):
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
raise
NotImplementedError
(
"Swap is not supported in TorchSDPABackend."
)
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
...
@@ -433,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -433,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
or
self
.
sliding_window
is
not
None
)
or
self
.
sliding_window
is
not
None
)
...
@@ -459,6 +458,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -459,6 +458,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
"""Forward pass with torch SDPA and PagedAttention.
...
@@ -473,6 +473,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -473,6 +473,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl"
)
# For warming-up
# For warming-up
if
attn_metadata
is
None
:
if
attn_metadata
is
None
:
...
...
vllm/attention/backends/utils.py
View file @
99324e25
...
@@ -373,7 +373,7 @@ class CommonAttentionState(AttentionState):
...
@@ -373,7 +373,7 @@ class CommonAttentionState(AttentionState):
f
"Expected attn_backend name to be either 'XFORMERS',"
\
f
"Expected attn_backend name to be either 'XFORMERS',"
\
f
"'ROCM_FLASH', or 'FLASH_ATTN', but "
\
f
"'ROCM_FLASH', or 'FLASH_ATTN', but "
\
f
"got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
f
"got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_add_additonal_input_buffers_for_enc_dec_model
(
self
.
_add_addit
i
onal_input_buffers_for_enc_dec_model
(
attn_metadata
=
attn_metadata
,
input_buffers
=
input_buffers
)
attn_metadata
=
attn_metadata
,
input_buffers
=
input_buffers
)
return
input_buffers
return
input_buffers
...
@@ -427,7 +427,7 @@ class CommonAttentionState(AttentionState):
...
@@ -427,7 +427,7 @@ class CommonAttentionState(AttentionState):
attn_metadata
.
max_encoder_seq_len
=
self
.
runner
.
max_seq_len_to_capture
attn_metadata
.
max_encoder_seq_len
=
self
.
runner
.
max_seq_len_to_capture
attn_metadata
.
num_encoder_tokens
=
0
attn_metadata
.
num_encoder_tokens
=
0
def
_add_additonal_input_buffers_for_enc_dec_model
(
def
_add_addit
i
onal_input_buffers_for_enc_dec_model
(
self
,
attn_metadata
,
input_buffers
:
Dict
[
str
,
Any
]):
self
,
attn_metadata
,
input_buffers
:
Dict
[
str
,
Any
]):
"""
"""
Saves additional input buffers specific to the encoder-decoder model
Saves additional input buffers specific to the encoder-decoder model
...
...
vllm/attention/backends/xformers.py
View file @
99324e25
...
@@ -415,7 +415,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -415,7 +415,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
...
@@ -435,6 +434,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -435,6 +434,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
"XFormersMetadata"
,
attn_metadata
:
"XFormersMetadata"
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
"""Forward pass with xFormers and PagedAttention.
...
@@ -487,6 +487,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -487,6 +487,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for XFormersImpl"
)
attn_type
=
self
.
attn_type
attn_type
=
self
.
attn_type
# Check that appropriate attention metadata attributes are
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
# selected for the desired attention type
...
...
vllm/attention/layer.py
View file @
99324e25
...
@@ -80,6 +80,9 @@ class Attention(nn.Module):
...
@@ -80,6 +80,9 @@ class Attention(nn.Module):
calculate_kv_scales
=
False
calculate_kv_scales
=
False
if
num_kv_heads
is
None
:
if
num_kv_heads
is
None
:
num_kv_heads
=
num_heads
num_kv_heads
=
num_heads
assert
num_heads
%
num_kv_heads
==
0
,
\
f
"num_heads (
{
num_heads
}
) is not "
\
f
"divisible by num_kv_heads (
{
num_kv_heads
}
)"
# The default k/v_scale is set to 1.0. This is ignored
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# when kv-cache is not fp8, and should be used with
...
@@ -206,7 +209,7 @@ class Attention(nn.Module):
...
@@ -206,7 +209,7 @@ class Attention(nn.Module):
if
self
.
use_output
:
if
self
.
use_output
:
output_shape
=
(
output_shape
output_shape
=
(
output_shape
if
output_shape
is
not
None
else
query
.
shape
)
if
output_shape
is
not
None
else
query
.
shape
)
output
=
torch
.
empty
(
output_shape
,
output
=
torch
.
zeros
(
output_shape
,
dtype
=
query
.
dtype
,
dtype
=
query
.
dtype
,
device
=
query
.
device
)
device
=
query
.
device
)
hidden_size
=
output_shape
[
-
1
]
hidden_size
=
output_shape
[
-
1
]
...
@@ -291,7 +294,9 @@ class MultiHeadAttention(nn.Module):
...
@@ -291,7 +294,9 @@ class MultiHeadAttention(nn.Module):
self
.
scale
=
scale
self
.
scale
=
scale
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
,
\
f
"num_heads (
{
self
.
num_heads
}
) is not "
\
f
"divisible by num_kv_heads (
{
self
.
num_kv_heads
}
)"
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
dtype
=
torch
.
get_default_dtype
()
dtype
=
torch
.
get_default_dtype
()
...
@@ -301,12 +306,17 @@ class MultiHeadAttention(nn.Module):
...
@@ -301,12 +306,17 @@ class MultiHeadAttention(nn.Module):
block_size
=
16
,
block_size
=
16
,
is_attention_free
=
False
)
is_attention_free
=
False
)
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
if
backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
}:
if
current_platform
.
is_rocm
():
backend
=
_Backend
.
XFORMERS
# currently, only torch_sdpa is supported on rocm
self
.
attn_backend
=
_Backend
.
TORCH_SDPA
else
:
if
backend
in
(
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
,
_Backend
.
FLEX_ATTENTION
):
backend
=
_Backend
.
XFORMERS
self
.
attn_backend
=
backend
if
backend
in
{
self
.
attn_backend
=
backend
if
backend
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
PALLAS_VLLM_V1
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
PALLAS_VLLM_V1
}
else
_Backend
.
TORCH_SDPA
}
else
_Backend
.
TORCH_SDPA
def
forward
(
def
forward
(
self
,
self
,
...
@@ -430,6 +440,7 @@ def unified_attention_with_output(
...
@@ -430,6 +440,7 @@ def unified_attention_with_output(
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
wait_for_kv_layer_from_connector
(
layer_name
)
wait_for_kv_layer_from_connector
(
layer_name
)
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
...
@@ -444,7 +455,8 @@ def unified_attention_with_output(
...
@@ -444,7 +455,8 @@ def unified_attention_with_output(
value
,
value
,
kv_cache
,
kv_cache
,
attn_metadata
,
attn_metadata
,
output
=
output
)
output
=
output
,
output_scale
=
output_scale
)
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
maybe_save_kv_layer_to_connector
(
layer_name
,
kv_cache
)
...
@@ -455,6 +467,7 @@ def unified_attention_with_output_fake(
...
@@ -455,6 +467,7 @@ def unified_attention_with_output_fake(
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
str
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
return
return
...
...
vllm/attention/ops/ipex_attn.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
try
:
try
:
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
...
@@ -29,7 +29,7 @@ class _PagedAttention:
...
@@ -29,7 +29,7 @@ class _PagedAttention:
head_size
:
int
,
head_size
:
int
,
*
args
,
*
args
,
)
->
Tuple
[
int
,
...]:
)
->
Tuple
[
int
,
...]:
return
(
2
,
num_blocks
,
block_size
*
num_kv_heads
*
head_size
)
return
2
,
num_blocks
,
block_size
*
num_kv_heads
*
head_size
@
staticmethod
@
staticmethod
def
split_kv_cache
(
def
split_kv_cache
(
...
@@ -120,7 +120,7 @@ class _PagedAttention:
...
@@ -120,7 +120,7 @@ class _PagedAttention:
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
*
args
,
*
args
,
)
->
None
:
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
...
...
vllm/attention/ops/nki_flash_attn.py
View file @
99324e25
...
@@ -8,9 +8,7 @@ import torch
...
@@ -8,9 +8,7 @@ import torch
from
neuronxcc
import
nki
from
neuronxcc
import
nki
from
neuronxcc.nki.language
import
par_dim
from
neuronxcc.nki.language
import
par_dim
from
vllm.utils
import
cdiv
def
ceil_div
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
def
is_power_of_2
(
x
):
def
is_power_of_2
(
x
):
...
@@ -35,11 +33,10 @@ def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
...
@@ -35,11 +33,10 @@ def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
(
num_tiles
,
num_blocks_per_tile
))
(
num_tiles
,
num_blocks_per_tile
))
block_tables_sbuf
=
nl
.
zeros
(
block_tables_sbuf
=
nl
.
zeros
(
(
ceil_div
(
num_tiles
,
(
cdiv
(
num_tiles
,
B_P_SIZE
),
par_dim
(
B_P_SIZE
),
num_blocks_per_tile
),
B_P_SIZE
),
par_dim
(
B_P_SIZE
),
num_blocks_per_tile
),
dtype
=
nl
.
int32
,
dtype
=
nl
.
int32
,
)
)
for
i
in
nl
.
affine_range
(
c
eil_
div
(
num_tiles
,
B_P_SIZE
)):
for
i
in
nl
.
affine_range
(
cdiv
(
num_tiles
,
B_P_SIZE
)):
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_f
=
nl
.
arange
(
num_blocks_per_tile
)[
None
,
:]
i_f
=
nl
.
arange
(
num_blocks_per_tile
)[
None
,
:]
block_tables_sbuf
[
i
,
i_p
,
i_f
]
=
nl
.
load
(
block_tables_sbuf
[
i
,
i_p
,
i_f
]
=
nl
.
load
(
...
@@ -83,7 +80,7 @@ def transform_block_tables_for_indirect_load(
...
@@ -83,7 +80,7 @@ def transform_block_tables_for_indirect_load(
assert
is_power_of_2
(
assert
is_power_of_2
(
num_blocks_per_tile
),
f
"
{
num_blocks_per_tile
=
}
is not power of 2"
num_blocks_per_tile
),
f
"
{
num_blocks_per_tile
=
}
is not power of 2"
num_loads
=
c
eil_
div
(
num_blocks_per_tile
,
B_P_SIZE
)
num_loads
=
cdiv
(
num_blocks_per_tile
,
B_P_SIZE
)
block_tables_transposed
=
nl
.
ndarray
(
block_tables_transposed
=
nl
.
ndarray
(
(
(
num_loads
,
num_loads
,
...
@@ -165,7 +162,7 @@ def load_kv_tile_from_cache(
...
@@ -165,7 +162,7 @@ def load_kv_tile_from_cache(
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
"""
"""
# load key cache
# load key cache
num_loads
=
c
eil_
div
(
num_blocks_per_large_tile
,
B_P_SIZE
)
num_loads
=
cdiv
(
num_blocks_per_large_tile
,
B_P_SIZE
)
for
load_idx
in
nl
.
affine_range
(
num_loads
):
for
load_idx
in
nl
.
affine_range
(
num_loads
):
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_f
=
nl
.
arange
(
tiled_block_size
*
B_D_SIZE
)[
None
,
:]
i_f
=
nl
.
arange
(
tiled_block_size
*
B_D_SIZE
)[
None
,
:]
...
@@ -605,7 +602,7 @@ def flash_paged_attention(
...
@@ -605,7 +602,7 @@ def flash_paged_attention(
)
)
for
large_k_tile_idx
in
nl
.
sequential_range
(
0
,
num_large_k_tile
):
for
large_k_tile_idx
in
nl
.
sequential_range
(
0
,
num_large_k_tile
):
num_loads
=
c
eil_
div
(
num_blocks_per_large_tile
,
B_P_SIZE
)
num_loads
=
cdiv
(
num_blocks_per_large_tile
,
B_P_SIZE
)
cur_k_tile
=
nl
.
ndarray
(
cur_k_tile
=
nl
.
ndarray
(
(
par_dim
(
B_D_SIZE
),
LARGE_TILE_SZ
),
(
par_dim
(
B_D_SIZE
),
LARGE_TILE_SZ
),
dtype
=
kernel_dtype
,
dtype
=
kernel_dtype
,
...
...
vllm/attention/ops/pallas_kv_cache_update.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
jax
from
jax.experimental
import
pallas
as
pl
from
jax.experimental.pallas
import
tpu
as
pltpu
from
vllm.utils
import
cdiv
def
_kv_cache_update_kernel
(
# Prefetch
slices_ref
,
# [3, padded_num_slices], list of (kv_cache_start,
# new_kv_start, slice_len)
# Input
new_kv_hbm_ref
,
# [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref
,
# [total_num_pages * page_size, num_combined_kv_heads,
# head_dim]
# Output
_
,
# [total_num_pages * page_size, num_combined_kv_heads, head_dim]
# Scratch
scratch
,
# [num_slices_per_block, page_size, num_combined_kv_heads,
# head_dim]
sem
,
):
async_copies
=
[]
block_idx
=
pl
.
program_id
(
0
)
num_slices_per_block
=
scratch
.
shape
[
0
]
# Copy from new_kv_hbm_ref to scratch
for
i
in
range
(
num_slices_per_block
):
offset_i
=
i
+
block_idx
*
num_slices_per_block
new_kv_start
=
slices_ref
[
1
,
offset_i
]
length
=
slices_ref
[
2
,
offset_i
]
async_copy
=
pltpu
.
make_async_copy
(
new_kv_hbm_ref
.
at
[
pl
.
ds
(
new_kv_start
,
length
),
...],
scratch
.
at
[
i
,
pl
.
ds
(
0
,
length
),
...],
sem
,
)
async_copy
.
start
()
async_copies
.
append
(
async_copy
)
for
async_copy
in
async_copies
:
async_copy
.
wait
()
# Copy from scratch to kv_cache_hbm_ref
async_copies
.
clear
()
for
i
in
range
(
num_slices_per_block
):
offset_i
=
i
+
block_idx
*
num_slices_per_block
kv_cache_start
=
slices_ref
[
0
,
offset_i
]
length
=
slices_ref
[
2
,
offset_i
]
async_copy
=
pltpu
.
make_async_copy
(
scratch
.
at
[
i
,
pl
.
ds
(
0
,
length
),
...],
kv_cache_hbm_ref
.
at
[
pl
.
ds
(
kv_cache_start
,
length
),
...],
sem
,
)
async_copy
.
start
()
async_copies
.
append
(
async_copy
)
for
async_copy
in
async_copies
:
async_copy
.
wait
()
@
functools
.
partial
(
jax
.
jit
,
static_argnames
=
[
"page_size"
,
"num_slices_per_block"
],
)
def
kv_cache_update
(
new_kv
:
jax
.
Array
,
# [total_num_token, num_combined_kv_heads, head_dim]
slices
:
jax
.
Array
,
# [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
kv_cache
:
jax
.
Array
,
# [total_num_pages * page_size, num_combined_kv_heads, head_dim]
num_kv_update_slices
:
jax
.
Array
,
# [1]
*
,
page_size
:
int
=
32
,
num_slices_per_block
:
int
=
8
,
):
assert
slices
.
shape
[
1
]
%
num_slices_per_block
==
0
_
,
num_combined_kv_heads
,
head_dim
=
new_kv
.
shape
assert
kv_cache
.
shape
[
1
]
==
num_combined_kv_heads
assert
kv_cache
.
shape
[
2
]
==
head_dim
assert
head_dim
%
128
==
0
# TODO: Add dynamic check to make sure that the all the slice lengths are
# smaller or equal to page_size
in_specs
=
[
pl
.
BlockSpec
(
memory_space
=
pltpu
.
TPUMemorySpace
.
ANY
),
pl
.
BlockSpec
(
memory_space
=
pltpu
.
TPUMemorySpace
.
ANY
),
]
out_specs
=
[
pl
.
BlockSpec
(
memory_space
=
pltpu
.
TPUMemorySpace
.
ANY
)]
out_shape
=
[
jax
.
ShapeDtypeStruct
(
kv_cache
.
shape
,
dtype
=
kv_cache
.
dtype
)]
scalar_prefetches
=
[
slices
]
scratch
=
pltpu
.
VMEM
(
(
num_slices_per_block
,
page_size
,
num_combined_kv_heads
,
head_dim
),
new_kv
.
dtype
,
)
scratch_shapes
=
[
scratch
,
pltpu
.
SemaphoreType
.
DMA
,
]
kernel
=
pl
.
pallas_call
(
_kv_cache_update_kernel
,
grid_spec
=
pltpu
.
PrefetchScalarGridSpec
(
num_scalar_prefetch
=
len
(
scalar_prefetches
),
in_specs
=
in_specs
,
out_specs
=
out_specs
,
grid
=
(
cdiv
(
num_kv_update_slices
[
0
],
num_slices_per_block
),
),
scratch_shapes
=
scratch_shapes
,
),
out_shape
=
out_shape
,
input_output_aliases
=
{
len
(
scalar_prefetches
)
+
1
:
0
},
)
return
kernel
(
*
scalar_prefetches
,
new_kv
,
kv_cache
)[
0
]
vllm/attention/ops/triton_flash_attention.py
View file @
99324e25
...
@@ -25,9 +25,14 @@ Not currently supported:
...
@@ -25,9 +25,14 @@ Not currently supported:
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
on_gfx1x
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
# Avoid misleading ROCm warning.
if
current_platform
.
is_rocm
():
from
vllm.platforms.rocm
import
on_gfx1x
else
:
on_gfx1x
=
lambda
*
args
,
**
kwargs
:
False
torch_dtype
:
tl
.
constexpr
=
torch
.
float16
torch_dtype
:
tl
.
constexpr
=
torch
.
float16
...
...
Prev
1
…
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment