Unverified Commit f98548b9 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Co-authored-by: default avatarSage Moore <sage@neuralmagic.com>
parent 96846bb3
...@@ -305,6 +305,7 @@ steps: ...@@ -305,6 +305,7 @@ steps:
commands: commands:
- pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py - pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_fusion_attn.py
- pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_async_tp.py
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from typing import Callable, Union from typing import Callable, Union
from torch import fx from torch import fx
from torch._ops import OpOverload
from vllm.compilation.fx_utils import (find_specified_fn, from vllm.compilation.fx_utils import find_op_nodes
find_specified_fn_maybe)
from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.inductor_pass import InductorPass
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
...@@ -48,18 +49,19 @@ class TestBackend: ...@@ -48,18 +49,19 @@ class TestBackend:
# assign by reference, will reflect the final state of the graph # assign by reference, will reflect the final state of the graph
self.final_graph = graph self.final_graph = graph
def check_before_ops(self, ops, def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe, \
ops_fully_replaced=True):
for op in ops: for op in ops:
find_fn(self.graph_pre_pass.nodes, op) num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
if ops_fully_replaced: num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
if fully_replaced:
assert num_post == 0, \
f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops, def check_after_ops(self, ops: Sequence[OpOverload]):
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe):
for op in ops: for op in ops:
find_fn(self.graph_post_pass.nodes, op) num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
\ No newline at end of file
...@@ -169,8 +169,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, ...@@ -169,8 +169,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
# In pre-nodes, all gather or reduce scatter should exist, # In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not # fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(), backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
ops_fully_replaced=False)
# In post-nodes, fused_matmul_reduce_scatter or \ # In post-nodes, fused_matmul_reduce_scatter or \
# fused_all_gather_matmul should exist # fused_all_gather_matmul should exist
......
...@@ -7,8 +7,7 @@ import torch ...@@ -7,8 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.plugins import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey) FusionPass, GroupShape, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig) VllmConfig)
...@@ -30,9 +29,10 @@ class TestModel(torch.nn.Module): ...@@ -30,9 +29,10 @@ class TestModel(torch.nn.Module):
self.cutlass_fp8_enabled = cutlass_fp8_enabled self.cutlass_fp8_enabled = cutlass_fp8_enabled
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.key = QuantKey(dtype=FP8_DTYPE, self.key = QuantKey(dtype=FP8_DTYPE,
static=static, static=static,
per_tensor=static, group_shape=group_shape,
symmetric=True) symmetric=True)
if static: if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
...@@ -122,9 +122,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, ...@@ -122,9 +122,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
# In pre-nodes, fp8 quant should be there and fused kernels should not # In pre-nodes, fp8 quant should be there and fused kernels should not
backend.check_before_ops(model.ops_in_model_before(), find_auto_fn, backend.check_before_ops(model.ops_in_model_before())
find_auto_fn_maybe)
# In post-nodes, fused kernels should be there and fp8 quant should not # In post-nodes, fused kernels should be there and fp8 quant should not
backend.check_after_ops(model.ops_in_model_after(), find_auto_fn, backend.check_after_ops(model.ops_in_model_after())
find_auto_fn_maybe)
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
import torch._dynamo
from tests.compile.backend import TestBackend
from tests.models.utils import check_outputs_equal
from vllm import LLM, SamplingParams
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.platforms import current_platform
# globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None
backend_unfused: Optional[TestBackend] = None
@pytest.mark.parametrize(
"model, quant_key",
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
@pytest.mark.parametrize(
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test CUDA and ROCm")
def test_attention_fusion(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
# Use global backends
global backend, backend_unfused
use_v1 = False # can be made a param once V1 support added
monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1)))
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
# Prompt 4 seems too open-ended, differs between fused and unfused
# (both outputs look reasonable though)
prompts = example_prompts[:4] + example_prompts[5:]
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend_unfused",
)
vllm_config = VllmConfig(compilation_config=compile_config)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
llm = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
max_model_len=2048)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=10,
top_p=0.95)
unfused_output = llm.generate(prompts, sampling_params)
backend_unfused = None # Reset backend to make sure llm gets released
del llm
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend",
)
vllm_config = VllmConfig(compilation_config=compile_config)
# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
max_model_len=2048)
# check support
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype,
quant_key.static,
quant_key.group_shape)
for key, layer in compile_config.static_forward_context.items()
]
print(f"{attn_fusion_supported=}")
if any(attn_fusion_supported):
# Check quant ops
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
# attention ops present in both, just output_scale param changes
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
assert len(attn_nodes_pre) == len(attn_nodes_post)
for i in range(len(attn_nodes_pre)):
assert attn_nodes_pre[i].kwargs["output_scale"] is None
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
assert fused == attn_fusion_supported[i], \
f"Node {i} {'' if fused else 'not '} expected " \
f"to have fused output quant"
# check outputs
fused_output = llm2.generate(prompts, sampling_params)
# transform outputs to format expected by check_outputs_equal
sample_outs = lambda s: (list(s.token_ids), s.text)
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
check_outputs_equal(
outputs_0_lst=outs_lst(unfused_output),
outputs_1_lst=outs_lst(fused_output),
name_0="unfused",
name_1="fused",
)
# Clean Dynamo cache to avoid polluting other case(s)
torch._dynamo.reset()
# Reset backend to make sure llm2 gets released
backend = None
...@@ -1225,6 +1225,7 @@ def scaled_fp8_quant( ...@@ -1225,6 +1225,7 @@ def scaled_fp8_quant(
num_token_padding: Optional[int] = None, num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False, use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Quantize input tensor to FP8 and return quantized tensor and scale. Quantize input tensor to FP8 and return quantized tensor and scale.
...@@ -1256,7 +1257,12 @@ def scaled_fp8_quant( ...@@ -1256,7 +1257,12 @@ def scaled_fp8_quant(
out_dtype: torch.dtype = current_platform.fp8_dtype() out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding: if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1]) shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype) output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None: if scale is None:
if use_per_token_if_dynamic: if use_per_token_if_dynamic:
......
...@@ -284,9 +284,25 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -284,9 +284,25 @@ class AttentionImpl(ABC, Generic[T]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape. (-1, -1) for per-tensor.
:return: is fusion supported for this type of quantization
"""
return False
class MLAAttentionImpl(AttentionImpl[T], Generic[T]): class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
...@@ -300,6 +316,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]): ...@@ -300,6 +316,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
......
...@@ -374,6 +374,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -374,6 +374,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata, attn_metadata: BlocksparseFlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
...@@ -388,6 +389,11 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -388,6 +389,11 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for BlocksparseFlashAttentionImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
......
...@@ -370,6 +370,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): ...@@ -370,6 +370,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: DualChunkFlashAttentionMetadata, attn_metadata: DualChunkFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with DualChunkFlashAttention. """Forward pass with DualChunkFlashAttention.
Args: Args:
...@@ -383,6 +385,13 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): ...@@ -383,6 +385,13 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is None, "Output tensor not supported for DualChunk"
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
( (
query, query,
query_succ, query_succ,
......
...@@ -673,6 +673,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -673,6 +673,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
...@@ -692,6 +693,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -692,6 +693,11 @@ class FlashAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
assert ( assert (
......
...@@ -975,8 +975,14 @@ class FlashInferImpl(AttentionImpl): ...@@ -975,8 +975,14 @@ class FlashInferImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
# TODO: directly write to output tensor # TODO: directly write to output tensor
num_heads: int = self.num_heads num_heads: int = self.num_heads
head_size: int = self.head_size head_size: int = self.head_size
......
...@@ -181,6 +181,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -181,6 +181,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata, attn_metadata: HPUAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
...@@ -193,6 +194,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -193,6 +194,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape _, seq_len_kv, _ = key.shape
......
...@@ -192,6 +192,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -192,6 +192,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore attn_metadata: IpexAttnMetadata, # type: ignore
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention. """Forward pass with IPEX varlen_attention and PagedAttention.
...@@ -206,6 +207,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -206,6 +207,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for IpexAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
......
...@@ -1319,11 +1319,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1319,11 +1319,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if output is not None: if output is not None:
raise NotImplementedError( raise NotImplementedError(
"output is not yet supported for MLAImplBase") "output is not yet supported for MLAImplBase")
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLAImplBase")
if attn_metadata.is_profile_run and \ if attn_metadata.is_profile_run and \
attn_metadata.context_chunk_workspace is not None: attn_metadata.context_chunk_workspace is not None:
# During the profile run try to simulate to worse case output size # During the profile run try to simulate to worse case output size
......
...@@ -172,6 +172,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -172,6 +172,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache: Tuple[torch.Tensor, torch.Tensor], kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
...@@ -187,6 +188,11 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -187,6 +188,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns: Returns:
shape = [batch_size, seq_len, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size) query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
......
...@@ -598,6 +598,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -598,6 +598,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim).reshape(tokens, n_kv_heads * n_rep, head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim)) head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype(
) and static and group_shape == (-1, -1) # per-tensor
# Only supported in the Triton backend
return False
def forward( def forward(
self, self,
layer: AttentionLayer, layer: AttentionLayer,
...@@ -607,6 +616,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -607,6 +616,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: ROCmFlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
...@@ -660,6 +670,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -660,6 +670,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None and not self.use_triton_flash_attn:
raise NotImplementedError(
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now")
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
if key is not None: if key is not None:
assert value is not None assert value is not None
...@@ -799,6 +814,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -799,6 +814,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks[0][None] attn_masks[0][None]
if attn_masks is not None else None, if attn_masks is not None else None,
full_scales, full_scales,
output_scale,
) )
elif self.use_naive_attn: elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
...@@ -876,6 +892,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -876,6 +892,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_query.dtype, head_size, block_size, gqa_ratio, decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window, decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes) self.kv_cache_dtype, self.alibi_slopes)
if use_custom: if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else != AttentionType.ENCODER_DECODER else
...@@ -887,7 +904,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -887,7 +904,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert _PARTITION_SIZE_ROCM % block_size == 0 assert _PARTITION_SIZE_ROCM % block_size == 0
tmp_output = torch.empty( tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size), size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype, dtype=query.dtype,
device=output.device, device=output.device,
) )
exp_sums = torch.empty( exp_sums = torch.empty(
...@@ -921,9 +938,17 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -921,9 +938,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale,
output_scale,
) )
else: else:
output[num_prefill_tokens:] = paged_attn.forward_decode( # PagedAttention does not support fused quant, manually quantize
if output_scale is None:
out_pa = output[num_prefill_tokens:]
else:
out_pa = torch.empty_like(output[num_prefill_tokens:],
dtype=query.dtype)
out_pa[:] = paged_attn.forward_decode(
decode_query, decode_query,
key_cache, key_cache,
value_cache, value_cache,
...@@ -944,6 +969,14 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -944,6 +969,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
layer._v_scale, layer._v_scale,
) )
# Manually perform quantization
if output_scale is not None:
out_uq = out_pa.view(-1, self.num_heads * self.head_size)
out_q = output.view(-1, self.num_heads * self.head_size)
ops.scaled_fp8_quant(out_uq,
output_scale,
output=out_q[num_prefill_tokens:])
# Reshape the output tensor. # Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size) return output.view(-1, self.num_heads * self.head_size)
......
...@@ -459,6 +459,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -459,6 +459,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
...@@ -473,6 +474,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -473,6 +474,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl")
# For warming-up # For warming-up
if attn_metadata is None: if attn_metadata is None:
......
...@@ -435,6 +435,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -435,6 +435,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
...@@ -487,6 +488,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -487,6 +488,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersImpl")
attn_type = self.attn_type attn_type = self.attn_type
# Check that appropriate attention metadata attributes are # Check that appropriate attention metadata attributes are
# selected for the desired attention type # selected for the desired attention type
......
...@@ -430,6 +430,7 @@ def unified_attention_with_output( ...@@ -430,6 +430,7 @@ def unified_attention_with_output(
value: torch.Tensor, value: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None: ) -> None:
wait_for_kv_layer_from_connector(layer_name) wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
...@@ -444,7 +445,8 @@ def unified_attention_with_output( ...@@ -444,7 +445,8 @@ def unified_attention_with_output(
value, value,
kv_cache, kv_cache,
attn_metadata, attn_metadata,
output=output) output=output,
output_scale=output_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, kv_cache)
...@@ -455,6 +457,7 @@ def unified_attention_with_output_fake( ...@@ -455,6 +457,7 @@ def unified_attention_with_output_fake(
value: torch.Tensor, value: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None: ) -> None:
return return
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, ClassVar, NamedTuple, Optional
from typing import Callable, NamedTuple, Optional
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
...@@ -34,36 +33,66 @@ RMS_OP = torch.ops._C.rms_norm.default ...@@ -34,36 +33,66 @@ RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
class QuantKey(NamedTuple): class QuantKey(NamedTuple):
""" """
Named tuple for identifying the type of quantization. Named tuple for identifying the type of quantization.
dtype: quantized data type dtype: quantized data type
static: static quantization if True, dynamic if False static: static quantization if True, dynamic if False
per_tensor: per-tensor quantization if True, per-token if False group_shape: quantization group shape
symmetric: symmetric if True, asymmetric if False symmetric: symmetric if True, asymmetric if False
TODO(luka) use QuantDescriptor once standardized:
https://github.com/vllm-project/vllm/issues/8913
""" """
dtype: torch.dtype dtype: torch.dtype
static: bool static: bool
per_tensor: bool = True group_shape: GroupShape
symmetric: bool = True symmetric: bool = True
def __str__(self): def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"QuantKey({'static' if self.static else 'dynamic'}," return (f"QuantKey({'static' if self.static else 'dynamic'},"
f"{fx.graph.dtype_abbrs[self.dtype]}," f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
f"{'per_tensor' if self.per_tensor else 'per_token'},"
f"{'a' if not self.symmetric else ''}symmetric)") f"{'a' if not self.symmetric else ''}symmetric)")
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True) kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
QUANT_OPS: dict[QuantKey, OpOverload] = { QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa kFp8StaticTensorSym:
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: kFp8DynamicTensorSym:
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: kFp8DynamicTokenSym:
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
} }
...@@ -83,13 +112,13 @@ class FusedRMSQuantKey(NamedTuple): ...@@ -83,13 +112,13 @@ class FusedRMSQuantKey(NamedTuple):
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(kFp8StaticTensorSym, False): FusedRMSQuantKey(kFp8StaticTensorSym, False):
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8StaticTensorSym, True): FusedRMSQuantKey(kFp8StaticTensorSym, True):
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8DynamicTokenSym, False): FusedRMSQuantKey(kFp8DynamicTokenSym, False):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8DynamicTokenSym, True): FusedRMSQuantKey(kFp8DynamicTokenSym, True):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
} }
...@@ -177,9 +206,10 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): ...@@ -177,9 +206,10 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
symmetric=True): symmetric=True):
fused_key = FusedRMSQuantKey(fused_add=False, fused_key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype, quant=QuantKey(
dtype=quant_dtype,
static=True, static=True,
per_tensor=True, group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric)) symmetric=symmetric))
super().__init__(epsilon, fused_key) super().__init__(epsilon, fused_key)
...@@ -233,9 +263,10 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): ...@@ -233,9 +263,10 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
symmetric=True): symmetric=True):
key = FusedRMSQuantKey(fused_add=True, key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype, quant=QuantKey(
dtype=quant_dtype,
static=True, static=True,
per_tensor=True, group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric)) symmetric=symmetric))
super().__init__(epsilon, key) super().__init__(epsilon, key)
...@@ -323,12 +354,12 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -323,12 +354,12 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self, def __init__(self,
epsilon: float, epsilon: float,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
per_tensor: bool, group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True): symmetric=True):
key = FusedRMSQuantKey(fused_add=False, key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype, quant=QuantKey(dtype=quant_dtype,
static=False, static=False,
per_tensor=per_tensor, group_shape=group_shape,
symmetric=symmetric)) symmetric=symmetric))
super().__init__(epsilon, key) super().__init__(epsilon, key)
...@@ -421,12 +452,12 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -421,12 +452,12 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self, def __init__(self,
epsilon: float, epsilon: float,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
per_tensor: bool = True, group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True): symmetric=True):
key = FusedRMSQuantKey(fused_add=True, key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype, quant=QuantKey(dtype=quant_dtype,
static=False, static=False,
per_tensor=per_tensor, group_shape=group_shape,
symmetric=symmetric)) symmetric=symmetric))
super().__init__(epsilon, key) super().__init__(epsilon, key)
...@@ -566,16 +597,12 @@ class FusionPass(VllmInductorPass): ...@@ -566,16 +597,12 @@ class FusionPass(VllmInductorPass):
self.patterns, self.record_match) self.patterns, self.record_match)
# Fuse rms_norm + dynamic per-token fp8 quant # Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE, RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
per_tensor=False).register(
self.patterns, self.record_match) self.patterns, self.record_match)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
FP8_DTYPE, self.patterns, self.record_match)
per_tensor=False).register(
self.patterns,
self.record_match)
# WARNING: This is a hack to clear the pattern matcher cache # WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon. # and allow multiple values of epsilon.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment