Unverified Commit 67661375 authored by Andy Lo's avatar Andy Lo Committed by GitHub
Browse files

[BugFix] Fix noop elimination edge case (#26394)


Signed-off-by: default avatarAndy Lo <andy@mistral.ai>
parent 213b6445
...@@ -12,15 +12,23 @@ from .backend import TestBackend ...@@ -12,15 +12,23 @@ from .backend import TestBackend
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("num_tokens", [256, 1024]) # Important edge case is when `num_tokens == buffer_size`
@pytest.mark.parametrize(
("num_tokens", "buffer_size"), [(256, 256), (256, 512), (1024, 1024), (1024, 1025)]
)
@pytest.mark.parametrize("hidden_size", [64, 4096]) @pytest.mark.parametrize("hidden_size", [64, 4096])
def test_noop_elimination(dtype, num_tokens, hidden_size): def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(1) torch.manual_seed(1)
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.pos_embed = torch.empty(buffer_size, hidden_size, dtype=dtype)
def forward(self, x): def forward(self, x):
x += self.pos_embed[: x.shape[0]]
# Chain of reshapes # Chain of reshapes
y = x.reshape(-1, 128, 32) y = x.reshape(-1, 128, 32)
z = y.reshape(-1, 4096) z = y.reshape(-1, 4096)
...@@ -65,9 +73,10 @@ def test_noop_elimination(dtype, num_tokens, hidden_size): ...@@ -65,9 +73,10 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
# The no-op reshape and slice should be eliminated. # The no-op reshape and slice should be eliminated.
# The initial slice on the positional embedding should remain.
# The chain of reshapes should be fused into a single reshape. # The chain of reshapes should be fused into a single reshape.
assert backend.op_count(torch.ops.aten.reshape.default) == 1 assert backend.op_count(torch.ops.aten.reshape.default) == 1
assert backend.op_count(torch.ops.aten.slice.Tensor) == 0 assert backend.op_count(torch.ops.aten.slice.Tensor) == 1
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0 assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0
......
...@@ -81,49 +81,32 @@ class NoOpEliminationPass(VllmInductorPass): ...@@ -81,49 +81,32 @@ class NoOpEliminationPass(VllmInductorPass):
graph.erase_node(input) graph.erase_node(input)
count += 1 count += 1
# Case 2: remove this reshape if it produces the original shape # remove reshape/slice if it produces the original shape
input, shape = node.args[:2] if is_func(node, torch.ops.aten.reshape.default) or is_func(
input_shape = input.meta["val"].shape node, torch.ops.aten.slice.Tensor
if len(shape) != len(input_shape): ):
# Reshape changing rank, skip input = node.args[0]
continue
if shape.count(-1) > 1:
# Invalid reshape args, skip
continue
if self.reshape_all_dims_equivalent(shape, input_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice.Tensor):
# python slicing semantics are different from reshape
# Don't treat -1 as inferred dimension
input, dim_index, start, end = node.args[:4]
input_shape = input.meta["val"].shape input_shape = input.meta["val"].shape
output_shape = node.meta["val"].shape output_shape = node.meta["val"].shape
if self.all_dims_equivalent(input_shape, output_shape):
if output_shape == input_shape:
node.replace_all_uses_with(input) node.replace_all_uses_with(input)
graph.erase_node(node) graph.erase_node(node)
count += 1 count += 1
elif is_func(node, torch.ops.aten.slice_scatter.default): elif is_func(node, torch.ops.aten.slice_scatter.default):
base, view, dim_index, start, end = node.args[:5] base, view, dim_index, start, end = node.args[:5]
base_shape = base.meta["val"].shape base_shape = base.meta["val"].shape
view_shape = view.meta["val"].shape view_shape = view.meta["val"].shape
if base_shape == view_shape: if self.all_dims_equivalent(base_shape, view_shape):
node.replace_all_uses_with(view) node.replace_all_uses_with(view)
graph.erase_node(node) graph.erase_node(node)
count += 1 count += 1
logger.debug("Removed %s no-op reshapes and slices", count) logger.debug("Removed %s no-op reshapes and slices", count)
# ---------------------- Reshape helpers ---------------------- # ---------------------- Shape comparison helpers ----------------------
def reshape_dims_equivalent( def dims_equivalent(
self, dim: Union[int, torch.fx.Node], i_dim: Union[int, SymInt] self, dim: Union[int, SymInt], i_dim: Union[int, SymInt]
) -> bool: ) -> bool:
""" """
This function checks if two dimensions are equivalent. This function checks if two dimensions are equivalent.
...@@ -131,27 +114,24 @@ class NoOpEliminationPass(VllmInductorPass): ...@@ -131,27 +114,24 @@ class NoOpEliminationPass(VllmInductorPass):
:param i_dim: The corresponding dimension in the input tensor :param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent? :return: Are the dimensions equivalent?
There are three cases in which the dimensions are equivalent: There are two cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers) 1. The dimensions are equal (both integers)
2. The reshape dimension is -1 (i.e. inferred) 2. The dimensions both correspond to the same SymInt
3. The dimensions both correspond to the same SymInt
While case 2 does not guarantee the dimensions are equal,
they are equal if all other dimensions are equal.
In case 3, the reshape dimension is a torch.fx.Node,
and its value is a SymInt. That value is equal to the
input dimension.
""" """
# Case 1 and 2 # Case 1
if dim == i_dim or dim == -1: if isinstance(i_dim, int) and isinstance(dim, int):
return True return dim == i_dim
# Case 3 # Case 2
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim if isinstance(i_dim, SymInt) and isinstance(dim, SymInt):
return dim == i_dim
def reshape_all_dims_equivalent( return False
self,
dims: Iterable[Union[int, torch.fx.Node]], def all_dims_equivalent(
i_dims: Iterable[Union[int, SymInt]], self, dims: Iterable[Union[int, SymInt]], i_dims: Iterable[Union[int, SymInt]]
) -> bool: ) -> bool:
return all(self.reshape_dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) dims_ = list(dims)
i_dims_ = list(i_dims)
if len(dims_) != len(i_dims_):
# Different ranks can't be equivalent
return False
return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment