Unverified Commit 8ecad0b1 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[benchmark] fbgemm benchmark support bandwidth report and support fbgemm_cutlass_gmm (#7422)

parent 7151194b
## Benchmark FBGEMM Grouped GEMM
Benchmark FBGEMM Grouped GEMM in both Triton and CUDA version and SGLang Triton Grouped GEMM, it will be used to compare the bandwidth of different implementations.
### Requirements
```shell
pip install fbgemm-gpu-genai
```
### Usage
```bash
python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
```
For example, in H200, the Qwen2-57B-A14B-Instruct TP4 fp8w8a8 grouped gemm bandwidth result is as follows:
```shell
grouped-gemm-performance:
batch_size FBGEMM Triton Grouped GEMM FP8 FBGEMM CUTLASS F8F8BF16 Rowwise SGLang Grouped GEMM FP8
0 256.0 3704.841339 3042.626402 2254.725030
1 512.0 3691.426346 3029.065684 2269.504543
2 1024.0 3653.938629 2258.471467 2358.319020
3 2048.0 3596.644313 2271.611904 2476.895397
4 4096.0 3468.496435 2231.283986 2179.473910
```
The theoretical peak bandwidth of H200 is 4.8 TB/s. Taking batch_size 256 as an example, the bandwidth of FBGEMM Triton Grouped GEMM FP8 is 3704.841339 GB/s, the bandwidth of FBGEMM CUTLASS F8F8BF16 Rowwise is 3042.626402 GB/s, and the bandwidth of SGLang Grouped GEMM FP8 is 2254.725030 GB/s. Therefore, FBGEMM Triton Grouped GEMM FP8 achieves 77.9% of H200's theoretical peak bandwidth, FBGEMM CUTLASS F8F8BF16 Rowwise achieves 63.4% of H200's theoretical peak bandwidth, and SGLang Grouped GEMM FP8 achieves 46.9% of H200's theoretical peak bandwidth.
# python3 benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8 # python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
import argparse import argparse
import torch import torch
import triton import triton
from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
from fbgemm_grouped_gemm import ( quantize_fp8_row,
triton_quantize_fp8_row,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm as fbgemm_grouped_gemm,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise, grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
) )
from transformers import AutoConfig from transformers import AutoConfig
...@@ -29,12 +35,11 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -29,12 +35,11 @@ def get_model_config(model_name: str, tp_size: int):
elif config.architectures[0] == "Qwen3MoeForCausalLM": elif config.architectures[0] == "Qwen3MoeForCausalLM":
num_groups = config.num_experts num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: elif config.architectures[0] in [
num_groups = ( "DeepseekV2ForCausalLM",
config.n_routed_experts + 1 "DeepseekV3ForCausalLM",
if config.architectures[0] in ["DeepseekV3ForCausalLM"] ]:
else config.n_routed_experts num_groups = config.n_routed_experts
)
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Llama4ForConditionalGeneration": elif config.architectures[0] == "Llama4ForConditionalGeneration":
num_groups = config.text_config.num_local_experts num_groups = config.text_config.num_local_experts
...@@ -65,7 +70,7 @@ def create_test_data(batch_size, num_groups, hidden_size, intermediate_size): ...@@ -65,7 +70,7 @@ def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
tokens_per_group = batch_size // num_groups tokens_per_group = batch_size // num_groups
m_sizes = torch.full( m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int64, device="cuda" (num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
) )
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda") x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
...@@ -84,11 +89,11 @@ def create_test_data(batch_size, num_groups, hidden_size, intermediate_size): ...@@ -84,11 +89,11 @@ def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda" batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
) )
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device="cuda") seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda")
for i in range(1, num_groups + 1): for i in range(1, num_groups + 1):
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
weight_indices = torch.arange(num_groups, dtype=torch.int64, device="cuda") weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda")
return ( return (
x, x,
...@@ -102,39 +107,144 @@ def create_test_data(batch_size, num_groups, hidden_size, intermediate_size): ...@@ -102,39 +107,144 @@ def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
) )
def create_fp8_test_data(batch_size, num_groups, hidden_size, intermediate_size): def create_fp8_test_data(
batch_size, num_groups, hidden_size, intermediate_size, backend="triton"
):
"""
Create test data for FP8 grouped GEMM operations.
Args:
batch_size: Total batch size
num_groups: Number of groups
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM
Returns:
For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale)
For cutlass: (x, wq, w_scale, m_sizes)
"""
torch.manual_seed(42) torch.manual_seed(42)
tokens_per_group = batch_size // num_groups tokens_per_group = batch_size // num_groups
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int64, device="cuda"
)
x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device="cuda") # Create weight matrices for each group
w_fp16 = torch.randn( w_list = []
num_groups * intermediate_size, hidden_size, dtype=torch.float16, device="cuda" for _ in range(num_groups):
) w = torch.randn(
intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
)
w_list.append(w)
x_fp8 = x_fp16.to(torch.float8_e4m3fn) # Quantize weights using quantize_fp8_row for each group
w_fp8 = w_fp16.to(torch.float8_e4m3fn) wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list])
x_scale = torch.randn(batch_size, dtype=torch.float32, device="cuda").abs() + 1e-4 if backend == "triton":
w_scale = torch.randn(num_groups, dtype=torch.float32, device="cuda").abs() + 1e-4 # Triton format: concatenated weights
w_fp8 = torch.concat(wq_list, dim=0).contiguous()
w_scale = torch.concat(w_scale_list, dim=0).contiguous()
return x_fp8, w_fp8, m_sizes, x_scale, w_scale # Create m_sizes as int32 for triton
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
)
# Create and quantize input
x_fp16 = torch.randn(
batch_size, hidden_size, dtype=torch.float16, device="cuda"
)
x_fp8, x_scale = triton_quantize_fp8_row(x_fp16)
x_scale = x_scale.view(batch_size, -1)
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
elif backend == "cutlass":
# CUTLASS format: stacked weights
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
# Create m_sizes as int64 for cutlass
m_values = [tokens_per_group] * num_groups
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda")
# Create input data - separate for each group then concat
x_list = []
for _ in range(num_groups):
x = torch.randn(
tokens_per_group, hidden_size, dtype=torch.float16, device="cuda"
)
x_list.append(x)
# Concatenate inputs into single tensor
x = torch.concat(x_list, dim=0).contiguous()
return x, wq, w_scale, m_sizes
else:
raise ValueError(f"Unsupported backend: {backend}")
def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype):
"""
Calculate memory bandwidth based on accessed expert weights.
Args:
m_sizes: Tensor containing batch sizes for each group
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
dtype: Data type of weights
Returns:
Memory size in bytes for accessed expert weights
"""
# Count non-zero groups (active experts)
if hasattr(m_sizes, "cpu"):
active_experts = torch.count_nonzero(m_sizes).item()
else:
active_experts = sum(1 for m in m_sizes if m > 0)
# Calculate bytes per element based on dtype
if dtype in [torch.float16, torch.bfloat16]:
bytes_per_element = 2
elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
bytes_per_element = 1
elif dtype == torch.float32:
bytes_per_element = 4
else:
# Default to 2 bytes for unknown dtypes
bytes_per_element = 2
# Memory per expert weight matrix
memory_per_expert = hidden_size * intermediate_size * bytes_per_element
# Total memory for active experts
total_memory_bytes = active_experts * memory_per_expert
return total_memory_bytes
def get_benchmark_config(use_fp8_w8a8=False): def get_benchmark_config(use_fp8_w8a8=False):
if use_fp8_w8a8: if use_fp8_w8a8:
return { return {
"line_vals": ["fbgemm_grouped_gemm_fp8", "sglang_grouped_gemm"], "line_vals": [
"line_names": ["FBGEMM Grouped GEMM FP8", "SGLang Grouped GEMM FP8"], "fbgemm_triton_grouped_gemm_fp8",
"styles": [("blue", "-"), ("red", "-")], "fbgemm_cutlass_f8f8bf16_rowwise",
"sglang_grouped_gemm",
],
"line_names": [
"FBGEMM Triton Grouped GEMM FP8",
"FBGEMM CUTLASS F8F8BF16 Rowwise",
"SGLang Grouped GEMM FP8",
],
"styles": [("blue", "-"), ("orange", "-"), ("red", "-")],
} }
else: else:
return { return {
"line_vals": ["fbgemm_grouped_gemm", "sglang_grouped_gemm"], "line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"],
"line_names": ["FBGEMM Grouped GEMM BF16", "SGLang Grouped GEMM BF16"], "line_names": [
"FBGEMM Triton Grouped GEMM BF16",
"SGLang Grouped GEMM BF16",
],
"styles": [("blue", "-"), ("green", "-")], "styles": [("blue", "-"), ("green", "-")],
} }
...@@ -146,12 +256,12 @@ def run_benchmark( ...@@ -146,12 +256,12 @@ def run_benchmark(
benchmark_config = triton.testing.Benchmark( benchmark_config = triton.testing.Benchmark(
x_names=["batch_size"], x_names=["batch_size"],
x_vals=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], x_vals=[256, 512, 1024, 2048, 4096],
line_arg="provider", line_arg="provider",
line_vals=config["line_vals"], line_vals=config["line_vals"],
line_names=config["line_names"], line_names=config["line_names"],
styles=config["styles"], styles=config["styles"],
ylabel="Time (ms)", ylabel="Bandwidth (GB/s)",
plot_name="grouped-gemm-performance", plot_name="grouped-gemm-performance",
args={}, args={},
) )
...@@ -165,13 +275,22 @@ def run_benchmark( ...@@ -165,13 +275,22 @@ def run_benchmark(
hidden_size = model_config["hidden_size"] hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"] intermediate_size = model_config["intermediate_size"]
if provider == "fbgemm_grouped_gemm_fp8": if provider == "fbgemm_triton_grouped_gemm_fp8":
try: try:
test_data = create_fp8_test_data( test_data = create_fp8_test_data(
batch_size, num_groups, hidden_size, intermediate_size batch_size,
num_groups,
hidden_size,
intermediate_size,
backend="triton",
) )
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
# Calculate memory bandwidth
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
)
def run_func(): def run_func():
return fbgemm_grouped_gemm_fp8_rowwise( return fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
...@@ -180,6 +299,38 @@ def run_benchmark( ...@@ -180,6 +299,38 @@ def run_benchmark(
except Exception as e: except Exception as e:
print(f"FP8 not supported, skipping: {e}") print(f"FP8 not supported, skipping: {e}")
return float("inf"), float("inf"), float("inf") return float("inf"), float("inf"), float("inf")
elif provider == "fbgemm_cutlass_f8f8bf16_rowwise":
try:
test_data = create_fp8_test_data(
batch_size,
num_groups,
hidden_size,
intermediate_size,
backend="cutlass",
)
x, wq, w_scale, m_sizes = test_data
# Calculate memory bandwidth
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
)
# Quantize input using triton_quantize_fp8_row
xq, x_scale = triton_quantize_fp8_row(x)
x_scale = x_scale.view(batch_size, -1)
def run_func():
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
xq, wq, x_scale, w_scale, m_sizes
)
except Exception as e:
print(
f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, "
f"skipping: {e}"
)
return float("inf"), float("inf"), float("inf")
else: else:
test_data = create_test_data( test_data = create_test_data(
batch_size, num_groups, hidden_size, intermediate_size batch_size, num_groups, hidden_size, intermediate_size
...@@ -195,7 +346,12 @@ def run_benchmark( ...@@ -195,7 +346,12 @@ def run_benchmark(
weight_indices, weight_indices,
) = test_data ) = test_data
if provider == "fbgemm_grouped_gemm": # Calculate memory bandwidth for BF16 operations
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.bfloat16
)
if provider == "fbgemm_triton_grouped_gemm":
def run_func(): def run_func():
return fbgemm_grouped_gemm( return fbgemm_grouped_gemm(
...@@ -228,10 +384,19 @@ def run_benchmark( ...@@ -228,10 +384,19 @@ def run_benchmark(
try: try:
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
return ms, min_ms, max_ms
# Convert time (ms) to bandwidth (GB/s)
# Bandwidth = Memory (bytes) / Time (seconds)
# Convert ms to seconds and bytes to GB (1e9)
gb_per_s = (memory_bytes / 1e9) / (ms / 1000)
# min bandwidth = max time, max bandwidth = min time
min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000)
max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000)
return gb_per_s, min_gb_per_s, max_gb_per_s
except Exception as e: except Exception as e:
print(f"Error during benchmarking for {provider}: {e}") print(f"Error during benchmarking for {provider}: {e}")
return float("inf"), float("inf"), float("inf") return 0.0, 0.0, 0.0
dynamic_benchmark.run( dynamic_benchmark.run(
show_plots=True, show_plots=True,
...@@ -242,7 +407,7 @@ def run_benchmark( ...@@ -242,7 +407,7 @@ def run_benchmark(
) )
def verify_correctness(model_config, use_fp8_w8a8): def verify_correctness(model_config):
print("Verifying correctness...") print("Verifying correctness...")
batch_size = 128 batch_size = 128
num_groups = model_config["num_groups"] num_groups = model_config["num_groups"]
...@@ -250,54 +415,39 @@ def verify_correctness(model_config, use_fp8_w8a8): ...@@ -250,54 +415,39 @@ def verify_correctness(model_config, use_fp8_w8a8):
intermediate_size = model_config["intermediate_size"] intermediate_size = model_config["intermediate_size"]
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size) test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
(x, w_fbgemm, w_sglang, c_fbgemm, c_sglang, m_sizes, seg_indptr, weight_indices) = ( (
test_data x,
) w_fbgemm,
w_sglang,
try: c_fbgemm,
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True) c_sglang,
m_sizes,
result_sglang = sglang_grouped_gemm( seg_indptr,
x, weight_indices,
w_sglang, ) = test_data
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
print("✓ BF16 Correctness verification passed!")
else:
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
return False
if use_fp8_w8a8:
try:
fp8_data = create_fp8_test_data(
batch_size, num_groups, hidden_size, intermediate_size
)
x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale = fp8_data
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale, use_fast_accum=True
)
assert result_fp8.shape == (batch_size, intermediate_size) result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
print("✓ FP8 functionality test passed!")
except Exception as e:
print(f"FP8 test failed (possibly unsupported): {e}")
return False
return True result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
except Exception as e: if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
print(f"✗ Error during correctness verification: {e}") print("✓ BF16 Correctness verification passed!")
else:
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
return False return False
return True
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -348,7 +498,7 @@ def main(): ...@@ -348,7 +498,7 @@ def main():
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}") print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
if args.verify_correctness: if args.verify_correctness:
if not verify_correctness(model_config, args.use_fp8_w8a8): if not verify_correctness(model_config):
print("Correctness verification failed. Exiting...") print("Correctness verification failed. Exiting...")
return return
......
This diff is collapsed.
import os
import sys
import pytest
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
try:
from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm
from fbgemm_grouped_gemm import (
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
)
FBGEMM_AVAILABLE = True
print("✓ Successfully imported FBGEMM grouped GEMM")
except ImportError as e:
print(f"✗ Failed to import FBGEMM grouped GEMM: {e}")
FBGEMM_AVAILABLE = False
try:
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton as sglang_grouped_gemm,
)
SGLANG_AVAILABLE = True
print("✓ Successfully imported SGLang grouped GEMM")
except ImportError as e:
print(f"✗ Failed to import SGLang grouped GEMM: {e}")
SGLANG_AVAILABLE = False
def create_uniform_groups(batch_size, num_groups, device):
tokens_per_group = batch_size // num_groups
return torch.full((num_groups,), tokens_per_group, dtype=torch.int64, device=device)
def create_non_uniform_groups(batch_size, num_groups, device):
remaining = batch_size
m_sizes = []
for i in range(num_groups - 1):
if remaining <= 1:
size = 1
else:
max_size = remaining - (num_groups - i - 1) + 1
size = torch.randint(1, max_size, (1,)).item()
m_sizes.append(size)
remaining -= size
m_sizes.append(remaining)
return torch.tensor(m_sizes, dtype=torch.int64, device=device)
def create_sglang_inputs(x, w, m_sizes, num_groups, intermediate_size, device):
batch_size = x.shape[0]
c_sglang = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device=device
)
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device=device)
current_pos = 0
for i, size in enumerate(m_sizes):
current_pos += size
seg_indptr[i + 1] = current_pos
weight_indices = torch.arange(num_groups, dtype=torch.int64, device=device)
w_sglang = w.view(num_groups, intermediate_size, -1)
return c_sglang, seg_indptr, weight_indices, w_sglang
def create_fp8_data(batch_size, num_groups, hidden_size, intermediate_size, device):
torch.manual_seed(42)
x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device=device)
w_fp16 = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.float16, device=device
)
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
w_fp8 = w_fp16.to(torch.float8_e4m3fn)
x_scale = torch.randn(batch_size, dtype=torch.float32, device=device).abs() + 1e-4
w_scale = torch.randn(num_groups, dtype=torch.float32, device=device).abs() + 1e-4
return x_fp8, w_fp8, x_scale, w_scale
@pytest.fixture
def device():
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
return torch.device("cuda")
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("num_groups", [2, 4, 8])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_uniform_groups(batch_size, num_groups, hidden_size, intermediate_size, device):
if batch_size % num_groups != 0:
pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")
torch.manual_seed(42)
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
x, w, m_sizes, num_groups, intermediate_size, device
)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
@pytest.mark.parametrize("batch_size", [63, 100, 127])
@pytest.mark.parametrize("num_groups", [3, 5, 7])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_non_uniform_groups(
batch_size, num_groups, hidden_size, intermediate_size, device
):
torch.manual_seed(42)
m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
x, w, m_sizes, num_groups, intermediate_size, device
)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
@pytest.mark.parametrize("batch_size,num_groups", [(64, 4), (128, 8), (256, 16)])
@pytest.mark.parametrize("hidden_size", [768, 2048, 4096])
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 8192])
def test_large_dimensions(
batch_size, num_groups, hidden_size, intermediate_size, device
):
torch.manual_seed(42)
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
x, w, m_sizes, num_groups, intermediate_size, device
)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.parametrize("batch_size", [32, 64])
@pytest.mark.parametrize("num_groups", [2, 4])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_fp8_uniform_groups(
batch_size, num_groups, hidden_size, intermediate_size, device
):
if batch_size % num_groups != 0:
pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")
torch.manual_seed(42)
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
batch_size, num_groups, hidden_size, intermediate_size, device
)
try:
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)
assert result_fp8.shape == (batch_size, intermediate_size)
assert result_fp8.dtype == torch.bfloat16
except Exception as e:
pytest.skip(f"FP8 test failed (possibly unsupported): {e}")
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.parametrize("batch_size", [63, 100])
@pytest.mark.parametrize("num_groups", [3, 5])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_fp8_non_uniform_groups(
batch_size, num_groups, hidden_size, intermediate_size, device
):
torch.manual_seed(42)
m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
batch_size, num_groups, hidden_size, intermediate_size, device
)
try:
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)
assert result_fp8.shape == (batch_size, intermediate_size)
assert result_fp8.dtype == torch.bfloat16
except Exception as e:
pytest.skip(f"FP8 test failed (possibly unsupported): {e}")
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
def test_fbgemm_only_uniform(device):
torch.manual_seed(42)
batch_size, num_groups = 64, 4
hidden_size, intermediate_size = 512, 1024
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
result = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
assert result.shape == (batch_size, intermediate_size)
assert result.dtype == torch.bfloat16
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
def test_sglang_only_uniform(device):
torch.manual_seed(42)
batch_size, num_groups = 64, 4
hidden_size, intermediate_size = 512, 1024
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
x, w, m_sizes, num_groups, intermediate_size, device
)
result = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
assert result.shape == (batch_size, intermediate_size)
assert result.dtype == torch.bfloat16
def test_imports():
assert (
FBGEMM_AVAILABLE or SGLANG_AVAILABLE
), "Neither FBGEMM nor SGLang is available"
if __name__ == "__main__":
pytest.main([__file__, "-v"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment