Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
70406eb1
Unverified
Commit
70406eb1
authored
Apr 07, 2026
by
Lucas Wilkinson
Committed by
GitHub
Apr 07, 2026
Browse files
[Attention][V0 Deprecation] Deprecate accept output buffer (#39125)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
08bfedc1
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
92 additions
and
219 deletions
+92
-219
tests/compile/test_config.py
tests/compile/test_config.py
+6
-4
vllm/config/compilation.py
vllm/config/compilation.py
+1
-3
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+53
-96
vllm/model_executor/layers/attention/cross_attention.py
vllm/model_executor/layers/attention/cross_attention.py
+1
-1
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+21
-76
vllm/model_executor/models/extract_hidden_states.py
vllm/model_executor/models/extract_hidden_states.py
+0
-1
vllm/model_executor/models/whisper_causal.py
vllm/model_executor/models/whisper_causal.py
+1
-1
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+1
-5
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+1
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+1
-3
vllm/v1/attention/backends/flash_attn_diffkv.py
vllm/v1/attention/backends/flash_attn_diffkv.py
+1
-2
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+1
-4
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+1
-3
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
+0
-1
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+0
-1
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
+0
-1
vllm/v1/attention/backends/mla/xpu_mla_sparse.py
vllm/v1/attention/backends/mla/xpu_mla_sparse.py
+0
-1
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+1
-4
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
+1
-5
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+1
-4
No files found.
tests/compile/test_config.py
View file @
70406eb1
...
@@ -216,12 +216,14 @@ def test_splitting_ops_dynamic():
...
@@ -216,12 +216,14 @@ def test_splitting_ops_dynamic():
compilation_config
=
CompilationConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
mode
=
CompilationMode
.
VLLM_COMPILE
,
use_inductor_graph_partition
=
True
,
use_inductor_graph_partition
=
True
,
splitting_ops
=
[
"vllm::unified_attention"
],
splitting_ops
=
[
"vllm::unified_attention
_with_output
"
],
)
)
)
)
# with inductor partition we use splitting_ops directly for
# with inductor partition we use splitting_ops directly for
# partition rules
# partition rules
assert
config
.
compilation_config
.
splitting_ops
==
[
"vllm::unified_attention"
]
assert
config
.
compilation_config
.
splitting_ops
==
[
"vllm::unified_attention_with_output"
]
# When attn_fusion pass enabled.
# When attn_fusion pass enabled.
config
=
VllmConfig
(
config
=
VllmConfig
(
...
@@ -281,7 +283,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
...
@@ -281,7 +283,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
mode
=
CompilationMode
.
VLLM_COMPILE
,
mode
=
CompilationMode
.
VLLM_COMPILE
,
use_inductor_graph_partition
=
True
,
use_inductor_graph_partition
=
True
,
splitting_ops
=
[
splitting_ops
=
[
"vllm::unified_attention"
,
"vllm::unified_attention
_with_output
"
,
"vllm::moe_forward"
,
"vllm::moe_forward"
,
"vllm::moe_forward_shared"
,
"vllm::moe_forward_shared"
,
],
],
...
@@ -289,7 +291,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
...
@@ -289,7 +291,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
)
)
splitting_ops
=
config
.
compilation_config
.
splitting_ops
splitting_ops
=
config
.
compilation_config
.
splitting_ops
assert
splitting_ops
==
[
assert
splitting_ops
==
[
"vllm::unified_attention"
,
"vllm::unified_attention
_with_output
"
,
"vllm::moe_forward"
,
"vllm::moe_forward"
,
"vllm::moe_forward_shared"
,
"vllm::moe_forward_shared"
,
]
]
...
...
vllm/config/compilation.py
View file @
70406eb1
...
@@ -282,7 +282,7 @@ class PassConfig:
...
@@ -282,7 +282,7 @@ class PassConfig:
"""
"""
enabled_fusions
=
[
enabled_fusions
=
[
f
.
name
[
len
(
"fuse_"
)
:]
f
.
name
[
len
(
"fuse_"
)
:]
for
f
in
fields
(
self
)
for
f
in
fields
(
self
)
# type: ignore[arg-type]
if
getattr
(
self
,
f
.
name
)
and
f
.
name
.
startswith
(
"fuse_"
)
if
getattr
(
self
,
f
.
name
)
and
f
.
name
.
startswith
(
"fuse_"
)
]
]
...
@@ -711,9 +711,7 @@ class CompilationConfig:
...
@@ -711,9 +711,7 @@ class CompilationConfig:
# Attention ops; used for piecewise cudagraphs
# Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name"
# Use PyTorch operator format: "namespace::name"
_attention_ops
:
ClassVar
[
list
[
str
]]
=
[
_attention_ops
:
ClassVar
[
list
[
str
]]
=
[
"vllm::unified_attention"
,
"vllm::unified_attention_with_output"
,
"vllm::unified_attention_with_output"
,
"vllm::unified_mla_attention"
,
"vllm::unified_mla_attention_with_output"
,
"vllm::unified_mla_attention_with_output"
,
"vllm::mamba_mixer2"
,
"vllm::mamba_mixer2"
,
"vllm::mamba_mixer"
,
"vllm::mamba_mixer"
,
...
...
vllm/model_executor/layers/attention/attention.py
View file @
70406eb1
...
@@ -354,7 +354,6 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -354,7 +354,6 @@ class Attention(nn.Module, AttentionLayerBase):
# and let torch.compile handle them.
# and let torch.compile handle them.
self
.
use_direct_call
=
not
current_platform
.
opaque_attention_op
()
self
.
use_direct_call
=
not
current_platform
.
opaque_attention_op
()
self
.
use_output
=
self
.
attn_backend
.
accept_output_buffer
compilation_config
=
vllm_config
.
compilation_config
compilation_config
=
vllm_config
.
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
...
@@ -429,75 +428,62 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -429,75 +428,62 @@ class Attention(nn.Module, AttentionLayerBase):
if
self
.
impl
.
supports_quant_query_input
:
if
self
.
impl
.
supports_quant_query_input
:
query
,
_
=
self
.
query_quant
(
query
,
self
.
_q_scale
)
query
,
_
=
self
.
query_quant
(
query
,
self
.
_q_scale
)
if
self
.
use_output
:
if
output_shape
is
None
:
if
output_shape
is
None
:
# Handle both 2D [num_tokens, hidden] and
# Handle both 2D [num_tokens, hidden] and
# 3D [num_tokens, heads, head_dim] query
# 3D [num_tokens, heads, head_dim] query
num_tokens
=
query
.
shape
[
0
]
num_tokens
=
query
.
shape
[
0
]
output_shape
=
torch
.
Size
((
num_tokens
,
self
.
num_heads
*
self
.
head_size_v
))
output_shape
=
torch
.
Size
(
output
=
torch
.
empty
(
output_shape
,
dtype
=
output_dtype
,
device
=
query
.
device
)
(
num_tokens
,
self
.
num_heads
*
self
.
head_size_v
)
hidden_size
=
output_shape
[
-
1
]
)
# Reshape the query, key, and value tensors.
output
=
torch
.
empty
(
output_shape
,
dtype
=
output_dtype
,
device
=
query
.
device
)
# NOTE(woosuk): We do this outside the custom op to minimize the
hidden_size
=
output_shape
[
-
1
]
# CPU overheads from the non-CUDA-graph regions.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
# NOTE(woosuk): We do this outside the custom op to minimize the
output
=
output
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size_v
)
# CPU overheads from the non-CUDA-graph regions.
if
key
is
not
None
:
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
output
=
output
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size_v
)
if
value
is
not
None
:
if
key
is
not
None
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size_v
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
kv_cache_dummy_dep
=
None
if
value
is
not
None
:
if
self
.
use_direct_call
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size_v
)
# Skip this if sharing KV cache with an earlier attention layer.
kv_cache_dummy_dep
=
None
if
(
if
self
.
use_direct_call
:
not
self
.
attn_backend
.
forward_includes_kv_cache_update
# Skip this if sharing KV cache with an earlier attention layer.
and
self
.
kv_sharing_target_layer_name
is
None
if
(
and
key
is
not
None
not
self
.
attn_backend
.
forward_includes_kv_cache_update
and
value
is
not
None
and
self
.
kv_sharing_target_layer_name
is
None
):
and
key
is
not
None
kv_cache_dummy_dep
=
unified_kv_cache_update
(
and
value
is
not
None
key
,
value
,
self
.
layer_name
):
kv_cache_dummy_dep
=
unified_kv_cache_update
(
key
,
value
,
self
.
layer_name
)
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
else
:
# Skip this if sharing KV cache with an earlier attention layer.
if
(
not
self
.
attn_backend
.
forward_includes_kv_cache_update
and
self
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
kv_cache_dummy_dep
=
torch
.
ops
.
vllm
.
unified_kv_cache_update
(
key
,
value
,
self
.
layer_name
)
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
)
return
output
.
view
(
-
1
,
hidden_size
)
unified_attention_with_output
(
else
:
query
,
assert
self
.
attn_backend
.
forward_includes_kv_cache_update
,
(
key
,
"Split KV cache update not supported when output tensor not provided."
value
,
output
,
self
.
layer_name
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
)
if
self
.
use_direct_call
:
else
:
return
unified_attention
(
query
,
key
,
value
,
self
.
layer_name
)
# Skip this if sharing KV cache with an earlier attention layer.
else
:
if
(
return
torch
.
ops
.
vllm
.
unified_attention
(
not
self
.
attn_backend
.
forward_includes_kv_cache_update
query
,
key
,
value
,
self
.
layer_name
and
self
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
kv_cache_dummy_dep
=
torch
.
ops
.
vllm
.
unified_kv_cache_update
(
key
,
value
,
self
.
layer_name
)
)
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
return
output
.
view
(
-
1
,
hidden_size
)
def
calc_kv_scales
(
self
,
query
,
key
,
value
):
def
calc_kv_scales
(
self
,
query
,
key
,
value
):
self
.
_q_scale
.
copy_
(
torch
.
abs
(
query
).
max
()
/
self
.
q_range
)
self
.
_q_scale
.
copy_
(
torch
.
abs
(
query
).
max
()
/
self
.
q_range
)
...
@@ -633,35 +619,6 @@ def get_attention_context(
...
@@ -633,35 +619,6 @@ def get_attention_context(
return
attn_metadata
,
attn_layer
,
kv_cache
,
layer_slot_mapping
return
attn_metadata
,
attn_layer
,
kv_cache
,
layer_slot_mapping
@
maybe_transfer_kv_layer
def
unified_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
attn_metadata
,
self
,
kv_cache
,
_
=
get_attention_context
(
layer_name
)
output
=
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
attn_metadata
)
return
output
def
unified_attention_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
).
contiguous
()
direct_register_custom_op
(
op_name
=
"unified_attention"
,
op_func
=
unified_attention
,
fake_impl
=
unified_attention_fake
,
)
def
unified_kv_cache_update
(
def
unified_kv_cache_update
(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/attention/cross_attention.py
View file @
70406eb1
...
@@ -133,7 +133,7 @@ def create_cross_attention_backend(
...
@@ -133,7 +133,7 @@ def create_cross_attention_backend(
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
output
:
torch
.
Tensor
|
None
=
None
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
70406eb1
...
@@ -494,21 +494,16 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -494,21 +494,16 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
_k_scale
,
self
.
_k_scale
,
)
)
if
self
.
attn_backend
.
accept_output_buffer
:
output
=
torch
.
empty
(
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
self
.
forward_impl
(
self
.
forward_impl
(
q
,
q
,
kv_c_normed
,
kv_c_normed
,
k_pe
,
k_pe
,
self_kv_cache
,
self_kv_cache
,
attn_metadata
,
attn_metadata
,
output
=
output
,
output
=
output
,
)
)
return
output
return
output
else
:
return
self
.
forward_impl
(
q
,
kv_c_normed
,
k_pe
,
self_kv_cache
,
attn_metadata
)
else
:
else
:
kv_cache_dummy_dep
=
torch
.
ops
.
vllm
.
unified_mla_kv_cache_update
(
kv_cache_dummy_dep
=
torch
.
ops
.
vllm
.
unified_mla_kv_cache_update
(
kv_c_normed
,
kv_c_normed
,
...
@@ -517,25 +512,16 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -517,25 +512,16 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
_k_scale
,
self
.
_k_scale
,
)
)
if
self
.
attn_backend
.
accept_output_buffer
:
output
=
torch
.
empty
(
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
torch
.
ops
.
vllm
.
unified_mla_attention_with_output
(
torch
.
ops
.
vllm
.
unified_mla_attention_with_output
(
q
,
q
,
kv_c_normed
,
kv_c_normed
,
k_pe
,
k_pe
,
output
,
output
,
self
.
layer_name
,
self
.
layer_name
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
)
return
output
return
output
else
:
return
torch
.
ops
.
vllm
.
unified_mla_attention
(
q
,
kv_c_normed
,
k_pe
,
self
.
layer_name
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
def
forward_impl
(
def
forward_impl
(
self
,
self
,
...
@@ -544,12 +530,10 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -544,12 +530,10 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_pe
:
torch
.
Tensor
,
# value in unified attn
k_pe
:
torch
.
Tensor
,
# value in unified attn
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
"MLACommonMetadata"
,
attn_metadata
:
"MLACommonMetadata"
,
output
:
torch
.
Tensor
|
None
=
None
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
output
is
not
None
,
"Output tensor must be provided."
use_quant
=
output_scale
is
not
None
or
output_block_scale
is
not
None
use_quant
=
output_scale
is
not
None
or
output_block_scale
is
not
None
if
use_quant
:
if
use_quant
:
# The fusion pass has allocated output with quantized dtype
# The fusion pass has allocated output with quantized dtype
...
@@ -913,43 +897,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -913,43 +897,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
out
.
copy_
(
out_new
)
# Copy result
out
.
copy_
(
out_new
)
# Copy result
@
maybe_transfer_kv_layer
def
unified_mla_attention
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
layer_name
:
str
,
kv_cache_dummy_dep
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del
kv_cache_dummy_dep
attn_metadata
,
layer
,
kv_cache
,
_
=
get_attention_context
(
layer_name
)
output
=
layer
.
forward_impl
(
q
,
kv_c_normed
,
k_pe
,
kv_cache
,
attn_metadata
)
return
output
def
unified_mla_attention_fake
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
layer_name
:
str
,
kv_cache_dummy_dep
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
).
contiguous
()
direct_register_custom_op
(
op_name
=
"unified_mla_attention"
,
op_func
=
unified_mla_attention
,
mutates_args
=
[],
fake_impl
=
unified_mla_attention_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
def
unified_mla_kv_cache_update
(
def
unified_mla_kv_cache_update
(
kv_c_normed
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
...
@@ -1151,8 +1098,6 @@ CUDNN_WORKSPACE_SIZE = 12800
...
@@ -1151,8 +1098,6 @@ CUDNN_WORKSPACE_SIZE = 12800
class
MLACommonBackend
(
AttentionBackend
):
class
MLACommonBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
staticmethod
@
staticmethod
def
get_name
()
->
str
:
def
get_name
()
->
str
:
return
"TRITON_MLA"
return
"TRITON_MLA"
...
...
vllm/model_executor/models/extract_hidden_states.py
View file @
70406eb1
...
@@ -94,7 +94,6 @@ def basic_cache(
...
@@ -94,7 +94,6 @@ def basic_cache(
class
CacheOnlyAttentionBackend
(
AttentionBackend
):
class
CacheOnlyAttentionBackend
(
AttentionBackend
):
"""Attention backend that only caches KV without computing attention."""
"""Attention backend that only caches KV without computing attention."""
accept_output_buffer
:
bool
=
False
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
...
...
vllm/model_executor/models/whisper_causal.py
View file @
70406eb1
...
@@ -184,7 +184,7 @@ def create_whisper_attention_backend_with_block_pooling(
...
@@ -184,7 +184,7 @@ def create_whisper_attention_backend_with_block_pooling(
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
output
:
torch
.
Tensor
|
None
=
None
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
vllm/v1/attention/backend.py
View file @
70406eb1
...
@@ -53,10 +53,6 @@ class MultipleOf:
...
@@ -53,10 +53,6 @@ class MultipleOf:
class
AttentionBackend
(
ABC
):
class
AttentionBackend
(
ABC
):
"""Abstract class for attention backends."""
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer
:
bool
=
False
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
"CacheDType"
]]
=
[
supported_kv_cache_dtypes
:
ClassVar
[
list
[
"CacheDType"
]]
=
[
"auto"
,
"auto"
,
...
@@ -779,7 +775,7 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]):
...
@@ -779,7 +775,7 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
attn_metadata
:
T
,
output
:
torch
.
Tensor
|
None
=
None
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
vllm/v1/attention/backends/cpu_attn.py
View file @
70406eb1
...
@@ -30,7 +30,6 @@ _CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM, CpuArchEnum.S3
...
@@ -30,7 +30,6 @@ _CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM, CpuArchEnum.S3
class
CPUAttentionBackend
(
AttentionBackend
):
class
CPUAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
...
@@ -267,7 +266,7 @@ class CPUAttentionBackendImpl(AttentionImpl):
...
@@ -267,7 +266,7 @@ class CPUAttentionBackendImpl(AttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
CPUAttentionMetadata
|
None
,
attn_metadata
:
CPUAttentionMetadata
|
None
,
output
:
torch
.
Tensor
|
None
=
None
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -283,7 +282,6 @@ class CPUAttentionBackendImpl(AttentionImpl):
...
@@ -283,7 +282,6 @@ class CPUAttentionBackendImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"fused output quantization is not yet supported"
"fused output quantization is not yet supported"
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
70406eb1
...
@@ -62,7 +62,6 @@ logger = init_logger(__name__)
...
@@ -62,7 +62,6 @@ logger = init_logger(__name__)
class
FlashAttentionBackend
(
AttentionBackend
):
class
FlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"auto"
,
...
@@ -664,7 +663,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -664,7 +663,7 @@ class FlashAttentionImpl(AttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
attn_metadata
:
FlashAttentionMetadata
,
output
:
torch
.
Tensor
|
None
=
None
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -683,7 +682,6 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -683,7 +682,6 @@ class FlashAttentionImpl(AttentionImpl):
{q,k,v}_descale to be (num_sequences, num_kv_heads).
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
We use torch's .expand() to avoid duplicating values
"""
"""
assert
output
is
not
None
,
"Output tensor must be provided."
assert
self
.
vllm_flash_attn_version
is
not
None
,
(
assert
self
.
vllm_flash_attn_version
is
not
None
,
(
"FlashAttention version not detected."
"FlashAttention version not detected."
)
)
...
...
vllm/v1/attention/backends/flash_attn_diffkv.py
View file @
70406eb1
...
@@ -128,7 +128,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
...
@@ -128,7 +128,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
attn_metadata
:
FlashAttentionMetadata
,
output
:
torch
.
Tensor
|
None
=
None
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -147,7 +147,6 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
...
@@ -147,7 +147,6 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
{q,k,v}_descale to be (num_sequences, num_kv_heads).
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
We use torch's .expand() to avoid duplicating values
"""
"""
assert
output
is
not
None
,
"Output tensor must be provided."
assert
self
.
vllm_flash_attn_version
is
not
None
,
(
assert
self
.
vllm_flash_attn_version
is
not
None
,
(
"FlashAttention version not detected."
"FlashAttention version not detected."
)
)
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
70406eb1
...
@@ -315,7 +315,6 @@ class BatchDCPPrefillWrapper:
...
@@ -315,7 +315,6 @@ class BatchDCPPrefillWrapper:
class
FlashInferBackend
(
AttentionBackend
):
class
FlashInferBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"auto"
,
...
@@ -1286,7 +1285,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1286,7 +1285,7 @@ class FlashInferImpl(AttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashInferMetadata
,
attn_metadata
:
FlashInferMetadata
,
output
:
torch
.
Tensor
|
None
=
None
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -1303,8 +1302,6 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1303,8 +1302,6 @@ class FlashInferImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
attn_metadata
is
None
:
if
attn_metadata
is
None
:
# Profiling run.
# Profiling run.
return
output
.
fill_
(
0
)
return
output
.
fill_
(
0
)
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
70406eb1
...
@@ -73,7 +73,6 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
...
@@ -73,7 +73,6 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
class
FlexAttentionBackend
(
AttentionBackend
):
class
FlexAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
...
@@ -992,7 +991,7 @@ class FlexAttentionImpl(AttentionImpl):
...
@@ -992,7 +991,7 @@ class FlexAttentionImpl(AttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlexAttentionMetadata
,
attn_metadata
:
FlexAttentionMetadata
,
output
:
torch
.
Tensor
|
None
=
None
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -1008,7 +1007,6 @@ class FlexAttentionImpl(AttentionImpl):
...
@@ -1008,7 +1007,6 @@ class FlexAttentionImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"fused output quantization is not yet supported for FlexAttentionImpl"
"fused output quantization is not yet supported for FlexAttentionImpl"
...
...
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
View file @
70406eb1
...
@@ -59,7 +59,6 @@ class FlashInferMLASparseBackend(AttentionBackend):
...
@@ -59,7 +59,6 @@ class FlashInferMLASparseBackend(AttentionBackend):
for models like DeepSeek-V3.2 that use index-based sparse attention.
for models like DeepSeek-V3.2 that use index-based sparse attention.
"""
"""
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"auto"
,
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
70406eb1
...
@@ -78,7 +78,6 @@ structured as:
...
@@ -78,7 +78,6 @@ structured as:
class
FlashMLASparseBackend
(
AttentionBackend
):
class
FlashMLASparseBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
bfloat16
]
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
bfloat16
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"auto"
,
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
View file @
70406eb1
...
@@ -78,7 +78,6 @@ def fetch_id_to_ragged_triton(
...
@@ -78,7 +78,6 @@ def fetch_id_to_ragged_triton(
class
ROCMAiterMLASparseBackend
(
AttentionBackend
):
class
ROCMAiterMLASparseBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"auto"
,
...
...
vllm/v1/attention/backends/mla/xpu_mla_sparse.py
View file @
70406eb1
...
@@ -35,7 +35,6 @@ logger = init_logger(__name__)
...
@@ -35,7 +35,6 @@ logger = init_logger(__name__)
class
XPUMLASparseBackend
(
AttentionBackend
):
class
XPUMLASparseBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"auto"
,
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
70406eb1
...
@@ -744,7 +744,6 @@ class AiterFlashAttentionMetadataBuilder(
...
@@ -744,7 +744,6 @@ class AiterFlashAttentionMetadataBuilder(
class
AiterFlashAttentionBackend
(
AttentionBackend
):
class
AiterFlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"auto"
,
...
@@ -1037,7 +1036,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -1037,7 +1036,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AiterFlashAttentionMetadata
,
attn_metadata
:
AiterFlashAttentionMetadata
,
output
:
torch
.
Tensor
|
None
=
None
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -1056,8 +1055,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -1056,8 +1055,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
{q,k,v}_descale to be (num_sequences, num_kv_heads).
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
We use torch's .expand() to avoid duplicating values
"""
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"fused output quantization is not yet supported "
"fused output quantization is not yet supported "
...
...
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
View file @
70406eb1
...
@@ -24,8 +24,6 @@ logger = init_logger(__name__)
...
@@ -24,8 +24,6 @@ logger = init_logger(__name__)
class
RocmAiterUnifiedAttentionBackend
(
RocmAttentionBackend
):
class
RocmAiterUnifiedAttentionBackend
(
RocmAttentionBackend
):
accept_output_buffer
:
bool
=
True
@
staticmethod
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
MultipleOf
(
16
)]
return
[
MultipleOf
(
16
)]
...
@@ -143,7 +141,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
...
@@ -143,7 +141,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
attn_metadata
:
FlashAttentionMetadata
,
output
:
torch
.
Tensor
|
None
=
None
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -159,8 +157,6 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
...
@@ -159,8 +157,6 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_block_scale
is
not
None
:
if
output_block_scale
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"fused block_scale output quantization is not yet supported"
"fused block_scale output quantization is not yet supported"
...
...
vllm/v1/attention/backends/rocm_attn.py
View file @
70406eb1
...
@@ -159,7 +159,6 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
...
@@ -159,7 +159,6 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
class
RocmAttentionBackend
(
AttentionBackend
):
class
RocmAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
...
@@ -352,7 +351,7 @@ class RocmAttentionImpl(AttentionImpl):
...
@@ -352,7 +351,7 @@ class RocmAttentionImpl(AttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
attn_metadata
:
FlashAttentionMetadata
,
output
:
torch
.
Tensor
|
None
=
None
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -368,8 +367,6 @@ class RocmAttentionImpl(AttentionImpl):
...
@@ -368,8 +367,6 @@ class RocmAttentionImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_block_scale
is
not
None
:
if
output_block_scale
is
not
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"fused block_scale output quantization is not yet supported"
"fused block_scale output quantization is not yet supported"
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment