"csrc/vscode:/vscode.git/clone" did not exist on "4f35be10a96feeca0328d3ab8d359e1eaae5c23d"
Unverified Commit bd7157a0 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent be429d0c
...@@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized ...@@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload from torch._ops import OpOverload
from vllm.config import VllmConfig from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -92,13 +93,19 @@ class RMSNormQuantPattern: ...@@ -92,13 +93,19 @@ class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey): def __init__(self, epsilon: float, key: FusedRMSQuantKey):
self.epsilon = epsilon self.epsilon = epsilon
self.quant_dtype = key.quant.dtype self.quant_dtype = key.quant.dtype
config = get_current_vllm_config()
assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" self.model_dtype = config.model_config.dtype if config.model_config else None
self.QUANT_OP = QUANT_OPS[key.quant]
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key] self.FUSED_OP = FUSED_OPS[key]
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(key.quant)
class RMSNormStaticQuantPattern(RMSNormQuantPattern): class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
...@@ -112,34 +119,18 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): ...@@ -112,34 +119,18 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
# Cannot use methods, as the self argument affects tracing # Cannot use methods, as the self argument affects tracing
def pattern( def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
result: torch.Tensor, result_rms = self.rmsnorm_matcher(input, weight)
result_rms: torch.Tensor, return self.quant_matcher(result_rms, scale)[0]
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon,
)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale
)
# result def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
return at2[1] # In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
def replacement( result = torch.empty(
result: torch.Tensor, input.shape, device=input.device, dtype=self.quant_dtype
result_rms: torch.Tensor, )
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized( at = auto_functionalized(
self.FUSED_OP, self.FUSED_OP,
result=result, result=result,
...@@ -153,12 +144,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): ...@@ -153,12 +144,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
return at[1] return at[1]
inputs = [ inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result # input, weight
empty_bf16(5, 4), # result_rms *self.rmsnorm_matcher.inputs(),
empty_bf16(5, 4), # input self.quant_matcher.inputs()[1], # scale
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
] ]
pattern(*inputs)
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
...@@ -175,33 +165,27 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): ...@@ -175,33 +165,27 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern( def pattern(
result: torch.Tensor,
input: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ):
at = auto_functionalized( result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
RMS_ADD_OP, result, _ = self.quant_matcher(result_rms, scale)
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
at1 = auto_functionalized(
self.QUANT_OP, result=result, input=at[1], scale=scale
)
# result, residual return result, residual
return at1[1], at[2]
def replacement( def replacement(
result: torch.Tensor,
input: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized( at = auto_functionalized(
self.FUSED_OP, self.FUSED_OP,
result=result, result=result,
...@@ -216,11 +200,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): ...@@ -216,11 +200,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
return at[1], at[2] return at[1], at[2]
inputs = [ inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result # input, weight, residual
empty_bf16(5, 4), # input *self.rmsnorm_matcher.inputs(),
empty_bf16(5, 4), # residual self.quant_matcher.inputs()[1], # scale
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
] ]
pm.register_replacement( pm.register_replacement(
...@@ -248,34 +230,18 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -248,34 +230,18 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
super().__init__(epsilon, key) super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern( def pattern(input: torch.Tensor, weight: torch.Tensor):
result: torch.Tensor, result_rms = self.rmsnorm_matcher(input, weight)
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon,
)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
)
# result, scale # result, scale
return at2[1], at2[2] return self.quant_matcher(result_rms)
def replacement( def replacement(input: torch.Tensor, weight: torch.Tensor):
result: torch.Tensor, # In case we're matching native rms-norm, conversions might be
result_rms: torch.Tensor, # optimized out. We convert here just to be safe.
input: torch.Tensor, input = input.to(dtype=self.model_dtype)
weight: torch.Tensor,
scale: torch.Tensor, result = torch.empty_like(input, dtype=self.quant_dtype)
): scale = self.quant_matcher.make_scale(input)
at = auto_functionalized( at = auto_functionalized(
self.FUSED_OP, self.FUSED_OP,
result=result, result=result,
...@@ -290,18 +256,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -290,18 +256,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
# result, scale # result, scale
return at[1], at[2] return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
]
pm.register_replacement( pm.register_replacement(
pattern, pattern,
replacement, replacement,
inputs, self.rmsnorm_matcher.inputs(),
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
) )
...@@ -323,34 +281,21 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -323,34 +281,21 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
super().__init__(epsilon, key) super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern( def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
result: torch.Tensor, result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
input: torch.Tensor, result, scale = self.quant_matcher(result_rms)
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
at1 = auto_functionalized(
self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
)
# result, residual, scale return result, residual, scale
return at1[1], at[2], at1[2]
def replacement( def replacement(
result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
): ):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized( at = auto_functionalized(
self.FUSED_OP, self.FUSED_OP,
result=result, result=result,
...@@ -365,18 +310,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -365,18 +310,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
# result, residual, scale # result, residual, scale
return at[1], at[3], at[2] return at[1], at[3], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
]
pm.register_replacement( pm.register_replacement(
pattern, pattern,
replacement, replacement,
inputs, self.rmsnorm_matcher.inputs(),
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
) )
...@@ -396,23 +333,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): ...@@ -396,23 +333,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
pass_name="rmsnorm_quant_fusion_pass" pass_name="rmsnorm_quant_fusion_pass"
) )
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant # Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns self.patterns
) )
# Fuse rms_norm + dynamic per-token fp8 quant # Fuse rms_norm + static fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns self.patterns
) )
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
......
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
...@@ -20,7 +22,9 @@ from vllm.platforms import current_platform ...@@ -20,7 +22,9 @@ from vllm.platforms import current_platform
from vllm.utils import round_up from vllm.utils import round_up
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .fx_utils import is_func
from .inductor_pass import enable_fake_mode from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherQuantFP8
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -66,9 +70,13 @@ class AttentionQuantPattern(ABC): ...@@ -66,9 +70,13 @@ class AttentionQuantPattern(ABC):
return torch.empty(*args, **kwargs) return torch.empty(*args, **kwargs)
@staticmethod @staticmethod
def wrap_trace_fn(process_fx, trace_fn): def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs)) gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped return wrapped
...@@ -77,7 +85,20 @@ class AttentionQuantPattern(ABC): ...@@ -77,7 +85,20 @@ class AttentionQuantPattern(ABC):
from torch._inductor.fx_passes.post_grad import view_to_reshape from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm) view_to_reshape(gm)
return gm
@staticmethod
def remove_noop_permutes(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default):
continue
dims = node.args[1]
if any(dim != i for i, dim in enumerate(dims)):
continue
# this is now an identity op, remove
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
def register_if_supported(self, pm_pass: PatternMatcherPass): def register_if_supported(self, pm_pass: PatternMatcherPass):
if self.layer.impl.fused_output_quant_supported(self.quant_key): if self.layer.impl.fused_output_quant_supported(self.quant_key):
...@@ -108,6 +129,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -108,6 +129,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
) )
super().__init__(layer, quant_key, dtype) super().__init__(layer, quant_key, dtype)
self.quant_matcher = MatcherQuantFP8(quant_key)
def _register(self, pm_pass: PatternMatcherPass): def _register(self, pm_pass: PatternMatcherPass):
def pattern( def pattern(
...@@ -115,7 +137,6 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -115,7 +137,6 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_attn: torch.Tensor,
output_quant: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ):
at1 = auto_functionalized( at1 = auto_functionalized(
...@@ -131,17 +152,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -131,17 +152,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
attn_out_view = RESHAPE_OP( attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size] at1[1], [q.shape[0], self.num_heads * self.head_size]
) )
at2 = auto_functionalized(
self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale return self.quant_matcher(attn_out_view, scale)[0]
)
return at2[1]
def replacement( def replacement(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_attn: torch.Tensor,
output_quant: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ):
# attn output in quant_dtype # attn output in quant_dtype
...@@ -164,13 +182,10 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -164,13 +182,10 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
inputs = [ inputs = [
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q self.empty(5, self.num_heads, self.head_size), # q
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k self.empty(5, self.num_heads, self.head_size), # k
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v self.empty(5, self.num_heads, self.head_size), # v
self.empty( self.empty(5, self.num_heads, self.head_size), # attn_output
5, self.num_heads, self.head_size, dtype=self.dtype
), # attn_output
self.empty_quant(5, self.num_heads * self.head_size), # quant_output
empty_fp32(1, 1), # scale empty_fp32(1, 1), # scale
] ]
...@@ -179,7 +194,9 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -179,7 +194,9 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
replacement, replacement,
inputs, inputs,
AttentionQuantPattern.wrap_trace_fn( AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
), ),
pm_pass, pm_pass,
) )
...@@ -279,7 +296,9 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ...@@ -279,7 +296,9 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
replacement, replacement,
inputs, inputs,
AttentionQuantPattern.wrap_trace_fn( AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
), ),
pm_pass, pm_pass,
) )
......
...@@ -6,7 +6,7 @@ from collections.abc import Iterable, Iterator ...@@ -6,7 +6,7 @@ from collections.abc import Iterable, Iterator
from torch import fx from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._ops import OpOverload from torch._ops import OpOverload, OpOverloadPacket
def is_func(node: fx.Node, target) -> bool: def is_func(node: fx.Node, target) -> bool:
...@@ -64,7 +64,17 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node: ...@@ -64,7 +64,17 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
# An auto-functionalization-aware utility for finding nodes with a specific op # An auto-functionalization-aware utility for finding nodes with a specific op
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]: # Also handles op overload packets and finds all overloads
def find_op_nodes(
op: OpOverload | OpOverloadPacket, graph: fx.Graph
) -> Iterator[fx.Node]:
if isinstance(op, OpOverloadPacket):
for overload in op.overloads():
overload_op = getattr(op, overload)
yield from find_op_nodes(overload_op, graph)
return
assert isinstance(op, OpOverload)
if not op._schema.is_mutable: if not op._schema.is_mutable:
yield from graph.find_nodes(op="call_function", target=op) yield from graph.find_nodes(op="call_function", target=op)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
from torch._higher_order_ops import auto_functionalized
from torch._ops import OpOverload
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
_normalize_quant_group_shape,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.platforms import current_platform
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
class MatcherCustomOp(ABC):
def __init__(self, enabled: bool):
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
self.enabled = enabled
self.forward = self.forward_custom if enabled else self.forward_native
@abstractmethod
def forward_custom(self, *args, **kws):
pass
@abstractmethod
def forward_native(self, *args, **kws):
pass
def __call__(self, *args, **kws):
return self.forward(*args, **kws)
def empty(self, *args, **kws):
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
def empty_f32(self, *args, **kws):
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
def inputs(self) -> list[torch.Tensor]:
"""Utility for inputs to the pattern"""
raise NotImplementedError
class MatcherRMSNorm(MatcherCustomOp):
def __init__(self, epsilon: float, enabled: bool | None = None):
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
def inputs(self):
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
result = torch.empty_like(input)
_, result = auto_functionalized(
RMS_OP,
result=result,
input=input,
weight=weight,
epsilon=self.epsilon,
)
return result
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight
)
class MatcherFusedAddRMSNorm(MatcherCustomOp):
def __init__(self, epsilon: float, enabled: bool | None = None):
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
def inputs(self):
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
residual = self.empty(5, 16)
return [input, weight, residual]
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
_, result, residual = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
return result, residual
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
)
class MatcherQuantFP8(MatcherCustomOp):
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
self.QUANT_OP = QUANT_OPS[quant_key]
assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by"
)
assert quant_key.scale2 is None
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)
def forward_custom(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_key.dtype
)
if self.quant_key.scale.static:
assert scale is not None
_, result = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale
)
return result, scale
else:
assert scale is None
scale = self.make_scale(input)
_, result, scale = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
)
return result, scale
def forward_native(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale)
def make_scale(self, input: torch.Tensor):
normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape
)
scale_shape = (
input.shape[0] // normalized_group_shape[0],
input.shape[1] // normalized_group_shape[1],
)
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16)
if self.quant_key.scale.static:
return [input, self.empty_f32(1, 1)]
return [input]
...@@ -22,6 +22,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): ...@@ -22,6 +22,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
import depyf import depyf
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
logger.debug("Dumping depyf output to %s", path)
global context_manager global context_manager
context_manager = depyf.prepare_debug(path.as_posix()) context_manager = depyf.prepare_debug(path.as_posix())
context_manager.__enter__() context_manager.__enter__()
......
...@@ -5,7 +5,7 @@ import functools ...@@ -5,7 +5,7 @@ import functools
from torch import fx as fx from torch import fx as fx
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import set_env_var from vllm.utils import set_env_var
...@@ -88,27 +88,30 @@ class PostGradPassManager(CustomGraphPass): ...@@ -88,27 +88,30 @@ class PostGradPassManager(CustomGraphPass):
def configure(self, config: VllmConfig): def configure(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config self.pass_config = config.compilation_config.pass_config
if self.pass_config.enable_noop:
self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_sequence_parallelism: # Set the current vllm config to allow tracing CustomOp instances
self.passes += [SequenceParallelismPass(config)] with set_current_vllm_config(config, check_compile=False):
if self.pass_config.enable_async_tp: if self.pass_config.enable_noop:
self.passes += [AsyncTPPass(config)] self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_fi_allreduce_fusion: if self.pass_config.enable_sequence_parallelism:
self.passes += [AllReduceFusionPass(config)] self.passes += [SequenceParallelismPass(config)]
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_fusion: if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [RMSNormQuantFusionPass(config)] self.passes += [AllReduceFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_attn_fusion: if self.pass_config.enable_fusion:
self.passes += [AttnFusionPass(config)] self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
# needs a functional graph if self.pass_config.enable_attn_fusion:
self.post_cleanup = PostCleanupPass(config) self.passes += [AttnFusionPass(config)]
self.fix_functionalization = FixFunctionalizationPass(config)
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
# [HACK: Bug with Inductor graph partition and torch.compile cache] # [HACK: Bug with Inductor graph partition and torch.compile cache]
# In PyTorch 2.9, torch.compile has a bug where the graph # In PyTorch 2.9, torch.compile has a bug where the graph
......
...@@ -128,7 +128,8 @@ class VllmPatternMatcherPass(VllmInductorPass): ...@@ -128,7 +128,8 @@ class VllmPatternMatcherPass(VllmInductorPass):
f" please add to dump_patterns if there are any errors.\n\n" f" please add to dump_patterns if there are any errors.\n\n"
f"from torch._higher_order_ops.auto_functionalize import " f"from torch._higher_order_ops.auto_functionalize import "
f"auto_functionalized as auto_functionalized\n" f"auto_functionalized as auto_functionalized\n"
f"from torch._inductor.pattern_matcher import *", f"from torch._inductor.pattern_matcher import *\n"
f"vllm = torch.ops.vllm",
file=f, file=f,
) )
......
...@@ -178,14 +178,11 @@ class RMSNorm(CustomOp): ...@@ -178,14 +178,11 @@ class RMSNorm(CustomOp):
self.variance_size_override = ( self.variance_size_override = (
None if var_hidden_size == hidden_size else var_hidden_size None if var_hidden_size == hidden_size else var_hidden_size
) )
weight_dtype = dtype or torch.get_default_dtype()
self.has_weight = has_weight self.has_weight = has_weight
if dtype is not None: self.weight = torch.ones(hidden_size, dtype=weight_dtype)
self.weight = torch.ones(hidden_size, dtype=dtype)
else:
self.weight = torch.ones(hidden_size)
if self.has_weight: if self.has_weight:
self.weight = nn.Parameter(self.weight) self.weight = nn.Parameter(self.weight)
weight_dtype = self.weight.data.dtype
if current_platform.is_rocm(): if current_platform.is_rocm():
self.rocm_norm_func = dispatch_rocm_rmsnorm_func( self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
...@@ -195,47 +192,69 @@ class RMSNorm(CustomOp): ...@@ -195,47 +192,69 @@ class RMSNorm(CustomOp):
with_fused_add=True, dtype=weight_dtype with_fused_add=True, dtype=weight_dtype
) )
def forward_native( @staticmethod
self, def forward_static(
x: torch.Tensor, x: torch.Tensor,
variance_epsilon: float,
hidden_size: int,
orig_dtype: torch.dtype,
weight: torch.Tensor | None = None,
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
variance_size_override: int | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32) x = x.to(torch.float32)
if residual is not None: if residual is not None:
x = x + residual.to(torch.float32) # residual promoted f16->f32 automatically,
# otherwise Inductor eliminates the casts to and from f16,
# increasing memory usage (and complicating pattern matching)
x = x + residual
residual = x.to(orig_dtype) residual = x.to(orig_dtype)
hidden_size = x.shape[-1] if x.shape[-1] != hidden_size:
if hidden_size != self.hidden_size:
raise ValueError( raise ValueError(
"Expected hidden_size to be " f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
f"{self.hidden_size}, but found: {hidden_size}"
) )
if self.variance_size_override is None: if variance_size_override is None:
x_var = x x_var = x
else: else:
if hidden_size < self.variance_size_override: if hidden_size < variance_size_override:
raise ValueError( raise ValueError(
"Expected hidden_size to be at least " "Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}" f"{variance_size_override}, but found: {hidden_size}"
) )
x_var = x[:, :, : self.variance_size_override] x_var = x[:, :, :variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True) variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype) x = x.to(orig_dtype)
if self.has_weight: if weight is not None:
x = x * self.weight x = x * weight
if residual is None: if residual is None:
return x return x
else: else:
return x, residual return x, residual
def forward_native(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(
x,
self.variance_epsilon,
self.hidden_size,
x.dtype,
self.weight.data if self.has_weight else None,
residual,
self.variance_size_override,
)
def forward_cuda( def forward_cuda(
self, self,
x: torch.Tensor, x: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment