Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b5f6c5f8
Unverified
Commit
b5f6c5f8
authored
Apr 18, 2026
by
Yusuf Mohammad
Committed by
GitHub
Apr 18, 2026
Browse files
Added general ND x ND matmul and unit test for it (#39909)
Signed-off-by:
Yusuf
<
yusufmohammad@live.com
>
parent
bfde49e2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
129 additions
and
31 deletions
+129
-31
tests/v1/determinism/test_matmul_batch_invariant.py
tests/v1/determinism/test_matmul_batch_invariant.py
+105
-0
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+24
-31
No files found.
tests/v1/determinism/test_matmul_batch_invariant.py
0 → 100644
View file @
b5f6c5f8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test batch-invariant matmul against torch.matmul for various shape combinations.
Tests correctness (matches torch.matmul) and batch invariance (result for one
item doesn't change based on other items in the batch).
"""
import
pytest
import
torch
from
utils
import
skip_unsupported
from
vllm.model_executor.layers.batch_invariant
import
matmul_batch_invariant
from
vllm.platforms
import
current_platform
DEVICE_TYPE
=
current_platform
.
device_type
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"a_shape,b_shape"
,
[
# 2D x 2D
((
32
,
64
),
(
64
,
16
)),
# 2D x 3D
((
64
,
16
),
(
4
,
16
,
32
)),
# 3D x 2D
((
4
,
32
,
64
),
(
64
,
16
)),
# 4D x 2D
((
1
,
4
,
32
,
64
),
(
64
,
16
)),
# 3D x 3D
((
4
,
32
,
64
),
(
4
,
64
,
16
)),
# 3D x 4D
((
2
,
32
,
64
),
(
1
,
2
,
64
,
16
)),
# 4D x 3D (Gemma4 pattern)
((
1
,
2
,
32
,
64
),
(
2
,
64
,
16
)),
# 4D x 4D
((
1
,
2
,
32
,
64
),
(
4
,
2
,
64
,
16
)),
# 2D x 4D
((
32
,
64
),
(
1
,
2
,
64
,
16
)),
# 2D x 5D
((
32
,
64
),
(
1
,
2
,
2
,
64
,
16
)),
# 5D x 2D
((
1
,
2
,
2
,
32
,
64
),
(
64
,
16
)),
# 5D x 5D
((
1
,
2
,
4
,
32
,
64
),
(
1
,
2
,
4
,
64
,
16
)),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
def
test_matmul_correctness
(
a_shape
,
b_shape
,
dtype
):
"""
Compare matmul_batch_invariant against torch.matmul for various shapes.
"""
device
=
torch
.
device
(
DEVICE_TYPE
)
torch
.
manual_seed
(
42
)
a
=
torch
.
rand
(
a_shape
,
dtype
=
dtype
,
device
=
device
)
b
=
torch
.
rand
(
b_shape
,
dtype
=
dtype
,
device
=
device
)
# Standard implementation (CUDA ops)
standard_output
=
torch
.
matmul
(
a
,
b
)
# Batch-invariant implementation (Triton)
triton_output
=
matmul_batch_invariant
(
a
,
b
)
# Compare outputs
# Use looser tolerance for bfloat16 due to its lower precision
if
dtype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-1
,
1e-1
# 10% relative tolerance for bfloat16
else
:
rtol
,
atol
=
1e-2
,
1e-2
# 1% for float16/float32
torch
.
testing
.
assert_close
(
triton_output
,
standard_output
,
rtol
=
rtol
,
atol
=
atol
,
msg
=
f
"matmul mismatch for a ndim=
{
a
.
ndim
}
, b ndim=
{
b
.
ndim
}
,"
,
)
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
def
test_matmul_batch_invariance
(
dtype
):
"""
Verify that the result for one item is bitwise identical regardless
of what other items are in the batch.
"""
device
=
torch
.
device
(
DEVICE_TYPE
)
torch
.
manual_seed
(
42
)
a_single
=
torch
.
rand
((
1
,
64
,
32
),
dtype
=
dtype
,
device
=
device
)
b
=
torch
.
rand
((
32
,
128
),
dtype
=
dtype
,
device
=
device
)
standard_output
=
matmul_batch_invariant
(
a_single
,
b
)
a_batch
=
torch
.
rand
((
8
,
64
,
32
),
dtype
=
dtype
,
device
=
device
)
a_batch
[
3
]
=
a_single
[
0
]
batch_output
=
matmul_batch_invariant
(
a_batch
,
b
)
batch_output_a
=
batch_output
[
3
]
assert
torch
.
equal
(
standard_output
[
0
],
batch_output_a
)
vllm/model_executor/layers/batch_invariant.py
View file @
b5f6c5f8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
os
from
collections.abc
import
Callable
from
typing
import
Any
...
...
@@ -611,51 +612,43 @@ def matmul_batch_invariant(a, b, *, out=None):
out
.
copy_
(
result
)
return
out
return
result
elif
a
.
ndim
==
3
and
b
.
ndim
==
3
:
# Handle batched case like bmm
return
bmm_batch_invariant
(
a
,
b
,
out
=
out
)
elif
a
.
ndim
==
3
and
b
.
ndim
==
2
:
# Handle 3D x 2D: common for linear layers
# (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out)
# Reshape to 2D, do mm, reshape back
batch
,
seq
,
hidden
=
a
.
shape
elif
b
.
ndim
==
2
:
# Handle ND x 2D: Common for linear layers
# (..., batch, seq, hidden) @ (hidden, out) -> (..., batch, seq, out)
batch_dims
=
a
.
shape
[:
-
1
]
hidden
=
a
.
shape
[
-
1
]
out_dim
=
b
.
shape
[
-
1
]
a_2d
=
a
.
reshape
(
-
1
,
hidden
)
result_2d
=
matmul_persistent
(
a_2d
,
b
)
result
=
result_2d
.
reshape
(
batch
,
seq
,
-
1
)
result
=
result_2d
.
reshape
(
batch
_dims
+
(
out_dim
,)
)
if
out
is
not
None
:
out
.
copy_
(
result
)
return
out
return
result
elif
a
.
ndim
==
2
and
b
.
ndim
==
3
:
# Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N)
# By broadcasting `a` to 3D, we can reuse the batched matrix
# multiplication logic.
a_expanded
=
a
.
unsqueeze
(
0
).
expand
(
b
.
shape
[
0
],
-
1
,
-
1
)
return
bmm_batch_invariant
(
a_expanded
,
b
,
out
=
out
)
elif
a
.
ndim
==
4
and
b
.
ndim
==
4
:
# Handle 4D attention tensors: [batch, heads, seq, dim]
# Reshape to 3D, process, reshape back
batch
,
heads
,
seq_a
,
dim_a
=
a
.
shape
_
,
_
,
dim_b
,
seq_b
=
b
.
shape
# Reshape to [batch*heads, seq_a, dim_a]
a_3d
=
a
.
reshape
(
batch
*
heads
,
seq_a
,
dim_a
)
b_3d
=
b
.
reshape
(
batch
*
heads
,
dim_b
,
seq_b
)
elif
a
.
ndim
>=
2
and
b
.
ndim
>=
3
:
# Generic handler for 2D x ND and ND x ND (except 1D)
# Broadcast dims to ensure both matrices have the same shape
# If 2D x ND, then unsqueeze to add a dim to a
if
a
.
ndim
==
2
:
a
=
a
.
unsqueeze
(
0
)
broadcast_shape
=
torch
.
broadcast_shapes
(
a
.
shape
[:
-
2
],
b
.
shape
[:
-
2
])
a
=
a
.
expand
(
broadcast_shape
+
a
.
shape
[
-
2
:])
b
=
b
.
expand
(
broadcast_shape
+
b
.
shape
[
-
2
:])
batch_dim
=
math
.
prod
(
broadcast_shape
)
# Reuse broadcast shape to get all dims except mm dims
a_3d
=
a
.
reshape
(
batch_dim
,
a
.
shape
[
-
2
],
a
.
shape
[
-
1
])
b_3d
=
b
.
reshape
(
batch_dim
,
b
.
shape
[
-
2
],
b
.
shape
[
-
1
])
# Do batched matmul
result_3d
=
bmm_batch_invariant
(
a_3d
,
b_3d
)
# Reshape back to [batch, heads, seq_a, seq_b]
result
=
result_3d
.
reshape
(
batch
,
heads
,
seq_a
,
seq_b
)
# Reshape back to [broadcast_shape, seq_a, seq_b]
result
=
result_3d
.
reshape
(
broadcast_shape
+
(
a
.
shape
[
-
2
],
b
.
shape
[
-
1
]))
if
out
is
not
None
:
out
.
copy_
(
result
)
return
out
return
result
else
:
raise
ValueError
(
f
"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, "
f
"3D x 2D, 2D x 3D, and 4D x 4D, "
f
"matmul_batch_invariant requires both inputs be at least 2D "
f
"got shapes
{
a
.
shape
}
and
{
b
.
shape
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment