Unverified Commit f98548b9 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Co-authored-by: default avatarSage Moore <sage@neuralmagic.com>
parent 96846bb3
......@@ -305,6 +305,7 @@ steps:
commands:
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_fusion_attn.py
- pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from copy import deepcopy
from typing import Callable, Union
from torch import fx
from torch._ops import OpOverload
from vllm.compilation.fx_utils import (find_specified_fn,
find_specified_fn_maybe)
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import get_current_vllm_config
......@@ -48,18 +49,19 @@ class TestBackend:
# assign by reference, will reflect the final state of the graph
self.final_graph = graph
def check_before_ops(self, ops,
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe, \
ops_fully_replaced=True):
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
for op in ops:
find_fn(self.graph_pre_pass.nodes, op)
if ops_fully_replaced:
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
if fully_replaced:
assert num_post == 0, \
f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops,
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe):
def check_after_ops(self, ops: Sequence[OpOverload]):
for op in ops:
find_fn(self.graph_post_pass.nodes, op)
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
\ No newline at end of file
......@@ -169,8 +169,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
# In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(),
ops_fully_replaced=False)
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
# In post-nodes, fused_matmul_reduce_scatter or \
# fused_all_gather_matmul should exist
......
......@@ -7,8 +7,7 @@ import torch
import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
FusionPass, GroupShape, QuantKey)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
......@@ -30,9 +29,10 @@ class TestModel(torch.nn.Module):
self.cutlass_fp8_enabled = cutlass_fp8_enabled
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.key = QuantKey(dtype=FP8_DTYPE,
static=static,
per_tensor=static,
group_shape=group_shape,
symmetric=True)
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
......@@ -122,9 +122,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
# In pre-nodes, fp8 quant should be there and fused kernels should not
backend.check_before_ops(model.ops_in_model_before(), find_auto_fn,
find_auto_fn_maybe)
backend.check_before_ops(model.ops_in_model_before())
# In post-nodes, fused kernels should be there and fp8 quant should not
backend.check_after_ops(model.ops_in_model_after(), find_auto_fn,
find_auto_fn_maybe)
backend.check_after_ops(model.ops_in_model_after())
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
import torch._dynamo
from tests.compile.backend import TestBackend
from tests.models.utils import check_outputs_equal
from vllm import LLM, SamplingParams
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.platforms import current_platform
# globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None
backend_unfused: Optional[TestBackend] = None
@pytest.mark.parametrize(
"model, quant_key",
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
@pytest.mark.parametrize(
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test CUDA and ROCm")
def test_attention_fusion(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
# Use global backends
global backend, backend_unfused
use_v1 = False # can be made a param once V1 support added
monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1)))
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
# Prompt 4 seems too open-ended, differs between fused and unfused
# (both outputs look reasonable though)
prompts = example_prompts[:4] + example_prompts[5:]
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend_unfused",
)
vllm_config = VllmConfig(compilation_config=compile_config)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
llm = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
max_model_len=2048)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=10,
top_p=0.95)
unfused_output = llm.generate(prompts, sampling_params)
backend_unfused = None # Reset backend to make sure llm gets released
del llm
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend",
)
vllm_config = VllmConfig(compilation_config=compile_config)
# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
max_model_len=2048)
# check support
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype,
quant_key.static,
quant_key.group_shape)
for key, layer in compile_config.static_forward_context.items()
]
print(f"{attn_fusion_supported=}")
if any(attn_fusion_supported):
# Check quant ops
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
# attention ops present in both, just output_scale param changes
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
assert len(attn_nodes_pre) == len(attn_nodes_post)
for i in range(len(attn_nodes_pre)):
assert attn_nodes_pre[i].kwargs["output_scale"] is None
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
assert fused == attn_fusion_supported[i], \
f"Node {i} {'' if fused else 'not '} expected " \
f"to have fused output quant"
# check outputs
fused_output = llm2.generate(prompts, sampling_params)
# transform outputs to format expected by check_outputs_equal
sample_outs = lambda s: (list(s.token_ids), s.text)
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
check_outputs_equal(
outputs_0_lst=outs_lst(unfused_output),
outputs_1_lst=outs_lst(fused_output),
name_0="unfused",
name_1="fused",
)
# Clean Dynamo cache to avoid polluting other case(s)
torch._dynamo.reset()
# Reset backend to make sure llm2 gets released
backend = None
......@@ -1225,6 +1225,7 @@ def scaled_fp8_quant(
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
......@@ -1256,7 +1257,12 @@ def scaled_fp8_quant(
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:
......
......@@ -284,9 +284,25 @@ class AttentionImpl(ABC, Generic[T]):
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape. (-1, -1) for per-tensor.
:return: is fusion supported for this type of quantization
"""
return False
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
......@@ -300,6 +316,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
......
......@@ -374,6 +374,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -388,6 +389,11 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for BlocksparseFlashAttentionImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
......
......@@ -370,6 +370,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: DualChunkFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with DualChunkFlashAttention.
Args:
......@@ -383,6 +385,13 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is None, "Output tensor not supported for DualChunk"
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
(
query,
query_succ,
......
......@@ -673,6 +673,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
......@@ -692,6 +693,11 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
assert (
......
......@@ -975,8 +975,14 @@ class FlashInferImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
# TODO: directly write to output tensor
num_heads: int = self.num_heads
head_size: int = self.head_size
......
......@@ -181,6 +181,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
......@@ -193,6 +194,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape
......
......@@ -192,6 +192,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
......@@ -206,6 +207,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for IpexAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
......
......@@ -1319,11 +1319,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLAImplBase")
if attn_metadata.is_profile_run and \
attn_metadata.context_chunk_workspace is not None:
# During the profile run try to simulate to worse case output size
......
......@@ -172,6 +172,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
......@@ -187,6 +188,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
......
......@@ -38,11 +38,11 @@ def is_rocm_aiter_paged_attn_enabled() -> bool:
@cache
def _get_paged_attn_module() -> PagedAttention:
"""
Initializes the appropriate PagedAttention module from `attention/ops`,
Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
The choice of attention module depends on whether
The choice of attention module depends on whether
AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
......@@ -598,6 +598,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype(
) and static and group_shape == (-1, -1) # per-tensor
# Only supported in the Triton backend
return False
def forward(
self,
layer: AttentionLayer,
......@@ -607,6 +616,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -660,6 +670,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None and not self.use_triton_flash_attn:
raise NotImplementedError(
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
......@@ -799,6 +814,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks[0][None]
if attn_masks is not None else None,
full_scales,
output_scale,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
......@@ -876,6 +892,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes)
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
......@@ -887,7 +904,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert _PARTITION_SIZE_ROCM % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
dtype=query.dtype,
device=output.device,
)
exp_sums = torch.empty(
......@@ -921,9 +938,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
output_scale,
)
else:
output[num_prefill_tokens:] = paged_attn.forward_decode(
# PagedAttention does not support fused quant, manually quantize
if output_scale is None:
out_pa = output[num_prefill_tokens:]
else:
out_pa = torch.empty_like(output[num_prefill_tokens:],
dtype=query.dtype)
out_pa[:] = paged_attn.forward_decode(
decode_query,
key_cache,
value_cache,
......@@ -944,6 +969,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
layer._v_scale,
)
# Manually perform quantization
if output_scale is not None:
out_uq = out_pa.view(-1, self.num_heads * self.head_size)
out_q = output.view(-1, self.num_heads * self.head_size)
ops.scaled_fp8_quant(out_uq,
output_scale,
output=out_q[num_prefill_tokens:])
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
......
......@@ -459,6 +459,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
......@@ -473,6 +474,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl")
# For warming-up
if attn_metadata is None:
......
......@@ -435,6 +435,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata",
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
......@@ -487,6 +488,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersImpl")
attn_type = self.attn_type
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
......
......@@ -430,6 +430,7 @@ def unified_attention_with_output(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
......@@ -444,7 +445,8 @@ def unified_attention_with_output(
value,
kv_cache,
attn_metadata,
output=output)
output=output,
output_scale=output_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
......@@ -455,6 +457,7 @@ def unified_attention_with_output_fake(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
return
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, NamedTuple, Optional
from typing import Callable, ClassVar, NamedTuple, Optional
import torch
import torch._inductor.pattern_matcher as pm
......@@ -34,36 +33,66 @@ RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.
dtype: quantized data type
static: static quantization if True, dynamic if False
per_tensor: per-tensor quantization if True, per-token if False
group_shape: quantization group shape
symmetric: symmetric if True, asymmetric if False
TODO(luka) use QuantDescriptor once standardized:
https://github.com/vllm-project/vllm/issues/8913
"""
dtype: torch.dtype
static: bool
per_tensor: bool = True
group_shape: GroupShape
symmetric: bool = True
def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"QuantKey({'static' if self.static else 'dynamic'},"
f"{fx.graph.dtype_abbrs[self.dtype]},"
f"{'per_tensor' if self.per_tensor else 'per_token'},"
f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
f"{'a' if not self.symmetric else ''}symmetric)")
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
kFp8StaticTensorSym:
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym:
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym:
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
......@@ -83,13 +112,13 @@ class FusedRMSQuantKey(NamedTuple):
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(kFp8StaticTensorSym, False):
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8StaticTensorSym, True):
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
}
......@@ -177,10 +206,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
symmetric=True):
fused_key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
static=True,
per_tensor=True,
symmetric=symmetric))
quant=QuantKey(
dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass):
......@@ -233,10 +263,11 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
static=True,
per_tensor=True,
symmetric=symmetric))
quant=QuantKey(
dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
......@@ -323,12 +354,12 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
per_tensor: bool,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
static=False,
per_tensor=per_tensor,
group_shape=group_shape,
symmetric=symmetric))
super().__init__(epsilon, key)
......@@ -421,12 +452,12 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
per_tensor: bool = True,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
static=False,
per_tensor=per_tensor,
group_shape=group_shape,
symmetric=symmetric))
super().__init__(epsilon, key)
......@@ -566,16 +597,12 @@ class FusionPass(VllmInductorPass):
self.patterns, self.record_match)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE,
per_tensor=False).register(
self.patterns, self.record_match)
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon,
FP8_DTYPE,
per_tensor=False).register(
self.patterns,
self.record_match)
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment