Unverified Commit b5f6c5f8 authored by Yusuf Mohammad's avatar Yusuf Mohammad Committed by GitHub
Browse files

Added general ND x ND matmul and unit test for it (#39909)


Signed-off-by: default avatarYusuf <yusufmohammad@live.com>
parent bfde49e2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test batch-invariant matmul against torch.matmul for various shape combinations.
Tests correctness (matches torch.matmul) and batch invariance (result for one
item doesn't change based on other items in the batch).
"""
import pytest
import torch
from utils import skip_unsupported
from vllm.model_executor.layers.batch_invariant import matmul_batch_invariant
from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
@skip_unsupported
@pytest.mark.parametrize(
"a_shape,b_shape",
[
# 2D x 2D
((32, 64), (64, 16)),
# 2D x 3D
((64, 16), (4, 16, 32)),
# 3D x 2D
((4, 32, 64), (64, 16)),
# 4D x 2D
((1, 4, 32, 64), (64, 16)),
# 3D x 3D
((4, 32, 64), (4, 64, 16)),
# 3D x 4D
((2, 32, 64), (1, 2, 64, 16)),
# 4D x 3D (Gemma4 pattern)
((1, 2, 32, 64), (2, 64, 16)),
# 4D x 4D
((1, 2, 32, 64), (4, 2, 64, 16)),
# 2D x 4D
((32, 64), (1, 2, 64, 16)),
# 2D x 5D
((32, 64), (1, 2, 2, 64, 16)),
# 5D x 2D
((1, 2, 2, 32, 64), (64, 16)),
# 5D x 5D
((1, 2, 4, 32, 64), (1, 2, 4, 64, 16)),
],
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_matmul_correctness(a_shape, b_shape, dtype):
"""
Compare matmul_batch_invariant against torch.matmul for various shapes.
"""
device = torch.device(DEVICE_TYPE)
torch.manual_seed(42)
a = torch.rand(a_shape, dtype=dtype, device=device)
b = torch.rand(b_shape, dtype=dtype, device=device)
# Standard implementation (CUDA ops)
standard_output = torch.matmul(a, b)
# Batch-invariant implementation (Triton)
triton_output = matmul_batch_invariant(a, b)
# Compare outputs
# Use looser tolerance for bfloat16 due to its lower precision
if dtype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1 # 10% relative tolerance for bfloat16
else:
rtol, atol = 1e-2, 1e-2 # 1% for float16/float32
torch.testing.assert_close(
triton_output,
standard_output,
rtol=rtol,
atol=atol,
msg=f"matmul mismatch for a ndim={a.ndim}, b ndim={b.ndim},",
)
@skip_unsupported
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_matmul_batch_invariance(dtype):
"""
Verify that the result for one item is bitwise identical regardless
of what other items are in the batch.
"""
device = torch.device(DEVICE_TYPE)
torch.manual_seed(42)
a_single = torch.rand((1, 64, 32), dtype=dtype, device=device)
b = torch.rand((32, 128), dtype=dtype, device=device)
standard_output = matmul_batch_invariant(a_single, b)
a_batch = torch.rand((8, 64, 32), dtype=dtype, device=device)
a_batch[3] = a_single[0]
batch_output = matmul_batch_invariant(a_batch, b)
batch_output_a = batch_output[3]
assert torch.equal(standard_output[0], batch_output_a)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import os
from collections.abc import Callable
from typing import Any
......@@ -611,51 +612,43 @@ def matmul_batch_invariant(a, b, *, out=None):
out.copy_(result)
return out
return result
elif a.ndim == 3 and b.ndim == 3:
# Handle batched case like bmm
return bmm_batch_invariant(a, b, out=out)
elif a.ndim == 3 and b.ndim == 2:
# Handle 3D x 2D: common for linear layers
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
# Reshape to 2D, do mm, reshape back
batch, seq, hidden = a.shape
elif b.ndim == 2:
# Handle ND x 2D: Common for linear layers
# (..., batch, seq, hidden) @ (hidden, out) -> (..., batch, seq, out)
batch_dims = a.shape[:-1]
hidden = a.shape[-1]
out_dim = b.shape[-1]
a_2d = a.reshape(-1, hidden)
result_2d = matmul_persistent(a_2d, b)
result = result_2d.reshape(batch, seq, -1)
result = result_2d.reshape(batch_dims + (out_dim,))
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 2 and b.ndim == 3:
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
# By broadcasting `a` to 3D, we can reuse the batched matrix
# multiplication logic.
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1)
return bmm_batch_invariant(a_expanded, b, out=out)
elif a.ndim == 4 and b.ndim == 4:
# Handle 4D attention tensors: [batch, heads, seq, dim]
# Reshape to 3D, process, reshape back
batch, heads, seq_a, dim_a = a.shape
_, _, dim_b, seq_b = b.shape
# Reshape to [batch*heads, seq_a, dim_a]
a_3d = a.reshape(batch * heads, seq_a, dim_a)
b_3d = b.reshape(batch * heads, dim_b, seq_b)
elif a.ndim >= 2 and b.ndim >= 3:
# Generic handler for 2D x ND and ND x ND (except 1D)
# Broadcast dims to ensure both matrices have the same shape
# If 2D x ND, then unsqueeze to add a dim to a
if a.ndim == 2:
a = a.unsqueeze(0)
broadcast_shape = torch.broadcast_shapes(a.shape[:-2], b.shape[:-2])
a = a.expand(broadcast_shape + a.shape[-2:])
b = b.expand(broadcast_shape + b.shape[-2:])
batch_dim = math.prod(broadcast_shape)
# Reuse broadcast shape to get all dims except mm dims
a_3d = a.reshape(batch_dim, a.shape[-2], a.shape[-1])
b_3d = b.reshape(batch_dim, b.shape[-2], b.shape[-1])
# Do batched matmul
result_3d = bmm_batch_invariant(a_3d, b_3d)
# Reshape back to [batch, heads, seq_a, seq_b]
result = result_3d.reshape(batch, heads, seq_a, seq_b)
# Reshape back to [broadcast_shape, seq_a, seq_b]
result = result_3d.reshape(broadcast_shape + (a.shape[-2], b.shape[-1]))
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
f"3D x 2D, 2D x 3D, and 4D x 4D, "
f"matmul_batch_invariant requires both inputs be at least 2D "
f"got shapes {a.shape} and {b.shape}"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment