Unverified Commit c14cc47e authored by Minglei Zhu's avatar Minglei Zhu Committed by GitHub
Browse files

[Deterministic] Optimize bmm_batch_invariant op (#12522)

parent dbcf85b7
......@@ -559,19 +559,215 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
@triton.jit
def bmm_kernel_persistent(
a_ptr,
b_ptr,
c_ptr, #
B,
M,
N,
K, #
stride_ab,
stride_am,
stride_ak,
stride_bb,
stride_bk,
stride_bn,
stride_cb,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr, #
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
NUM_SMS: tl.constexpr, #
A_LARGE: tl.constexpr,
B_LARGE: tl.constexpr,
C_LARGE: tl.constexpr,
):
"""
Batched matrix multiplication kernel that processes batches in parallel.
Each tile processes a (BLOCK_SIZE_M, BLOCK_SIZE_N) output block for a specific batch.
"""
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles_per_batch = num_pid_m * num_pid_n
num_tiles_total = B * num_tiles_per_batch
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# Process tiles in a deterministic order: batch-major ordering
for tile_id in tl.range(start_pid, num_tiles_total, NUM_SMS, flatten=True):
# Decompose tile_id into batch and within-batch tile
batch_idx = tile_id // num_tiles_per_batch
tile_in_batch = tile_id % num_tiles_per_batch
pid_m, pid_n = _compute_pid(
tile_in_batch, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
if A_LARGE:
offs_am = offs_am.to(tl.int64)
if B_LARGE:
offs_bn = offs_bn.to(tl.int64)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
# Add batch offset
if A_LARGE or B_LARGE:
batch_idx_typed = batch_idx.to(tl.int64)
else:
batch_idx_typed = batch_idx
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
if A_LARGE or B_LARGE:
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
else:
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
batch_idx_typed * stride_ab
+ offs_am[:, None] * stride_am
+ offs_k[None, :] * stride_ak
)
b_ptrs = b_ptr + (
batch_idx_typed * stride_bb
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
)
a = tl.load(
a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0
)
b = tl.load(
b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0
)
accumulator = tl.dot(a, b, accumulator)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if C_LARGE:
offs_cm = offs_cm.to(tl.int64)
offs_cn = offs_cn.to(tl.int64)
c_ptrs = (
c_ptr
+ batch_idx_typed * stride_cb
+ stride_cm * offs_cm[:, None]
+ stride_cn * offs_cn[None, :]
)
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
if c_ptr.dtype.element_ty == tl.float8e4nv:
c = accumulator.to(tl.float8e4nv)
elif c_ptr.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif c_ptr.dtype.element_ty == tl.float32:
c = accumulator.to(tl.float32)
else:
c = accumulator.to(tl.float16)
tl.store(c_ptrs, c, mask=c_mask)
def bmm_batch_invariant(a, b, *, out=None):
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
# Process each batch separately with our persistent kernel
# Process batches in parallel with our persistent kernel
if a.ndim == 3 and b.ndim == 3:
results = []
for i in range(a.shape[0]):
results.append(matmul_persistent(a[i], b[i]))
result = torch.stack(results, dim=0)
if out is not None:
out.copy_(result)
return out
return result
# Check constraints
assert a.shape[0] == b.shape[0], "Batch sizes must match"
assert a.shape[2] == b.shape[1], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
B = a.shape[0]
M = a.shape[1]
K = a.shape[2]
N = b.shape[2]
dtype = a.dtype
# Allocate output
if out is None:
c = torch.empty((B, M, N), device=a.device, dtype=dtype)
else:
c = out
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
# Use fixed kernel configuration for determinism
configs = {
torch.bfloat16: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 3,
"num_warps": 8,
},
torch.float16: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 3,
"num_warps": 8,
},
torch.float32: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 3,
"num_warps": 8,
},
}
config = configs.get(dtype)
if config is None:
raise ValueError(
f"Unsupported dtype {dtype} for bmm_batch_invariant. "
f"Supported dtypes are: {list(configs.keys())}"
)
# Grid: limit by NUM_SMS for persistent kernel approach
num_tiles_per_batch = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
num_tiles_total = B * num_tiles_per_batch
grid = (min(NUM_SMS, num_tiles_total),)
bmm_kernel_persistent[grid](
a,
b,
c, #
B,
M,
N,
K, #
a.stride(0),
a.stride(1),
a.stride(2), #
b.stride(0),
b.stride(1),
b.stride(2), #
c.stride(0),
c.stride(1),
c.stride(2), #
NUM_SMS=NUM_SMS, #
A_LARGE=a.numel() > 2**31,
B_LARGE=b.numel() > 2**31,
C_LARGE=c.numel() > 2**31,
**config,
)
return c
else:
raise ValueError(
f"bmm_batch_invariant expects 3D tensors, "
......
......@@ -167,6 +167,92 @@ class TestBatchInvariantOps(CustomTestCase):
)
print(f"Without batch-invariant mode, we get diffs: {difflist}")
def _test_bmm_batch_invariance(self, B, M, K, N, dtype):
"""
Test that BMM operations produce identical results for:
- Method 1: BMM with subset of batches
- Method 2: BMM with all batches, then slice
"""
a = torch.linspace(-100, 100, B * M * K, dtype=dtype).reshape(B, M, K)
b = torch.linspace(-100, 100, B * K * N, dtype=dtype).reshape(B, K, N)
# Method 1: BMM with subset (first 2 batches)
subset_size = min(2, B)
out1 = torch.bmm(a[:subset_size], b[:subset_size])
# Method 2: BMM with all batches, then slice
out2_pre = torch.bmm(a, b)
out2 = out2_pre[:subset_size]
# Check if results are identical
diff = (out1 - out2).abs().max()
return diff.item()
def _run_bmm_multiple_iterations(self, iters, B, M, K, N, dtype):
"""Run multiple BMM iterations and collect diff statistics"""
difflist = []
for _ in range(iters):
diff = self._test_bmm_batch_invariance(B, M, K, N, dtype)
difflist.append(diff)
return difflist
def test_bmm_small_matrices(self):
"""Test BMM batch invariance with small matrix sizes"""
test_cases = [
("BMM-Small-1", 4, 8, 64, 128),
("BMM-Small-2", 8, 16, 128, 256),
("BMM-Small-3", 6, 4, 32, 64),
]
for name, B, M, K, N in test_cases:
with self.subTest(name=name, B=B, M=M, K=K, N=N):
for dtype in [torch.float32, torch.bfloat16]:
with self.subTest(dtype=dtype):
# Run with batch-invariant mode
with set_batch_invariant_mode(True):
difflist = self._run_bmm_multiple_iterations(
iters=5, B=B, M=M, K=K, N=N, dtype=dtype
)
self._assert_batch_invariant_results(difflist, dtype, name)
def test_bmm_medium_matrices(self):
"""Test BMM batch invariance with medium matrix sizes"""
test_cases = [
("BMM-Medium-1", 8, 32, 128, 1024),
("BMM-Medium-2", 16, 64, 512, 2048),
("BMM-Medium-3", 12, 24, 192, 768),
]
for name, B, M, K, N in test_cases:
with self.subTest(name=name, B=B, M=M, K=K, N=N):
for dtype in [torch.float32, torch.bfloat16]:
with self.subTest(dtype=dtype):
# Run with batch-invariant mode
with set_batch_invariant_mode(True):
difflist = self._run_bmm_multiple_iterations(
iters=5, B=B, M=M, K=K, N=N, dtype=dtype
)
self._assert_batch_invariant_results(difflist, dtype, name)
def test_bmm_large_matrices(self):
"""Test BMM batch invariance with large matrix sizes"""
test_cases = [
("BMM-Large-1", 16, 128, 1024, 4096),
("BMM-Large-2", 32, 256, 2048, 8192),
("BMM-Large-3", 24, 96, 768, 3072),
]
for name, B, M, K, N in test_cases:
with self.subTest(name=name, B=B, M=M, K=K, N=N):
for dtype in [torch.float32, torch.bfloat16]:
with self.subTest(dtype=dtype):
# Run with batch-invariant mode
with set_batch_invariant_mode(True):
difflist = self._run_bmm_multiple_iterations(
iters=5, B=B, M=M, K=K, N=N, dtype=dtype
)
self._assert_batch_invariant_results(difflist, dtype, name)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment