"vscode:/vscode.git/clone" did not exist on "bfde49e287cb5522fb0625c8e2b4e03cac20cbb2"
Unverified Commit b5f6c5f8 authored by Yusuf Mohammad's avatar Yusuf Mohammad Committed by GitHub
Browse files

Added general ND x ND matmul and unit test for it (#39909)


Signed-off-by: default avatarYusuf <yusufmohammad@live.com>
parent bfde49e2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test batch-invariant matmul against torch.matmul for various shape combinations.
Tests correctness (matches torch.matmul) and batch invariance (result for one
item doesn't change based on other items in the batch).
"""
import pytest
import torch
from utils import skip_unsupported
from vllm.model_executor.layers.batch_invariant import matmul_batch_invariant
from vllm.platforms import current_platform
DEVICE_TYPE = current_platform.device_type
@skip_unsupported
@pytest.mark.parametrize(
"a_shape,b_shape",
[
# 2D x 2D
((32, 64), (64, 16)),
# 2D x 3D
((64, 16), (4, 16, 32)),
# 3D x 2D
((4, 32, 64), (64, 16)),
# 4D x 2D
((1, 4, 32, 64), (64, 16)),
# 3D x 3D
((4, 32, 64), (4, 64, 16)),
# 3D x 4D
((2, 32, 64), (1, 2, 64, 16)),
# 4D x 3D (Gemma4 pattern)
((1, 2, 32, 64), (2, 64, 16)),
# 4D x 4D
((1, 2, 32, 64), (4, 2, 64, 16)),
# 2D x 4D
((32, 64), (1, 2, 64, 16)),
# 2D x 5D
((32, 64), (1, 2, 2, 64, 16)),
# 5D x 2D
((1, 2, 2, 32, 64), (64, 16)),
# 5D x 5D
((1, 2, 4, 32, 64), (1, 2, 4, 64, 16)),
],
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_matmul_correctness(a_shape, b_shape, dtype):
"""
Compare matmul_batch_invariant against torch.matmul for various shapes.
"""
device = torch.device(DEVICE_TYPE)
torch.manual_seed(42)
a = torch.rand(a_shape, dtype=dtype, device=device)
b = torch.rand(b_shape, dtype=dtype, device=device)
# Standard implementation (CUDA ops)
standard_output = torch.matmul(a, b)
# Batch-invariant implementation (Triton)
triton_output = matmul_batch_invariant(a, b)
# Compare outputs
# Use looser tolerance for bfloat16 due to its lower precision
if dtype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1 # 10% relative tolerance for bfloat16
else:
rtol, atol = 1e-2, 1e-2 # 1% for float16/float32
torch.testing.assert_close(
triton_output,
standard_output,
rtol=rtol,
atol=atol,
msg=f"matmul mismatch for a ndim={a.ndim}, b ndim={b.ndim},",
)
@skip_unsupported
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_matmul_batch_invariance(dtype):
"""
Verify that the result for one item is bitwise identical regardless
of what other items are in the batch.
"""
device = torch.device(DEVICE_TYPE)
torch.manual_seed(42)
a_single = torch.rand((1, 64, 32), dtype=dtype, device=device)
b = torch.rand((32, 128), dtype=dtype, device=device)
standard_output = matmul_batch_invariant(a_single, b)
a_batch = torch.rand((8, 64, 32), dtype=dtype, device=device)
a_batch[3] = a_single[0]
batch_output = matmul_batch_invariant(a_batch, b)
batch_output_a = batch_output[3]
assert torch.equal(standard_output[0], batch_output_a)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
...@@ -611,51 +612,43 @@ def matmul_batch_invariant(a, b, *, out=None): ...@@ -611,51 +612,43 @@ def matmul_batch_invariant(a, b, *, out=None):
out.copy_(result) out.copy_(result)
return out return out
return result return result
elif a.ndim == 3 and b.ndim == 3: elif b.ndim == 2:
# Handle batched case like bmm # Handle ND x 2D: Common for linear layers
return bmm_batch_invariant(a, b, out=out) # (..., batch, seq, hidden) @ (hidden, out) -> (..., batch, seq, out)
elif a.ndim == 3 and b.ndim == 2: batch_dims = a.shape[:-1]
# Handle 3D x 2D: common for linear layers hidden = a.shape[-1]
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out) out_dim = b.shape[-1]
# Reshape to 2D, do mm, reshape back
batch, seq, hidden = a.shape
a_2d = a.reshape(-1, hidden) a_2d = a.reshape(-1, hidden)
result_2d = matmul_persistent(a_2d, b) result_2d = matmul_persistent(a_2d, b)
result = result_2d.reshape(batch, seq, -1) result = result_2d.reshape(batch_dims + (out_dim,))
if out is not None: if out is not None:
out.copy_(result) out.copy_(result)
return out return out
return result return result
elif a.ndim == 2 and b.ndim == 3: elif a.ndim >= 2 and b.ndim >= 3:
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N) # Generic handler for 2D x ND and ND x ND (except 1D)
# By broadcasting `a` to 3D, we can reuse the batched matrix # Broadcast dims to ensure both matrices have the same shape
# multiplication logic. # If 2D x ND, then unsqueeze to add a dim to a
a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1) if a.ndim == 2:
return bmm_batch_invariant(a_expanded, b, out=out) a = a.unsqueeze(0)
elif a.ndim == 4 and b.ndim == 4: broadcast_shape = torch.broadcast_shapes(a.shape[:-2], b.shape[:-2])
# Handle 4D attention tensors: [batch, heads, seq, dim] a = a.expand(broadcast_shape + a.shape[-2:])
# Reshape to 3D, process, reshape back b = b.expand(broadcast_shape + b.shape[-2:])
batch, heads, seq_a, dim_a = a.shape batch_dim = math.prod(broadcast_shape)
_, _, dim_b, seq_b = b.shape # Reuse broadcast shape to get all dims except mm dims
a_3d = a.reshape(batch_dim, a.shape[-2], a.shape[-1])
# Reshape to [batch*heads, seq_a, dim_a] b_3d = b.reshape(batch_dim, b.shape[-2], b.shape[-1])
a_3d = a.reshape(batch * heads, seq_a, dim_a)
b_3d = b.reshape(batch * heads, dim_b, seq_b)
# Do batched matmul # Do batched matmul
result_3d = bmm_batch_invariant(a_3d, b_3d) result_3d = bmm_batch_invariant(a_3d, b_3d)
# Reshape back to [broadcast_shape, seq_a, seq_b]
# Reshape back to [batch, heads, seq_a, seq_b] result = result_3d.reshape(broadcast_shape + (a.shape[-2], b.shape[-1]))
result = result_3d.reshape(batch, heads, seq_a, seq_b)
if out is not None: if out is not None:
out.copy_(result) out.copy_(result)
return out return out
return result return result
else: else:
raise ValueError( raise ValueError(
f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, " f"matmul_batch_invariant requires both inputs be at least 2D "
f"3D x 2D, 2D x 3D, and 4D x 4D, "
f"got shapes {a.shape} and {b.shape}" f"got shapes {a.shape} and {b.shape}"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment