"docs/vscode:/vscode.git/clone" did not exist on "022afbeb4efa22bb8a4656a2712cd66c6a811c23"
Unverified Commit 70406eb1 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention][V0 Deprecation] Deprecate accept output buffer (#39125)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 08bfedc1
...@@ -216,12 +216,14 @@ def test_splitting_ops_dynamic(): ...@@ -216,12 +216,14 @@ def test_splitting_ops_dynamic():
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True, use_inductor_graph_partition=True,
splitting_ops=["vllm::unified_attention"], splitting_ops=["vllm::unified_attention_with_output"],
) )
) )
# with inductor partition we use splitting_ops directly for # with inductor partition we use splitting_ops directly for
# partition rules # partition rules
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"] assert config.compilation_config.splitting_ops == [
"vllm::unified_attention_with_output"
]
# When attn_fusion pass enabled. # When attn_fusion pass enabled.
config = VllmConfig( config = VllmConfig(
...@@ -281,7 +283,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition(): ...@@ -281,7 +283,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True, use_inductor_graph_partition=True,
splitting_ops=[ splitting_ops=[
"vllm::unified_attention", "vllm::unified_attention_with_output",
"vllm::moe_forward", "vllm::moe_forward",
"vllm::moe_forward_shared", "vllm::moe_forward_shared",
], ],
...@@ -289,7 +291,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition(): ...@@ -289,7 +291,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
) )
splitting_ops = config.compilation_config.splitting_ops splitting_ops = config.compilation_config.splitting_ops
assert splitting_ops == [ assert splitting_ops == [
"vllm::unified_attention", "vllm::unified_attention_with_output",
"vllm::moe_forward", "vllm::moe_forward",
"vllm::moe_forward_shared", "vllm::moe_forward_shared",
] ]
......
...@@ -282,7 +282,7 @@ class PassConfig: ...@@ -282,7 +282,7 @@ class PassConfig:
""" """
enabled_fusions = [ enabled_fusions = [
f.name[len("fuse_") :] f.name[len("fuse_") :]
for f in fields(self) for f in fields(self) # type: ignore[arg-type]
if getattr(self, f.name) and f.name.startswith("fuse_") if getattr(self, f.name) and f.name.startswith("fuse_")
] ]
...@@ -711,9 +711,7 @@ class CompilationConfig: ...@@ -711,9 +711,7 @@ class CompilationConfig:
# Attention ops; used for piecewise cudagraphs # Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name" # Use PyTorch operator format: "namespace::name"
_attention_ops: ClassVar[list[str]] = [ _attention_ops: ClassVar[list[str]] = [
"vllm::unified_attention",
"vllm::unified_attention_with_output", "vllm::unified_attention_with_output",
"vllm::unified_mla_attention",
"vllm::unified_mla_attention_with_output", "vllm::unified_mla_attention_with_output",
"vllm::mamba_mixer2", "vllm::mamba_mixer2",
"vllm::mamba_mixer", "vllm::mamba_mixer",
......
...@@ -354,7 +354,6 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -354,7 +354,6 @@ class Attention(nn.Module, AttentionLayerBase):
# and let torch.compile handle them. # and let torch.compile handle them.
self.use_direct_call = not current_platform.opaque_attention_op() self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
...@@ -429,14 +428,11 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -429,14 +428,11 @@ class Attention(nn.Module, AttentionLayerBase):
if self.impl.supports_quant_query_input: if self.impl.supports_quant_query_input:
query, _ = self.query_quant(query, self._q_scale) query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
if output_shape is None: if output_shape is None:
# Handle both 2D [num_tokens, hidden] and # Handle both 2D [num_tokens, hidden] and
# 3D [num_tokens, heads, head_dim] query # 3D [num_tokens, heads, head_dim] query
num_tokens = query.shape[0] num_tokens = query.shape[0]
output_shape = torch.Size( output_shape = torch.Size((num_tokens, self.num_heads * self.head_size_v))
(num_tokens, self.num_heads * self.head_size_v)
)
output = torch.empty(output_shape, dtype=output_dtype, device=query.device) output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1] hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
...@@ -488,16 +484,6 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -488,16 +484,6 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
return output.view(-1, hidden_size) return output.view(-1, hidden_size)
else:
assert self.attn_backend.forward_includes_kv_cache_update, (
"Split KV cache update not supported when output tensor not provided."
)
if self.use_direct_call:
return unified_attention(query, key, value, self.layer_name)
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name
)
def calc_kv_scales(self, query, key, value): def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._q_scale.copy_(torch.abs(query).max() / self.q_range)
...@@ -633,35 +619,6 @@ def get_attention_context( ...@@ -633,35 +619,6 @@ def get_attention_context(
return attn_metadata, attn_layer, kv_cache, layer_slot_mapping return attn_metadata, attn_layer, kv_cache, layer_slot_mapping
@maybe_transfer_kv_layer
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
return output
def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
fake_impl=unified_attention_fake,
)
def unified_kv_cache_update( def unified_kv_cache_update(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
......
...@@ -133,7 +133,7 @@ def create_cross_attention_backend( ...@@ -133,7 +133,7 @@ def create_cross_attention_backend(
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -494,7 +494,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -494,7 +494,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.kv_cache_dtype, self.kv_cache_dtype,
self._k_scale, self._k_scale,
) )
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device) output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.forward_impl( self.forward_impl(
q, q,
...@@ -505,10 +504,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -505,10 +504,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
output=output, output=output,
) )
return output return output
else:
return self.forward_impl(
q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
)
else: else:
kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update( kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update(
kv_c_normed, kv_c_normed,
...@@ -517,7 +512,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -517,7 +512,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.kv_cache_dtype, self.kv_cache_dtype,
self._k_scale, self._k_scale,
) )
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device) output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output( torch.ops.vllm.unified_mla_attention_with_output(
q, q,
...@@ -528,14 +522,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -528,14 +522,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
return output return output
else:
return torch.ops.vllm.unified_mla_attention(
q,
kv_c_normed,
k_pe,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
def forward_impl( def forward_impl(
self, self,
...@@ -544,12 +530,10 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -544,12 +530,10 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_pe: torch.Tensor, # value in unified attn k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: "MLACommonMetadata", attn_metadata: "MLACommonMetadata",
output: torch.Tensor | None = None, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
use_quant = output_scale is not None or output_block_scale is not None use_quant = output_scale is not None or output_block_scale is not None
if use_quant: if use_quant:
# The fusion pass has allocated output with quantized dtype # The fusion pass has allocated output with quantized dtype
...@@ -913,43 +897,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -913,43 +897,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
out.copy_(out_new) # Copy result out.copy_(out_new) # Copy result
@maybe_transfer_kv_layer
def unified_mla_attention(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> torch.Tensor:
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del kv_cache_dummy_dep
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
return output
def unified_mla_attention_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()
direct_register_custom_op(
op_name="unified_mla_attention",
op_func=unified_mla_attention,
mutates_args=[],
fake_impl=unified_mla_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
def unified_mla_kv_cache_update( def unified_mla_kv_cache_update(
kv_c_normed: torch.Tensor, kv_c_normed: torch.Tensor,
k_pe: torch.Tensor, k_pe: torch.Tensor,
...@@ -1151,8 +1098,6 @@ CUDNN_WORKSPACE_SIZE = 12800 ...@@ -1151,8 +1098,6 @@ CUDNN_WORKSPACE_SIZE = 12800
class MLACommonBackend(AttentionBackend): class MLACommonBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TRITON_MLA" return "TRITON_MLA"
......
...@@ -94,7 +94,6 @@ def basic_cache( ...@@ -94,7 +94,6 @@ def basic_cache(
class CacheOnlyAttentionBackend(AttentionBackend): class CacheOnlyAttentionBackend(AttentionBackend):
"""Attention backend that only caches KV without computing attention.""" """Attention backend that only caches KV without computing attention."""
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [ supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
......
...@@ -184,7 +184,7 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -184,7 +184,7 @@ def create_whisper_attention_backend_with_block_pooling(
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -53,10 +53,6 @@ class MultipleOf: ...@@ -53,10 +53,6 @@ class MultipleOf:
class AttentionBackend(ABC): class AttentionBackend(ABC):
"""Abstract class for attention backends.""" """Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = [ supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = [
"auto", "auto",
...@@ -779,7 +775,7 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]): ...@@ -779,7 +775,7 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
output: torch.Tensor | None = None, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -30,7 +30,6 @@ _CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM, CpuArchEnum.S3 ...@@ -30,7 +30,6 @@ _CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM, CpuArchEnum.S3
class CPUAttentionBackend(AttentionBackend): class CPUAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [ supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
...@@ -267,7 +266,7 @@ class CPUAttentionBackendImpl(AttentionImpl): ...@@ -267,7 +266,7 @@ class CPUAttentionBackendImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: CPUAttentionMetadata | None, attn_metadata: CPUAttentionMetadata | None,
output: torch.Tensor | None = None, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -283,7 +282,6 @@ class CPUAttentionBackendImpl(AttentionImpl): ...@@ -283,7 +282,6 @@ class CPUAttentionBackendImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
......
...@@ -62,7 +62,6 @@ logger = init_logger(__name__) ...@@ -62,7 +62,6 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
...@@ -664,7 +663,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -664,7 +663,7 @@ class FlashAttentionImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -683,7 +682,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -683,7 +682,6 @@ class FlashAttentionImpl(AttentionImpl):
{q,k,v}_descale to be (num_sequences, num_kv_heads). {q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values We use torch's .expand() to avoid duplicating values
""" """
assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, ( assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected." "FlashAttention version not detected."
) )
......
...@@ -128,7 +128,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl): ...@@ -128,7 +128,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -147,7 +147,6 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl): ...@@ -147,7 +147,6 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
{q,k,v}_descale to be (num_sequences, num_kv_heads). {q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values We use torch's .expand() to avoid duplicating values
""" """
assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, ( assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected." "FlashAttention version not detected."
) )
......
...@@ -315,7 +315,6 @@ class BatchDCPPrefillWrapper: ...@@ -315,7 +315,6 @@ class BatchDCPPrefillWrapper:
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
...@@ -1286,7 +1285,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1286,7 +1285,7 @@ class FlashInferImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
output: torch.Tensor | None = None, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -1303,8 +1302,6 @@ class FlashInferImpl(AttentionImpl): ...@@ -1303,8 +1302,6 @@ class FlashInferImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is not None, "Output tensor must be provided."
if attn_metadata is None: if attn_metadata is None:
# Profiling run. # Profiling run.
return output.fill_(0) return output.fill_(0)
......
...@@ -73,7 +73,6 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): ...@@ -73,7 +73,6 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
class FlexAttentionBackend(AttentionBackend): class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [ supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
...@@ -992,7 +991,7 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -992,7 +991,7 @@ class FlexAttentionImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlexAttentionMetadata, attn_metadata: FlexAttentionMetadata,
output: torch.Tensor | None = None, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -1008,7 +1007,6 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -1008,7 +1007,6 @@ class FlexAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported for FlexAttentionImpl" "fused output quantization is not yet supported for FlexAttentionImpl"
......
...@@ -59,7 +59,6 @@ class FlashInferMLASparseBackend(AttentionBackend): ...@@ -59,7 +59,6 @@ class FlashInferMLASparseBackend(AttentionBackend):
for models like DeepSeek-V3.2 that use index-based sparse attention. for models like DeepSeek-V3.2 that use index-based sparse attention.
""" """
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
......
...@@ -78,7 +78,6 @@ structured as: ...@@ -78,7 +78,6 @@ structured as:
class FlashMLASparseBackend(AttentionBackend): class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
......
...@@ -78,7 +78,6 @@ def fetch_id_to_ragged_triton( ...@@ -78,7 +78,6 @@ def fetch_id_to_ragged_triton(
class ROCMAiterMLASparseBackend(AttentionBackend): class ROCMAiterMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
......
...@@ -35,7 +35,6 @@ logger = init_logger(__name__) ...@@ -35,7 +35,6 @@ logger = init_logger(__name__)
class XPUMLASparseBackend(AttentionBackend): class XPUMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
......
...@@ -744,7 +744,6 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -744,7 +744,6 @@ class AiterFlashAttentionMetadataBuilder(
class AiterFlashAttentionBackend(AttentionBackend): class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
...@@ -1037,7 +1036,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1037,7 +1036,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AiterFlashAttentionMetadata, attn_metadata: AiterFlashAttentionMetadata,
output: torch.Tensor | None = None, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -1056,8 +1055,6 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1056,8 +1055,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
{q,k,v}_descale to be (num_sequences, num_kv_heads). {q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values We use torch's .expand() to avoid duplicating values
""" """
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported " "fused output quantization is not yet supported "
......
...@@ -24,8 +24,6 @@ logger = init_logger(__name__) ...@@ -24,8 +24,6 @@ logger = init_logger(__name__)
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
accept_output_buffer: bool = True
@staticmethod @staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)] return [MultipleOf(16)]
...@@ -143,7 +141,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): ...@@ -143,7 +141,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -159,8 +157,6 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): ...@@ -159,8 +157,6 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None: if output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused block_scale output quantization is not yet supported" "fused block_scale output quantization is not yet supported"
......
...@@ -159,7 +159,6 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat ...@@ -159,7 +159,6 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
class RocmAttentionBackend(AttentionBackend): class RocmAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [ supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
...@@ -352,7 +351,7 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -352,7 +351,7 @@ class RocmAttentionImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -368,8 +367,6 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -368,8 +367,6 @@ class RocmAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None: if output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused block_scale output quantization is not yet supported" "fused block_scale output quantization is not yet supported"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment