Unverified Commit bd56c983 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename...


[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)
Signed-off-by: default avatarluka <luka@neuralmagic.com>
parent 084bbac8
......@@ -13,21 +13,26 @@ class TestBackend:
This class provides a simple Inductor backend that can be used for testing.
It takes a list of custom passes and runs them after Inductor's passes.
It also saves the graph before and after the custom passes for inspection.
Inductor config can be modified directly by editing the inductor_config
property. This can be helpful for adding passes like the
'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
"""
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
None]]):
self.custom_passes = list(passes)
from torch._inductor import config
self.current_config = config.shallow_copy_dict()
self.current_config['force_disable_caches'] = True
self.current_config['post_grad_custom_post_pass'] = self.post_pass
self.inductor_config = config.shallow_copy_dict()
self.inductor_config['force_disable_caches'] = True
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs,
config_patches=self.current_config)
config_patches=self.inductor_config)
def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)
......
......@@ -9,7 +9,7 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig
from .backend import TestBackend
......@@ -50,11 +50,11 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
torch.set_default_device("cuda")
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config)
passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
func_pass = FixFunctionalizationPass(config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
......
......@@ -5,23 +5,25 @@ import torch
from compressed_tensors.quantization import FP8_DTYPE
import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity)
from .backend import TestBackend
class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, static: bool, *args,
**kwargs):
def __init__(self, hidden_size: int, eps: float, static: bool,
cutlass_fp8_enabled: bool, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cutlass_fp8_enabled = cutlass_fp8_enabled
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
if static:
......@@ -41,7 +43,8 @@ class TestModel(torch.nn.Module):
self.w[0],
self.wscale[0],
self.scale[0],
use_per_token_if_dynamic=True)
use_per_token_if_dynamic=True,
cutlass_fp8_supported=self.cutlass_fp8_enabled)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)
......@@ -49,7 +52,8 @@ class TestModel(torch.nn.Module):
self.w[1],
self.wscale[1],
self.scale[1],
use_per_token_if_dynamic=True)
use_per_token_if_dynamic=True,
cutlass_fp8_supported=self.cutlass_fp8_enabled)
y3, resid = self.norm[2](x3, resid) # use resid here
return y3
......@@ -59,21 +63,28 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cutlass_fp8_enabled):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config)
backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps, static)
backend = TestBackend(noop_pass, fusion_pass)
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
......@@ -107,12 +118,12 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
fp8_quant = QUANT_OPS[key]
# In pre-nodes, fp8 quant should be present and fused kernels should not
# In pre-nodes, fp8 quant should be there and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
find_auto_fn(pre_nodes, fp8_quant)
# In post-nodes, fused kernels should be present and fp8 quant should not
# In post-nodes, fused kernels should be there and fp8 quant should not
find_auto_fn(post_nodes, rms_quant)
find_auto_fn(post_nodes, add_rms_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
# SPDX-License-Identifier: Apache-2.0
from typing import Union
from typing import Iterable, Union
import torch.fx
from torch import SymInt
......@@ -13,15 +13,15 @@ from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class RedundantReshapesPass(VllmInductorPass):
class NoOpEliminationPass(VllmInductorPass):
"""
This is an inductor pass that removes redundant reshape operations.
This is an inductor pass that removes redundant reshape/slice operations.
It is required for RMSNorm-quant fusion to work properly.
That's because apply_fp8_linear adds a reshape, which is redundant
in the 2D-case.
Example graph:
in the 2D-case. Additionally, torch internal no-op elimination pass does
not handle certain slice variants.
Example graph 1:
getitem_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
......@@ -31,11 +31,27 @@ class RedundantReshapesPass(VllmInductorPass):
getitem_1: "f16[s0, 4096]" = ...
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Example graph 2:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
Can be replaced with:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
out: "f16[s0, 4096]" = at[1]
TODO(luka): This is currently tested in test_fusion,
but separate tests could be good.
"""
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_reshapes")
self.dump_graph(graph, "before_noop_elimination")
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
......@@ -50,23 +66,52 @@ class RedundantReshapesPass(VllmInductorPass):
# Invalid reshape args, skip
continue
if all(
self.dims_equivalent(s, i_s)
for s, i_s in zip(shape, input_shape)):
if self.all_dims_equivalent(shape, input_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice.Tensor):
input, dim_index, start, end = node.args[:4]
input_shape = input.meta["val"].shape
i_dim = input_shape[dim_index]
if start == 0 and self.dims_equivalent(end, i_dim):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
logger.debug("Removed %s no-op reshapes", count)
elif is_func(node, torch.ops.aten.slice_scatter.default):
base, view, dim_index, start, end = node.args[:5]
base_shape = base.meta["val"].shape
view_shape = view.meta["val"].shape
view_dim = view_shape[dim_index]
self.dump_graph(graph, "after_reshapes")
# Check that view fully covers base and the full view is used
# (if the view fully covered the base after slicing but was not
# fully used, we could replace slice_scatter with a simple slice
# but that's a niche case).
if (base_shape == view_shape and start == 0
and self.dims_equivalent(end, view_dim)):
node.replace_all_uses_with(view)
graph.erase_node(node)
count += 1
logger.debug("Removed %s no-op reshapes and slices", count)
self.dump_graph(graph, "after_noop_elimination")
self.end_and_log()
def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
i_dims: Iterable[Union[int, SymInt]]):
return all(
self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
i_dim: Union[int, SymInt]) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape
:param dim: The dimension arg to reshape/slice
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
......
......@@ -11,7 +11,7 @@ from vllm.logger import init_logger
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import InductorPass
from .reshapes import RedundantReshapesPass
from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)
......@@ -36,7 +36,7 @@ class PostGradPassManager(Parent):
The order of the post-grad post-passes is:
1. passes (constructor parameter)
2. default passes (RedundantReshapesPass, FusionPass)
2. default passes (NoopEliminationPass, FusionPass)
3. config["post_grad_custom_post_pass"] (if it exists)
4. fix_functionalization
This way, all passes operate on a functionalized graph.
......@@ -54,8 +54,8 @@ class PostGradPassManager(Parent):
def configure(self, pass_config: CompilationConfig.PassConfig):
self.pass_config = pass_config
if pass_config.enable_reshape:
self.passes += [RedundantReshapesPass(pass_config)]
if pass_config.enable_noop:
self.passes += [NoOpEliminationPass(pass_config)]
if pass_config.enable_fusion:
self.passes += [FusionPass.instance(pass_config)]
......
......@@ -28,8 +28,8 @@ class VllmInductorPass(InductorPass):
self.config = config
self.pass_name = self.__class__.__name__
def dump_graph(self, graph: torch.fx.Graph, stage: str):
if stage in self.config.dump_graph_stages:
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
if stage in self.config.dump_graph_stages or always:
# Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1
rank = f"-{get_tp_rank()}" if parallel else ""
......@@ -49,3 +49,17 @@ class VllmInductorPass(InductorPass):
self._end_time = time.perf_counter_ns()
duration_ms = float(self._end_time - self._start_time) / 1.0e6
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
class PrinterInductorPass(VllmInductorPass):
def __init__(self,
name: str,
config: CompilationConfig.PassConfig,
always=False):
super().__init__(config)
self.name = name
self.always = always
def __call__(self, graph: torch.fx.Graph):
self.dump_graph(graph, self.name, always=self.always)
......@@ -2993,13 +2993,13 @@ class CompilationConfig(BaseModel):
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graphs. Default is .
- enable_fusion: whether to enable the custom fusion pass.
- enable_reshape: whether to enable the custom reshape elimination pass.
TODO better pass enabling system.
- enable_noop: whether to enable the custom no-op elimination pass.
TODO(luka) better pass enabling system.
"""
dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
enable_reshape: bool = True
enable_noop: bool = True
def uuid(self):
"""
......@@ -3008,13 +3008,12 @@ class CompilationConfig(BaseModel):
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
dict_ = self.model_dump(
include={"enable_fusion", "enable_reshape"})
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).digest()
def model_post_init(self, __context: Any) -> None:
if not self.enable_reshape and self.enable_fusion:
if not self.enable_noop and self.enable_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm + quant (fp8) fusion might not work")
......@@ -3411,7 +3410,7 @@ class VllmConfig:
self.compilation_config.use_inductor = True
self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.pass_config.enable_fusion = False
self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.pass_config.enable_noop = False
self.compilation_config.level = CompilationLevel.PIECEWISE
self._set_cudagraph_sizes()
......
......@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform
# Input scaling factors are no longer optional in _scaled_mm starting
......@@ -161,10 +162,14 @@ def apply_fp8_linear(
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
config = get_current_vllm_config().compilation_config
do_pad = config.level < CompilationLevel.PIECEWISE
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=17,
num_token_padding=17 if do_pad else None,
use_per_token_if_dynamic=use_per_token_if_dynamic)
per_tensor_weights = (weight_scale.numel() == 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment