Unverified Commit 0194948f authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

Optimize Triton Kernel of Group GEMM in DeepGEMM Benchmark (#4014)

parent b4d34cd3
...@@ -115,17 +115,17 @@ def fp8_gemm_group_triton_kernel( ...@@ -115,17 +115,17 @@ def fp8_gemm_group_triton_kernel(
): ):
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors. """Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
A has shape (M, K), B has shape (K, N) and C has shape (M, N) A has shape (M, K), B has shape (K, N) and C has shape (M, N)
Note: Block sizes must be multiples of 32 for optimal TMA performance.
""" """
# Map program ids to the block of C it should compute # Map program ids to the block of C it should compute
pid = tl.program_id(axis=0) pid_group = tl.program_id(axis=0) # Group ID
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) pid_n = tl.program_id(axis=1) # N dimension ID
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n # Compute the M block ID within this group
group_id = pid // num_pid_in_group group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M)
first_pid_m = group_id * GROUP_SIZE_M pid_m_within_group = tl.program_id(axis=2) % group_size_m
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# Create pointers for the first blocks of A and B # Create pointers for the first blocks of A and B
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
...@@ -153,20 +153,15 @@ def fp8_gemm_group_triton_kernel( ...@@ -153,20 +153,15 @@ def fp8_gemm_group_triton_kernel(
pid_n * stride_b_scale_n + k_block * stride_b_scale_k pid_n * stride_b_scale_n + k_block * stride_b_scale_k
) )
# Perform matrix multiplication in FP8
res = tl.dot(a, b)
# Load scaling factors for the current block # Load scaling factors for the current block
a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1] a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1]
b_scale = tl.load(b_scale_ptrs) b_scale = tl.load(b_scale_ptrs)
# Convert FP8 to FP32 for computation # Apply scaling factors to the accumulated result
a = a.to(tl.float32) accumulator += res * a_scale * b_scale
b = b.to(tl.float32)
# Apply scaling factors to the current block
a = a * a_scale
b = b * b_scale
# Accumulate matmul for the current block
accumulator += tl.dot(a, b)
# Advance pointers # Advance pointers
a_ptrs += BLOCK_SIZE_K * stride_ak a_ptrs += BLOCK_SIZE_K * stride_ak
...@@ -183,13 +178,14 @@ def fp8_gemm_group_triton_kernel( ...@@ -183,13 +178,14 @@ def fp8_gemm_group_triton_kernel(
tl.store(c_ptrs, c, mask=c_mask) tl.store(c_ptrs, c, mask=c_mask)
def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups): def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
""" """
Perform matrix multiplication with FP8 inputs and proper scaling. Perform matrix multiplication with FP8 inputs and proper scaling.
Args: Args:
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
c: Output tensor in BF16 format
num_groups: Number of groups for grouped GEMM num_groups: Number of groups for grouped GEMM
Returns: Returns:
...@@ -199,32 +195,21 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups): ...@@ -199,32 +195,21 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups):
a, a_scale = a_tuple a, a_scale = a_tuple
b, b_scale = b_tuple b, b_scale = b_tuple
# Check constraints
assert a.shape[1] == b.shape[1], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape M, K = a.shape
N, K_b = b.shape _, N = b.shape
assert K == K_b, f"Incompatible K dimensions: {K} vs {K_b}"
# Transpose b to match kernel expectations (K,N format)
b = b.T.contiguous()
# Allocate output in bfloat16 (not float16) # Configure block sizes - must be multiples of 32 for TMA alignment
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) BLOCK_SIZE_M = 128
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 128
# Prepare scale factors # Calculate grid dimensions
# Ensure scales are in the right format and contiguous num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
a_scale = a_scale.contiguous() num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)
b_scale = b_scale.contiguous() num_groups_grid = triton.cdiv(num_pid_m, num_groups)
# 1D launch kernel # 3D grid launch - (group, n_blocks, m_blocks_per_group)
grid = lambda META: ( grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m))
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
# Calculate K blocks (128 elements per block)
K_blocks = triton.cdiv(K, 128)
fp8_gemm_group_triton_kernel[grid]( fp8_gemm_group_triton_kernel[grid](
a, a,
...@@ -245,9 +230,9 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups): ...@@ -245,9 +230,9 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups):
1, # Stride in the K dimension may be 1 1, # Stride in the K dimension may be 1
b_scale.stride(0), b_scale.stride(0),
1 if b_scale.dim() > 1 else 0, 1 if b_scale.dim() > 1 else 0,
BLOCK_SIZE_M=128, BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=128, BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=128, BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=num_groups, GROUP_SIZE_M=num_groups,
) )
...@@ -264,52 +249,6 @@ def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices): ...@@ -264,52 +249,6 @@ def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
return out return out
def get_weight_shapes(tp_size):
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
weight_shapes = []
for t in total:
weight_shapes.append(t)
for n_t in n_tp:
new_t = (n_t[0] // tp_size, n_t[1])
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = (k_t[0], k_t[1] // tp_size)
weight_shapes.append(new_t)
return weight_shapes
def create_benchmark_configs(tp_size):
configs = []
weight_shapes = get_weight_shapes(tp_size)
batch_sizes = [2048, 4096]
group_sizes = [4, 8]
for n, k in weight_shapes:
for m in batch_sizes:
for num_groups in group_sizes:
configs.append((m, n, k, num_groups, tp_size))
return configs
def calculate_diff(m: int, n: int, k: int, num_groups: int): def calculate_diff(m: int, n: int, k: int, num_groups: int):
print(f"Shape (m={m}, n={n}, k={k}") print(f"Shape (m={m}, n={n}, k={k}")
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
...@@ -332,8 +271,16 @@ def calculate_diff(m: int, n: int, k: int, num_groups: int): ...@@ -332,8 +271,16 @@ def calculate_diff(m: int, n: int, k: int, num_groups: int):
) )
torch.cuda.synchronize() torch.cuda.synchronize()
# Quantized x and y # Prepare inputs for Triton
out_triton = fp8_gemm_group_triton(x_fp8_flat, y_fp8_flat, num_groups) a, a_scale = x_fp8_flat
b, b_scale = y_fp8_flat
b = b.T.contiguous()
# Ensure scales are in the right format and contiguous
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
M, _ = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups)
torch.cuda.synchronize() torch.cuda.synchronize()
diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item() diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item()
...@@ -369,6 +316,52 @@ def calculate_diff(m: int, n: int, k: int, num_groups: int): ...@@ -369,6 +316,52 @@ def calculate_diff(m: int, n: int, k: int, num_groups: int):
) )
def get_weight_shapes(tp_size):
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
weight_shapes = []
for t in total:
weight_shapes.append(t)
for n_t in n_tp:
new_t = (n_t[0] // tp_size, n_t[1])
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = (k_t[0], k_t[1] // tp_size)
weight_shapes.append(new_t)
return weight_shapes
def create_benchmark_configs(tp_size):
configs = []
weight_shapes = get_weight_shapes(tp_size)
batch_sizes = [2048, 4096]
group_sizes = [4, 8]
for n, k in weight_shapes:
for m in batch_sizes:
for num_groups in group_sizes:
configs.append((m, n, k, num_groups, tp_size))
return configs
def get_benchmark(tp_size): def get_benchmark(tp_size):
all_configs = create_benchmark_configs(tp_size) all_configs = create_benchmark_configs(tp_size)
...@@ -416,10 +409,21 @@ def get_benchmark(tp_size): ...@@ -416,10 +409,21 @@ def get_benchmark(tp_size):
quantiles=quantiles, quantiles=quantiles,
) )
elif provider == "triton": elif provider == "triton":
# Prepare inputs for Triton
# We did it outside of the lambda function to make it fair comparison like deepgemm
a, a_scale = x_fp8_flat
b, b_scale = y_fp8_flat
b = b.T.contiguous()
# Ensure scales are in the right format and contiguous
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
M, _ = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_group_triton( lambda: fp8_gemm_group_triton(
x_fp8_flat, (a, a_scale),
y_fp8_flat, (b, b_scale),
c,
num_groups, num_groups,
), ),
quantiles=quantiles, quantiles=quantiles,
...@@ -429,13 +433,8 @@ def get_benchmark(tp_size): ...@@ -429,13 +433,8 @@ def get_benchmark(tp_size):
flops = 2 * m * n * k # multiply-adds flops = 2 * m * n * k # multiply-adds
tflops = flops / (ms * 1e-3) / 1e12 tflops = flops / (ms * 1e-3) / 1e12
# Print shape-specific results with TFLOPS print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
print(f"Time: {ms:.2f} ms, TFLOPS: {tflops:.2f}") return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
return (
ms,
max_ms,
min_ms,
) # return in seconds for consistency with triton benchmark
return benchmark return benchmark
...@@ -478,6 +477,7 @@ if __name__ == "__main__": ...@@ -478,6 +477,7 @@ if __name__ == "__main__":
calculate_diff(8192, 2048, 7168, 4) calculate_diff(8192, 2048, 7168, 4)
calculate_diff(4096, 7168, 4096, 8) calculate_diff(4096, 7168, 4096, 8)
calculate_diff(4096, 2048, 7168, 8) calculate_diff(4096, 2048, 7168, 8)
calculate_diff(4096, 576, 7168, 8)
# Get the benchmark function with the specified tp_size # Get the benchmark function with the specified tp_size
benchmark = get_benchmark(args.tp_size) benchmark = get_benchmark(args.tp_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment