Unverified Commit bd7157a0 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent be429d0c
......@@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
......@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
......@@ -92,13 +93,19 @@ class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
self.QUANT_OP = QUANT_OPS[key.quant]
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(key.quant)
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
......@@ -112,34 +119,18 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def register(self, pm_pass: PatternMatcherPass):
# Cannot use methods, as the self argument affects tracing
def pattern(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon,
)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale
)
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
result_rms = self.rmsnorm_matcher(input, weight)
return self.quant_matcher(result_rms, scale)[0]
# result
return at2[1]
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
def replacement(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_dtype
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
......@@ -153,12 +144,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
return at[1]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
# input, weight
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
......@@ -175,33 +165,27 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
at1 = auto_functionalized(
self.QUANT_OP, result=result, input=at[1], scale=scale
)
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, _ = self.quant_matcher(result_rms, scale)
# result, residual
return at1[1], at[2]
return result, residual
def replacement(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
......@@ -216,11 +200,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
# input, weight, residual
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pm.register_replacement(
......@@ -248,34 +230,18 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon,
)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
)
def pattern(input: torch.Tensor, weight: torch.Tensor):
result_rms = self.rmsnorm_matcher(input, weight)
# result, scale
return at2[1], at2[2]
return self.quant_matcher(result_rms)
def replacement(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
def replacement(input: torch.Tensor, weight: torch.Tensor):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
......@@ -290,18 +256,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
# result, scale
return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
......@@ -323,34 +281,21 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
at1 = auto_functionalized(
self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
)
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
# result, residual, scale
return at1[1], at[2], at1[2]
return result, residual, scale
def replacement(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
......@@ -365,18 +310,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
# result, residual, scale
return at[1], at[3], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
......@@ -396,23 +333,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
pass_name="rmsnorm_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
......
......@@ -2,9 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
......@@ -20,7 +22,9 @@ from vllm.platforms import current_platform
from vllm.utils import round_up
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .fx_utils import is_func
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherQuantFP8
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
......@@ -66,9 +70,13 @@ class AttentionQuantPattern(ABC):
return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(process_fx, trace_fn):
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
......@@ -77,7 +85,20 @@ class AttentionQuantPattern(ABC):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
@staticmethod
def remove_noop_permutes(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default):
continue
dims = node.args[1]
if any(dim != i for i, dim in enumerate(dims)):
continue
# this is now an identity op, remove
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
def register_if_supported(self, pm_pass: PatternMatcherPass):
if self.layer.impl.fused_output_quant_supported(self.quant_key):
......@@ -108,6 +129,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
)
super().__init__(layer, quant_key, dtype)
self.quant_matcher = MatcherQuantFP8(quant_key)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(
......@@ -115,7 +137,6 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
......@@ -131,17 +152,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
)
at2 = auto_functionalized(
self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale
)
return at2[1]
return self.quant_matcher(attn_out_view, scale)[0]
def replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
scale: torch.Tensor,
):
# attn output in quant_dtype
......@@ -164,13 +182,10 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
inputs = [
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v
self.empty(
5, self.num_heads, self.head_size, dtype=self.dtype
), # attn_output
self.empty_quant(5, self.num_heads * self.head_size), # quant_output
self.empty(5, self.num_heads, self.head_size), # q
self.empty(5, self.num_heads, self.head_size), # k
self.empty(5, self.num_heads, self.head_size), # v
self.empty(5, self.num_heads, self.head_size), # attn_output
empty_fp32(1, 1), # scale
]
......@@ -179,7 +194,9 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
......@@ -279,7 +296,9 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
......
......@@ -6,7 +6,7 @@ from collections.abc import Iterable, Iterator
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._ops import OpOverload
from torch._ops import OpOverload, OpOverloadPacket
def is_func(node: fx.Node, target) -> bool:
......@@ -64,7 +64,17 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
# An auto-functionalization-aware utility for finding nodes with a specific op
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
# Also handles op overload packets and finds all overloads
def find_op_nodes(
op: OpOverload | OpOverloadPacket, graph: fx.Graph
) -> Iterator[fx.Node]:
if isinstance(op, OpOverloadPacket):
for overload in op.overloads():
overload_op = getattr(op, overload)
yield from find_op_nodes(overload_op, graph)
return
assert isinstance(op, OpOverload)
if not op._schema.is_mutable:
yield from graph.find_nodes(op="call_function", target=op)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
from torch._higher_order_ops import auto_functionalized
from torch._ops import OpOverload
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
_normalize_quant_group_shape,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.platforms import current_platform
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
class MatcherCustomOp(ABC):
def __init__(self, enabled: bool):
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
self.enabled = enabled
self.forward = self.forward_custom if enabled else self.forward_native
@abstractmethod
def forward_custom(self, *args, **kws):
pass
@abstractmethod
def forward_native(self, *args, **kws):
pass
def __call__(self, *args, **kws):
return self.forward(*args, **kws)
def empty(self, *args, **kws):
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
def empty_f32(self, *args, **kws):
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
def inputs(self) -> list[torch.Tensor]:
"""Utility for inputs to the pattern"""
raise NotImplementedError
class MatcherRMSNorm(MatcherCustomOp):
def __init__(self, epsilon: float, enabled: bool | None = None):
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
def inputs(self):
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
result = torch.empty_like(input)
_, result = auto_functionalized(
RMS_OP,
result=result,
input=input,
weight=weight,
epsilon=self.epsilon,
)
return result
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight
)
class MatcherFusedAddRMSNorm(MatcherCustomOp):
def __init__(self, epsilon: float, enabled: bool | None = None):
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
def inputs(self):
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
residual = self.empty(5, 16)
return [input, weight, residual]
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
_, result, residual = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
return result, residual
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
)
class MatcherQuantFP8(MatcherCustomOp):
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
self.QUANT_OP = QUANT_OPS[quant_key]
assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by"
)
assert quant_key.scale2 is None
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)
def forward_custom(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_key.dtype
)
if self.quant_key.scale.static:
assert scale is not None
_, result = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale
)
return result, scale
else:
assert scale is None
scale = self.make_scale(input)
_, result, scale = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
)
return result, scale
def forward_native(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale)
def make_scale(self, input: torch.Tensor):
normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape
)
scale_shape = (
input.shape[0] // normalized_group_shape[0],
input.shape[1] // normalized_group_shape[1],
)
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16)
if self.quant_key.scale.static:
return [input, self.empty_f32(1, 1)]
return [input]
......@@ -22,6 +22,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
import depyf
path.mkdir(parents=True, exist_ok=True)
logger.debug("Dumping depyf output to %s", path)
global context_manager
context_manager = depyf.prepare_debug(path.as_posix())
context_manager.__enter__()
......
......@@ -5,7 +5,7 @@ import functools
from torch import fx as fx
from vllm import envs
from vllm.config import VllmConfig
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import set_env_var
......@@ -88,27 +88,30 @@ class PostGradPassManager(CustomGraphPass):
def configure(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
if self.pass_config.enable_noop:
self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_sequence_parallelism:
self.passes += [SequenceParallelismPass(config)]
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
# Set the current vllm config to allow tracing CustomOp instances
with set_current_vllm_config(config, check_compile=False):
if self.pass_config.enable_noop:
self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.enable_sequence_parallelism:
self.passes += [SequenceParallelismPass(config)]
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_fusion:
self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_fusion:
self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
# [HACK: Bug with Inductor graph partition and torch.compile cache]
# In PyTorch 2.9, torch.compile has a bug where the graph
......
......@@ -128,7 +128,8 @@ class VllmPatternMatcherPass(VllmInductorPass):
f" please add to dump_patterns if there are any errors.\n\n"
f"from torch._higher_order_ops.auto_functionalize import "
f"auto_functionalized as auto_functionalized\n"
f"from torch._inductor.pattern_matcher import *",
f"from torch._inductor.pattern_matcher import *\n"
f"vllm = torch.ops.vllm",
file=f,
)
......
......@@ -178,14 +178,11 @@ class RMSNorm(CustomOp):
self.variance_size_override = (
None if var_hidden_size == hidden_size else var_hidden_size
)
weight_dtype = dtype or torch.get_default_dtype()
self.has_weight = has_weight
if dtype is not None:
self.weight = torch.ones(hidden_size, dtype=dtype)
else:
self.weight = torch.ones(hidden_size)
self.weight = torch.ones(hidden_size, dtype=weight_dtype)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
weight_dtype = self.weight.data.dtype
if current_platform.is_rocm():
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
......@@ -195,47 +192,69 @@ class RMSNorm(CustomOp):
with_fused_add=True, dtype=weight_dtype
)
def forward_native(
self,
@staticmethod
def forward_static(
x: torch.Tensor,
variance_epsilon: float,
hidden_size: int,
orig_dtype: torch.dtype,
weight: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
variance_size_override: int | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
# residual promoted f16->f32 automatically,
# otherwise Inductor eliminates the casts to and from f16,
# increasing memory usage (and complicating pattern matching)
x = x + residual
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
if x.shape[-1] != hidden_size:
raise ValueError(
"Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}"
f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
)
if self.variance_size_override is None:
if variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
if hidden_size < variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}"
f"{variance_size_override}, but found: {hidden_size}"
)
x_var = x[:, :, : self.variance_size_override]
x_var = x[:, :, :variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if weight is not None:
x = x * weight
if residual is None:
return x
else:
return x, residual
def forward_native(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(
x,
self.variance_epsilon,
self.hidden_size,
x.dtype,
self.weight.data if self.has_weight else None,
residual,
self.variance_size_override,
)
def forward_cuda(
self,
x: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment