Unverified Commit eae9a9fb authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

Fix batch invariant ops (#11368)

parent 2674c1d2
......@@ -77,8 +77,6 @@ def matmul_kernel_persistent(
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
tile_id_c = start_pid - NUM_SMS
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
......@@ -120,10 +118,6 @@ def matmul_kernel_persistent(
)
accumulator = tl.dot(a, b, accumulator)
tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if C_LARGE:
......@@ -137,6 +131,10 @@ def matmul_kernel_persistent(
accumulator += bias
if c_ptr.dtype.element_ty == tl.float8e4nv:
c = accumulator.to(tl.float8e4nv)
elif c_ptr.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif c_ptr.dtype.element_ty == tl.float32:
c = accumulator.to(tl.float32)
else:
c = accumulator.to(tl.float16)
tl.store(c_ptrs, c, mask=c_mask)
......
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/test_batch_invariance.py
import math
import unittest
import torch
from sglang.srt.batch_invariant_ops.batch_invariant_ops import set_batch_invariant_mode
from sglang.test.test_utils import CustomTestCase
device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu")
torch.set_default_device(device_type)
# Just to get the logging out of the way
with set_batch_invariant_mode(True):
pass
class TestBatchInvariantOps(CustomTestCase):
def _test_batch_invariance(self, M, K, N, dtype):
"""
Test that matrix operations produce identical results for:
- Method 1: Matrix-vector multiplication (batch size 1)
- Method 2: Matrix-matrix multiplication, then slice (full batch)
"""
a = torch.linspace(-100, 100, M * K, dtype=dtype).reshape(M, K)
# Create non-contiguous tensor
b = torch.linspace(-100, 100, K * N, dtype=dtype).reshape(N, K)
b = b.transpose(0, 1)
# Method 1: Matrix-vector multiplication (batch size 1)
out1 = torch.mm(a[:1], b)
# Method 2: Matrix-matrix multiplication, then slice (full batch)
out2_pre = torch.mm(a, b)
out2 = out2_pre[:1]
# Check if results are identical
diff = (out1 - out2).abs().max()
return diff.item()
def _run_multiple_iterations(self, iters, M, K, N, dtype):
"""Run multiple iterations and collect diff statistics"""
difflist = []
for _ in range(iters):
diff = self._test_batch_invariance(M, K, N, dtype)
difflist.append(diff)
return difflist
def _assert_batch_invariant_results(self, difflist, dtype, test_name):
"""
Assert that in batch-invariant mode:
1. All diffs must not be NaN
2. All diffs must be exactly 0
3. Max, min, and diff of diffs must all be 0
"""
max_diff = max(difflist)
min_diff = min(difflist)
diff_range = max_diff - min_diff
# Check for NaN values
self.assertFalse(
math.isnan(max_diff), f"{test_name}: max_diff is NaN for {dtype}"
)
self.assertFalse(
math.isnan(min_diff), f"{test_name}: min_diff is NaN for {dtype}"
)
self.assertFalse(
math.isnan(diff_range), f"{test_name}: diff_range is NaN for {dtype}"
)
# Check that all diffs are exactly 0
self.assertEqual(
max_diff,
0.0,
f"{test_name}: max_diff must be 0 in batch-invariant mode, got {max_diff} for {dtype}",
)
self.assertEqual(
min_diff,
0.0,
f"{test_name}: min_diff must be 0 in batch-invariant mode, got {min_diff} for {dtype}",
)
self.assertEqual(
diff_range,
0.0,
f"{test_name}: diff_range must be 0 in batch-invariant mode, got {diff_range} for {dtype}",
)
def test_small_matrices(self):
"""Test batch invariance with small matrix sizes"""
test_cases = [
("Small-1", 8, 64, 128),
("Small-2", 16, 128, 256),
("Small-3", 4, 32, 64),
]
for name, M, K, N in test_cases:
with self.subTest(name=name, M=M, K=K, N=N):
for dtype in [torch.float32, torch.bfloat16]:
with self.subTest(dtype=dtype):
# Run with batch-invariant mode
with set_batch_invariant_mode(True):
difflist = self._run_multiple_iterations(
iters=5, M=M, K=K, N=N, dtype=dtype
)
self._assert_batch_invariant_results(difflist, dtype, name)
def test_medium_matrices(self):
"""Test batch invariance with medium matrix sizes"""
test_cases = [
("Medium-1", 32, 128, 1024),
("Medium-2", 64, 512, 2048),
("Medium-3", 24, 192, 768),
]
for name, M, K, N in test_cases:
with self.subTest(name=name, M=M, K=K, N=N):
for dtype in [torch.float32, torch.bfloat16]:
with self.subTest(dtype=dtype):
# Run with batch-invariant mode
with set_batch_invariant_mode(True):
difflist = self._run_multiple_iterations(
iters=5, M=M, K=K, N=N, dtype=dtype
)
self._assert_batch_invariant_results(difflist, dtype, name)
def test_large_matrices(self):
"""Test batch invariance with large matrix sizes"""
test_cases = [
("Large-1", 128, 1024, 4096),
("Large-2", 256, 2048, 8192),
("Large-3", 96, 768, 3072),
]
for name, M, K, N in test_cases:
with self.subTest(name=name, M=M, K=K, N=N):
for dtype in [torch.float32, torch.bfloat16]:
with self.subTest(dtype=dtype):
# Run with batch-invariant mode
with set_batch_invariant_mode(True):
difflist = self._run_multiple_iterations(
iters=5, M=M, K=K, N=N, dtype=dtype
)
self._assert_batch_invariant_results(difflist, dtype, name)
def test_without_batch_invariant_mode(self):
"""
Test that without batch-invariant mode, results may differ.
This test demonstrates the difference batch-invariant mode makes.
"""
M, K, N = 32, 128, 1024
dtype = torch.float32
# Run without batch-invariant mode
with set_batch_invariant_mode(False):
difflist = self._run_multiple_iterations(
iters=5, M=M, K=K, N=N, dtype=dtype
)
print(f"Without batch-invariant mode, we get diffs: {difflist}")
if __name__ == "__main__":
unittest.main()
......@@ -33,6 +33,7 @@ suites = {
TestFile("models/test_generation_models.py", 103),
TestFile("models/test_nvidia_nemotron_nano_v2.py", 180),
TestFile("models/test_qwen_models.py", 82),
TestFile("batch_invariant/test_batch_invariant_ops.py", 10),
TestFile("models/test_reward_models.py", 132),
TestFile("models/test_transformers_models.py", 320),
TestFile("models/test_vlm_models.py", 741),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment