Unverified Commit bae4fdc7 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

add fbgemm moe grouped gemm kernel benchmark (#6924)

parent 6153f2ff
# python3 benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
import argparse
import torch
import triton
from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm
from fbgemm_grouped_gemm import (
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
)
from transformers import AutoConfig
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton as sglang_grouped_gemm,
)
def get_model_config(model_name: str, tp_size: int):
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
num_groups = config.ffn_config.moe_num_experts
intermediate_size = config.ffn_config.ffn_hidden_size
elif config.architectures[0] == "JambaForCausalLM":
num_groups = config.num_experts
intermediate_size = config.intermediate_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
num_groups = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
num_groups = config.text_config.num_local_experts
intermediate_size = config.text_config.intermediate_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
num_groups = config.num_local_experts
intermediate_size = config.moe_intermediate_size
else:
num_groups = config.num_local_experts
intermediate_size = config.intermediate_size
shape_configs = {
"num_groups": num_groups,
"hidden_size": config.hidden_size,
"intermediate_size": intermediate_size,
"dtype": config.torch_dtype,
}
print(f"{shape_configs=}")
return shape_configs
def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
torch.manual_seed(42)
tokens_per_group = batch_size // num_groups
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int64, device="cuda"
)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
base_weights = torch.randn(
num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda"
)
w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size)
w_sglang = base_weights
c_fbgemm = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
c_sglang = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device="cuda")
for i in range(1, num_groups + 1):
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
weight_indices = torch.arange(num_groups, dtype=torch.int64, device="cuda")
return (
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
)
def create_fp8_test_data(batch_size, num_groups, hidden_size, intermediate_size):
torch.manual_seed(42)
tokens_per_group = batch_size // num_groups
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int64, device="cuda"
)
x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device="cuda")
w_fp16 = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
)
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
w_fp8 = w_fp16.to(torch.float8_e4m3fn)
x_scale = torch.randn(batch_size, dtype=torch.float32, device="cuda").abs() + 1e-4
w_scale = torch.randn(num_groups, dtype=torch.float32, device="cuda").abs() + 1e-4
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
def get_benchmark_config(use_fp8_w8a8=False):
if use_fp8_w8a8:
return {
"line_vals": ["fbgemm_grouped_gemm_fp8", "sglang_grouped_gemm"],
"line_names": ["FBGEMM Grouped GEMM FP8", "SGLang Grouped GEMM FP8"],
"styles": [("blue", "-"), ("red", "-")],
}
else:
return {
"line_vals": ["fbgemm_grouped_gemm", "sglang_grouped_gemm"],
"line_names": ["FBGEMM Grouped GEMM BF16", "SGLang Grouped GEMM BF16"],
"styles": [("blue", "-"), ("green", "-")],
}
def run_benchmark(
model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/"
):
config = get_benchmark_config(use_fp8_w8a8)
benchmark_config = triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
line_arg="provider",
line_vals=config["line_vals"],
line_names=config["line_names"],
styles=config["styles"],
ylabel="Time (ms)",
plot_name="grouped-gemm-performance",
args={},
)
@triton.testing.perf_report(benchmark_config)
def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"Benchmarking {provider} with batch_size={batch_size}")
torch.cuda.manual_seed_all(0)
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]
if provider == "fbgemm_grouped_gemm_fp8":
try:
test_data = create_fp8_test_data(
batch_size, num_groups, hidden_size, intermediate_size
)
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
def run_func():
return fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)
except Exception as e:
print(f"FP8 not supported, skipping: {e}")
return float("inf"), float("inf"), float("inf")
else:
test_data = create_test_data(
batch_size, num_groups, hidden_size, intermediate_size
)
(
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
) = test_data
if provider == "fbgemm_grouped_gemm":
def run_func():
return fbgemm_grouped_gemm(
x, w_fbgemm, m_sizes, use_fast_accum=True
)
else:
def run_func():
return sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
for _ in range(10):
try:
run_func()
except Exception as e:
print(f"Error during warmup for {provider}: {e}")
return float("inf"), float("inf"), float("inf")
torch.cuda.synchronize()
try:
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
return ms, min_ms, max_ms
except Exception as e:
print(f"Error during benchmarking for {provider}: {e}")
return float("inf"), float("inf"), float("inf")
dynamic_benchmark.run(
show_plots=True,
print_data=True,
save_path=save_path,
model_config=model_config,
use_fp8_w8a8=use_fp8_w8a8,
)
def verify_correctness(model_config, use_fp8_w8a8):
print("Verifying correctness...")
batch_size = 128
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
(x, w_fbgemm, w_sglang, c_fbgemm, c_sglang, m_sizes, seg_indptr, weight_indices) = (
test_data
)
try:
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
print("✓ BF16 Correctness verification passed!")
else:
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
return False
if use_fp8_w8a8:
try:
fp8_data = create_fp8_test_data(
batch_size, num_groups, hidden_size, intermediate_size
)
x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale = fp8_data
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale, use_fast_accum=True
)
assert result_fp8.shape == (batch_size, intermediate_size)
print("✓ FP8 functionality test passed!")
except Exception as e:
print(f"FP8 test failed (possibly unsupported): {e}")
return False
return True
except Exception as e:
print(f"✗ Error during correctness verification: {e}")
return False
def main():
parser = argparse.ArgumentParser(
description="Benchmark FBGEMM vs SGLang Grouped GEMM"
)
parser.add_argument(
"--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
help="Model name to get configuration from",
)
parser.add_argument(
"--tp-size", type=int, default=1, help="Tensor parallelism size"
)
parser.add_argument(
"--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark"
)
parser.add_argument(
"--save-path",
type=str,
default="./benchmark_grouped_gemm/",
help="Path to save benchmark results",
)
parser.add_argument(
"--verify-correctness",
action="store_true",
help="Verify correctness before benchmarking",
)
args = parser.parse_args()
try:
model_config = get_model_config(args.model, args.tp_size)
except Exception as e:
print(f"Failed to get model config: {e}")
print("Using default configuration...")
model_config = {
"num_groups": 8,
"hidden_size": 4096,
"intermediate_size": 14336,
"dtype": torch.bfloat16,
}
print("Running benchmark with:")
print(f" num_groups: {model_config['num_groups']}")
print(f" hidden_size: {model_config['hidden_size']}")
print(f" intermediate_size: {model_config['intermediate_size']}")
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
if args.verify_correctness:
if not verify_correctness(model_config, args.use_fp8_w8a8):
print("Correctness verification failed. Exiting...")
return
try:
run_benchmark(
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
save_path=args.save_path,
)
except Exception as e:
print(f"Benchmark failed: {e}")
if __name__ == "__main__":
main()
# Copy from https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import functools
import inspect
import sys
import warnings
from typing import Optional
import torch
import triton # @manual
import triton.language as tl # @manual
from triton.runtime import driver # @manual
def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
"""
Maps torch dtype to triton dtype.
Args:
dtype (torch.dtype): input dtype.
Returns:
tl.dtype: triton dtype.
"""
if dtype == torch.float16:
return tl.float16
elif dtype == torch.bfloat16:
return tl.bfloat16
elif dtype == torch.float32:
return tl.float32
elif dtype == torch.int32:
return tl.int32
elif dtype == torch.float8_e4m3fn and torch.version.hip is None:
return tl.float8e4nv
else:
raise ValueError(f"Unsupported dtype {dtype}")
# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
if HAS_TMA_DESC:
print(
"TMA benchmarks will be running with experimental grid constant TMA descriptor.",
file=sys.stderr,
)
else:
print(
"TMA benchmarks will be running without grid constant TMA descriptor.",
file=sys.stderr,
)
class TmaAutoTuneHelper:
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
class KernelParamWrapper:
def __init__(self, desc):
self.desc = desc
def tma_desc_cpu_ptr(self):
return self.desc.data_ptr()
TMA_SIZE = 128
def __init__(self):
self.fill_1d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_1d_tma_descriptor
)
self.fill_2d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_2d_tma_descriptor
)
if HAS_TMA_DESC:
self.descriptors = {}
else:
self.cuda_descriptors = {}
# Call this method outside of the lambda function for grid size
def init_tma_descriptor(self, name):
if HAS_TMA_DESC:
self.descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
)
else:
self.cuda_descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
)
# Call this method inside the lambda function for grid size
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)
# Call this method inside the lambda function for grid size
def fill_2d_tma_descriptor(
self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)
def get_tma_descriptor_kernel_param(self, name):
if HAS_TMA_DESC:
assert self.descriptors[name] is not None
return self.KernelParamWrapper(self.descriptors[name])
else:
assert self.cuda_descriptors[name] is not None
return self.cuda_descriptors[name]
_NV_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"NUM_CONSUMER_GROUPS": 1,
},
num_stages=num_stages,
num_warps=num_warps,
num_ctas=num_ctas,
)
for block_size_m in [64, 128]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
]
_HAS_WS_SUPPORT = None
def _check_ws_support():
if not hasattr(tl, "async_task"):
return False
config_signature = inspect.signature(triton.Config).parameters
if (
"num_consumer_groups" not in config_signature
or "num_buffers_warp_spec" not in config_signature
):
return False
if not HAS_TMA_DESC:
return False
return True
def _set_ws_support():
global _HAS_WS_SUPPORT
if _HAS_WS_SUPPORT is None:
_HAS_WS_SUPPORT = _check_ws_support()
_set_ws_support()
if _HAS_WS_SUPPORT:
_NV_WS_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"NUM_CONSUMER_GROUPS": max(1, num_consumer_groups),
"USE_TMA_LOAD_ON_SCALES": use_tma_load_on_scales,
"USE_TMA_STORE": use_tma_store,
},
num_stages=num_stages,
num_warps=num_warps,
num_ctas=num_ctas,
num_consumer_groups=num_consumer_groups,
num_buffers_warp_spec=num_stages,
)
for block_size_m in [64, 128, 256]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for num_stages in [2, 3, 4]
for num_warps in [4, 8, 16]
# TODO(shikaili): Resolve LLVM error.
for num_ctas in [1]
for num_consumer_groups in [0, 2]
for use_tma_load_on_scales in [True, False]
# TODO(shikaili): Resolve compatibility with ws.
for use_tma_store in [False]
]
else:
_NV_WS_CONFIGS = _NV_CONFIGS
_AMD_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"waves_per_eu": waves_per_cu,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"NUM_CONSUMER_GROUPS": 1,
},
num_stages=num_stages,
num_warps=num_warps,
)
for block_size_m in [32, 64, 128]
for block_size_n in [32, 64, 128, 256]
for block_size_k in [128, 256]
for num_stages in [1, 2]
for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]
for matrix_instr_nonkdim in [16]
]
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
device = torch.cuda.current_device()
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
if dtsize is None:
dtsize = named_args["c_ptr"].element_size()
if dtype is None:
dtype = named_args["c_ptr"].dtype
pruned_configs = []
for config in configs:
kw = config.kwargs
(
BLOCK_M,
BLOCK_N,
BLOCK_K,
num_stages,
num_warps,
num_consumer_groups,
use_tma_load_on_scales,
) = (
kw["BLOCK_SIZE_M"],
kw["BLOCK_SIZE_N"],
kw["BLOCK_SIZE_K"],
config.num_stages,
config.num_warps,
config.num_consumer_groups,
kw.get("USE_TMA_LOAD_ON_SCALES", False),
)
G, M, N, K = (
named_args["G"],
named_args["M_BUCKET"],
named_args["N"],
named_args["K"],
)
# 1. make sure we have enough smem
max_shared_memory = driver.active.utils.get_device_properties(device)[
"max_shared_mem"
]
if torch.version.hip:
required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
else:
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory > max_shared_memory:
continue
use_warp_specialization = num_consumer_groups >= 1
M_PER_GROUP = M // G
MIN_M_TILES = 32 if torch.version.hip else 64
# 2. make sure we don't load M tiles that are too big
if (
not use_warp_specialization
and BLOCK_M > MIN_M_TILES
and BLOCK_M > (M_PER_GROUP * 2)
):
continue
# 3. make sure we don't load N tiles that are too small
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
continue
num_sm = driver.active.utils.get_device_properties(device)[
"multiprocessor_count"
]
N_TILES = N // BLOCK_N
MIN_N_TILES = 32 if torch.version.hip else 64
# 4. make sure we don't load N tiles that are too big
if (
not use_warp_specialization
and BLOCK_N > MIN_N_TILES
and M * N_TILES < num_sm
):
continue
# 5. make sure we don't load N tiles that are too small
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
continue
# 6. make sure K can be evenly divided
if K % BLOCK_K != 0:
continue
# 7. make sure we can partition for ws
if use_warp_specialization:
if num_warps != 4:
continue
# "tritongpu-warp-spec-data-partition"
m_slice = BLOCK_M // num_consumer_groups
n_slice = BLOCK_N // num_consumer_groups
if m_slice < 64 and n_slice < 256:
continue
if dtsize >= 2:
if use_tma_load_on_scales:
continue
pruned_configs.append(config)
return pruned_configs
@triton.autotune(
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
restore_value=["c_ptr"], # restore for scatter_add fusion
)
@triton.jit
def _fbgemm_grouped_gemm(
a_desc_ptr,
b_desc_ptr,
c_ptr,
workspace,
scatter_add_indices,
m_sizes,
# problem sizes
G: tl.constexpr,
M_BUCKET,
N: tl.constexpr,
K: tl.constexpr,
NUM_SMS: tl.constexpr,
FUSE_SCATTER_ADD: tl.constexpr,
USE_TMA_LOAD: tl.constexpr,
USE_TMA_STORE: tl.constexpr,
USE_FAST_ACCUM: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
NUM_CONSUMER_GROUPS: tl.constexpr,
) -> None:
tl.static_assert(
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
"Cannot fuse scatter add with TMA store!",
)
tidx = tl.program_id(0)
dtype: tl.dtype = c_ptr.dtype.element_ty
TMA_SIZE: tl.constexpr = tl.constexpr(128)
if USE_TMA_STORE:
c_desc_ptr = workspace + tidx * TMA_SIZE
else:
c_desc_ptr = None
M_end_offset = 0
M_end_offset = M_end_offset.to(tl.int64)
iterated_tiles = 0
iterated_tiles = iterated_tiles.to(tl.int64)
for g in tl.range(G):
# Move across groups
m_size = tl.load(m_sizes + g)
if m_size > 0:
M_start_offset = M_end_offset
M_end_offset = M_start_offset + m_size
N_start_offset = g.to(tl.int64) * N
n_size = N
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
if USE_TMA_STORE:
# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
# Move across tiles
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
gidx = tidx - iterated_tiles
# Split M first and N second.
tile_m_idx = gidx % num_m_tiles
tile_n_idx = gidx // num_m_tiles
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
tl.static_assert(K % BLOCK_SIZE_K == 0)
if USE_TMA_LOAD:
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
for k_offset in range(0, K, BLOCK_SIZE_K):
a = tl._experimental_descriptor_load(
a_desc_ptr,
[m_offset, k_offset],
[BLOCK_SIZE_M, BLOCK_SIZE_K],
dtype,
)
b = tl._experimental_descriptor_load(
b_desc_ptr,
[n_offset, k_offset],
[BLOCK_SIZE_N, BLOCK_SIZE_K],
dtype,
)
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b.T, accumulator)
else:
accumulator += tl.dot(a, b.T)
else:
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = (
a_desc_ptr
+ (M_start_offset + offs_am[:, None]) * K
+ offs_k[None, :]
)
b_ptrs = (
b_desc_ptr
+ (N_start_offset + offs_bn[:, None]) * K
+ offs_k[None, :]
)
for k_offset in range(0, K, BLOCK_SIZE_K):
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
accumulator += tl.dot(a, b.T)
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
if USE_TMA_STORE:
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
accumulator.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
elif FUSE_SCATTER_ADD:
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
mask = offs_am < m_size
m_offsets = tl.load(
scatter_add_indices + M_start_offset + offs_am,
mask=mask,
cache_modifier=".ca",
)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c = accumulator.to(c_ptr.dtype.element_ty)
tl.atomic_add(
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
c,
mask=mask[:, None] and offs_bn[None, :] < n_size,
sem="relaxed",
)
else:
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c = accumulator.to(c_ptr.dtype.element_ty)
tl.store(
c_ptr
+ (M_start_offset + offs_am[:, None]) * N
+ offs_bn[None, :],
c,
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
)
tidx += NUM_SMS
iterated_tiles += num_tiles
# TODO(shikaili): Too much code duplication. Need to refactor.
@triton.autotune(
configs=_NV_WS_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
restore_value=["c_ptr"], # restore for scatter_add fusion
)
@triton.jit
def _fbgemm_grouped_gemm_ws(
a_desc_ptr,
b_desc_ptr,
c_ptr,
workspace,
scatter_add_indices,
m_sizes,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_SMS: tl.constexpr,
FUSE_SCATTER_ADD: tl.constexpr,
USE_TMA_LOAD: tl.constexpr,
USE_FAST_ACCUM: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
NUM_CONSUMER_GROUPS: tl.constexpr,
USE_TMA_LOAD_ON_SCALES: tl.constexpr,
USE_TMA_STORE: tl.constexpr,
) -> None:
tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
tl.static_assert(not USE_TMA_LOAD_ON_SCALES, "Not supported!")
tl.static_assert(
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
"Cannot fuse scatter add with TMA store!",
)
tidx = tl.program_id(0)
dtype: tl.dtype = c_ptr.dtype.element_ty
TMA_SIZE: tl.constexpr = tl.constexpr(128)
if USE_TMA_STORE:
c_desc_ptr = workspace + tidx * TMA_SIZE
else:
c_desc_ptr = None
M_end_offset = 0
M_end_offset = M_end_offset.to(tl.int64)
iterated_tiles = 0
iterated_tiles = iterated_tiles.to(tl.int64)
for g in tl.range(G):
# Move across groups
m_size = tl.load(m_sizes + g, cache_modifier=".ca")
if m_size > 0:
M_start_offset = M_end_offset
M_end_offset = M_start_offset + m_size
N_start_offset = g.to(tl.int64) * N
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
tl.static_assert(N % BLOCK_SIZE_N == 0)
NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
num_tiles = num_m_tiles * NUM_N_TILES
if USE_TMA_STORE:
with tl.async_task([0]):
# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, N],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
# Move across tiles
next_iterated_tiles = iterated_tiles + num_tiles
if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
for i in range(tidx, next_iterated_tiles, NUM_SMS):
gidx = i - iterated_tiles
# Split M first and N second.
tile_m_idx = gidx % num_m_tiles
tile_n_idx = gidx // num_m_tiles
accumulator = tl.zeros(
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
)
tl.static_assert(K % BLOCK_SIZE_K == 0)
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
for k_offset in range(0, K, BLOCK_SIZE_K):
with tl.async_task([0]):
a = tl._experimental_descriptor_load(
a_desc_ptr,
[m_offset, k_offset],
[BLOCK_SIZE_M, BLOCK_SIZE_K],
dtype,
)
b = tl._experimental_descriptor_load(
b_desc_ptr,
[n_offset, k_offset],
[BLOCK_SIZE_N, BLOCK_SIZE_K],
dtype,
)
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b.T, accumulator)
else:
accumulator += tl.dot(a, b.T)
if USE_TMA_STORE:
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
accumulator.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
elif FUSE_SCATTER_ADD:
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
0, BLOCK_SIZE_M
)
mask = offs_am < m_size
m_offsets = tl.load(
scatter_add_indices + M_start_offset + offs_am,
mask=mask,
cache_modifier=".ca",
)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
0, BLOCK_SIZE_N
)
c = accumulator.to(c_ptr.dtype.element_ty)
tl.atomic_add(
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
c,
mask=mask[:, None],
sem="relaxed",
)
else:
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
0, BLOCK_SIZE_M
)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
0, BLOCK_SIZE_N
)
c = accumulator.to(c_ptr.dtype.element_ty)
tl.store(
c_ptr
+ (M_start_offset + offs_am[:, None]) * N
+ offs_bn[None, :],
c,
mask=offs_am[:, None] < m_size,
cache_modifier=".cs",
)
tidx += NUM_SMS
iterated_tiles += num_tiles
TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv
# TODO(shikaili): clean up redundant 'b_scale_desc_ptr' argument.
@triton.autotune(
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={
"early_config_prune": functools.partial(
early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
)
},
restore_value=["c_ptr"], # restore for scatter_add fusion
)
@triton.jit
def _fbgemm_grouped_gemm_fp8_rowwise(
a_desc_ptr,
a_scale_ptr,
b_desc_ptr,
b_scale_ptr,
b_scale_desc_ptr,
c_ptr,
workspace,
scatter_add_indices,
m_sizes,
# problem sizes
G: tl.constexpr,
M_BUCKET,
N: tl.constexpr,
K: tl.constexpr,
NUM_SMS: tl.constexpr,
FUSE_SCATTER_ADD: tl.constexpr,
USE_TMA_LOAD: tl.constexpr,
USE_TMA_STORE: tl.constexpr,
USE_FAST_ACCUM: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
NUM_CONSUMER_GROUPS: tl.constexpr,
) -> None:
tl.static_assert(
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
"Cannot fuse scatter add with TMA store!",
)
tidx = tl.program_id(0)
dtype = TT_FP8_DTYPE
TMA_SIZE: tl.constexpr = tl.constexpr(128)
if USE_TMA_STORE:
c_desc_ptr = workspace + tidx * TMA_SIZE
else:
c_desc_ptr = None
M_end_offset = 0
M_end_offset = M_end_offset.to(tl.int64)
iterated_tiles = 0
iterated_tiles = iterated_tiles.to(tl.int64)
for g in tl.range(G):
# Move across groups
m_size = tl.load(m_sizes + g)
if m_size > 0:
M_start_offset = M_end_offset
M_end_offset = M_start_offset + m_size
N_start_offset = g.to(tl.int64) * N
n_size = N
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
if USE_TMA_STORE:
# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
# Move across tiles
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
gidx = tidx - iterated_tiles
# Split M first and N second.
tile_m_idx = gidx % num_m_tiles
tile_n_idx = gidx // num_m_tiles
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
tl.static_assert(K % BLOCK_SIZE_K == 0)
if USE_TMA_LOAD:
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
for k_offset in range(0, K, BLOCK_SIZE_K):
a = tl._experimental_descriptor_load(
a_desc_ptr,
[m_offset, k_offset],
[BLOCK_SIZE_M, BLOCK_SIZE_K],
dtype,
)
b = tl._experimental_descriptor_load(
b_desc_ptr,
[n_offset, k_offset],
[BLOCK_SIZE_N, BLOCK_SIZE_K],
dtype,
)
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b.T, accumulator)
else:
accumulator += tl.dot(a, b.T)
else:
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = (
a_desc_ptr
+ (M_start_offset + offs_am[:, None]) * K
+ offs_k[None, :]
)
b_ptrs = (
b_desc_ptr
+ (N_start_offset + offs_bn[:, None]) * K
+ offs_k[None, :]
)
for k_offset in range(0, K, BLOCK_SIZE_K):
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
accumulator += tl.dot(a, b.T)
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
a_scale = tl.load(
a_scale_ptr + M_start_offset + offs_am[:, None],
mask=offs_am[:, None] < m_size,
)
b_scale = tl.load(
b_scale_ptr + N_start_offset + offs_bn[None, :],
mask=offs_bn[None, :] < n_size,
)
c = accumulator.to(tl.float32) * a_scale * b_scale
if USE_TMA_STORE:
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
c.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
elif FUSE_SCATTER_ADD:
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
mask = offs_am < m_size
m_offsets = tl.load(
scatter_add_indices + M_start_offset + offs_am,
mask=mask,
cache_modifier=".ca",
)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
tl.atomic_add(
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
c.to(c_ptr.dtype.element_ty),
mask=mask[:, None] and offs_bn[None, :] < n_size,
sem="relaxed",
)
else:
tl.store(
c_ptr
+ (M_start_offset + offs_am[:, None]) * N
+ offs_bn[None, :],
c,
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
)
tidx += NUM_SMS
iterated_tiles += num_tiles
# TODO(shikaili): Too much code duplication. Need to refactor.
@triton.autotune(
configs=_NV_WS_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={
"early_config_prune": functools.partial(
early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
)
},
restore_value=["c_ptr"], # restore for scatter_add fusion
)
@triton.jit
def _fbgemm_grouped_gemm_fp8_rowwise_ws(
a_desc_ptr,
a_scale_ptr,
b_desc_ptr,
b_scale_ptr,
b_scale_desc_ptr,
c_ptr,
workspace,
scatter_add_indices,
m_sizes,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_SMS: tl.constexpr,
FUSE_SCATTER_ADD: tl.constexpr,
USE_TMA_LOAD: tl.constexpr,
USE_FAST_ACCUM: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
NUM_CONSUMER_GROUPS: tl.constexpr,
USE_TMA_LOAD_ON_SCALES: tl.constexpr,
USE_TMA_STORE: tl.constexpr,
) -> None:
tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
tl.static_assert(
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
"Cannot fuse scatter add with TMA store!",
)
tidx = tl.program_id(0)
dtype = TT_FP8_DTYPE
TMA_SIZE: tl.constexpr = tl.constexpr(128)
if USE_TMA_STORE:
c_desc_ptr = workspace + tidx * TMA_SIZE
else:
c_desc_ptr = None
M_end_offset = 0
M_end_offset = M_end_offset.to(tl.int64)
iterated_tiles = 0
iterated_tiles = iterated_tiles.to(tl.int64)
for g in tl.range(G):
# Move across groups
m_size = tl.load(m_sizes + g, cache_modifier=".ca")
if m_size > 0:
M_start_offset = M_end_offset
M_end_offset = M_start_offset + m_size
N_start_offset = g.to(tl.int64) * N
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
tl.static_assert(N % BLOCK_SIZE_N == 0)
NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
num_tiles = num_m_tiles * NUM_N_TILES
if USE_TMA_STORE:
with tl.async_task([0]):
# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, N],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
# Move across tiles
next_iterated_tiles = iterated_tiles + num_tiles
if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
for i in range(tidx, next_iterated_tiles, NUM_SMS):
gidx = i - iterated_tiles
# Split M first and N second.
tile_m_idx = gidx % num_m_tiles
tile_n_idx = gidx // num_m_tiles
accumulator = tl.zeros(
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
)
tl.static_assert(K % BLOCK_SIZE_K == 0)
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
for k_offset in range(0, K, BLOCK_SIZE_K):
with tl.async_task([0]):
a = tl._experimental_descriptor_load(
a_desc_ptr,
[m_offset, k_offset],
[BLOCK_SIZE_M, BLOCK_SIZE_K],
dtype,
)
b = tl._experimental_descriptor_load(
b_desc_ptr,
[n_offset, k_offset],
[BLOCK_SIZE_N, BLOCK_SIZE_K],
dtype,
)
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b.T, accumulator)
else:
accumulator += tl.dot(a, b.T)
if USE_TMA_LOAD_ON_SCALES:
with tl.async_task([0]):
b_scale = tl._experimental_descriptor_load(
b_scale_desc_ptr,
[n_offset],
[BLOCK_SIZE_N],
tl.float32,
)
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
0, BLOCK_SIZE_M
)
a_scale = tl.load(
a_scale_ptr + M_start_offset + offs_am[:, None],
mask=offs_am[:, None] < m_size,
cache_modifier=".ca",
)
c = accumulator.to(tl.float32) * a_scale * b_scale[None, :]
else:
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
0, BLOCK_SIZE_M
)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
0, BLOCK_SIZE_N
)
a_scale = tl.load(
a_scale_ptr + M_start_offset + offs_am[:, None],
mask=offs_am[:, None] < m_size,
cache_modifier=".ca",
)
b_scale = tl.load(
b_scale_ptr + N_start_offset + offs_bn[None, :],
cache_modifier=".ca",
)
c = accumulator.to(tl.float32) * a_scale * b_scale
if USE_TMA_STORE:
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
c.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
elif FUSE_SCATTER_ADD:
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
0, BLOCK_SIZE_M
)
mask = offs_am < m_size
m_offsets = tl.load(
scatter_add_indices + M_start_offset + offs_am,
mask=mask,
cache_modifier=".ca",
)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
0, BLOCK_SIZE_N
)
tl.atomic_add(
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
c,
mask=mask[:, None],
sem="relaxed",
)
else:
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
0, BLOCK_SIZE_M
)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
0, BLOCK_SIZE_N
)
tl.store(
c_ptr
+ (M_start_offset + offs_am[:, None]) * N
+ offs_bn[None, :],
c,
mask=offs_am[:, None] < m_size,
cache_modifier=".cs",
)
tidx += NUM_SMS
iterated_tiles += num_tiles
warnings.simplefilter("once")
def _grouped_gemm(
*,
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
x_scale: Optional[torch.Tensor],
w_scale: Optional[torch.Tensor],
use_fast_accum: bool,
use_warp_specialization: bool,
output_tensor: Optional[torch.Tensor],
scatter_add_indices: Optional[torch.Tensor],
) -> torch.Tensor:
USE_TMA_LOAD = not torch.version.hip
USE_TMA_STORE = False
if USE_TMA_LOAD and not HAS_TMA_DESC:
USE_TMA_LOAD = False
warnings.warn("TMA load is disabled as there is no TMA descriptor support!")
if USE_TMA_STORE and not HAS_TMA_DESC:
USE_TMA_STORE = False
warnings.warn("TMA store is disabled as there is no TMA descriptor support!")
# TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton.
if use_warp_specialization and torch.version.hip:
warnings.warn("Warp specialization is disabled as it is not supported on ROCm.")
use_warp_specialization = False
if use_warp_specialization and not _HAS_WS_SUPPORT:
warnings.warn(
"Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs."
)
use_warp_specialization = False
if use_warp_specialization:
assert HAS_TMA_DESC
USE_TMA_STORE = True # Tuning decision
G = m_sizes.shape[0]
assert x.is_contiguous()
assert w.is_contiguous()
assert m_sizes.is_contiguous()
M, K = x.shape
N = w.shape[0] // G
assert K == w.shape[1]
if output_tensor is None:
FUSE_SCATTER_ADD = False
assert scatter_add_indices is None
y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
else:
FUSE_SCATTER_ADD = True
assert scatter_add_indices is not None
assert scatter_add_indices.is_contiguous()
assert scatter_add_indices.shape == (M,)
y = output_tensor
if M == 0 or N == 0:
return y
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
desc_helper = None
desc_x = x
desc_w = w
desc_ws = w_scale
workspace = None
if USE_TMA_LOAD:
desc_helper = TmaAutoTuneHelper()
desc_helper.init_tma_descriptor("x")
desc_helper.init_tma_descriptor("w")
desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
if use_warp_specialization and w_scale is not None:
desc_helper.init_tma_descriptor("ws")
desc_ws = desc_helper.get_tma_descriptor_kernel_param("ws")
if USE_TMA_STORE:
workspace = torch.empty(
NUM_SMS * TmaAutoTuneHelper.TMA_SIZE,
device=x.device,
dtype=torch.uint8,
)
def grid(META):
if USE_TMA_LOAD:
nonlocal desc_helper # noqa: F824
desc_helper.fill_2d_tma_descriptor(
"x",
x.data_ptr(),
M,
K,
META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"],
META["BLOCK_SIZE_K"],
x.element_size(),
)
desc_helper.fill_2d_tma_descriptor(
"w",
w.data_ptr(),
N * G,
K,
META["BLOCK_SIZE_N"],
META["BLOCK_SIZE_K"],
w.element_size(),
)
if META.get("USE_TMA_LOAD_ON_SCALES", False):
desc_helper.fill_1d_tma_descriptor(
"ws",
w_scale.data_ptr(),
N * G,
META["BLOCK_SIZE_N"],
w_scale.element_size(),
)
return (NUM_SMS,)
M_BUCKET_CAP = 16384
M_BUCKET = min(triton.next_power_of_2(M), M_BUCKET_CAP)
if x_scale is not None and w_scale is not None:
assert x_scale.is_contiguous()
assert w_scale.is_contiguous()
fn = (
_fbgemm_grouped_gemm_fp8_rowwise_ws
if use_warp_specialization
else _fbgemm_grouped_gemm_fp8_rowwise
)
args = (
desc_x,
x_scale,
desc_w,
w_scale,
desc_ws,
y,
workspace,
scatter_add_indices,
m_sizes,
G,
M_BUCKET,
N,
K,
NUM_SMS,
FUSE_SCATTER_ADD,
USE_TMA_LOAD,
)
if use_warp_specialization:
args += (use_fast_accum,)
else:
args += (USE_TMA_STORE, use_fast_accum)
fn[grid](*args)
else:
assert x_scale is None
assert w_scale is None
fn = (
_fbgemm_grouped_gemm_ws if use_warp_specialization else _fbgemm_grouped_gemm
)
args = (
desc_x,
desc_w,
y,
workspace,
scatter_add_indices,
m_sizes,
G,
M_BUCKET,
N,
K,
NUM_SMS,
FUSE_SCATTER_ADD,
USE_TMA_LOAD,
)
if use_warp_specialization:
args += (use_fast_accum,)
else:
args += (USE_TMA_STORE, use_fast_accum)
fn[grid](*args)
return y
def grouped_gemm(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
use_fast_accum: bool = True,
*,
_use_warp_specialization: bool = True,
_output_tensor: Optional[torch.Tensor] = None,
_scatter_add_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return _grouped_gemm(
x=x,
w=w,
m_sizes=m_sizes,
x_scale=None,
w_scale=None,
use_fast_accum=use_fast_accum,
use_warp_specialization=_use_warp_specialization,
output_tensor=_output_tensor,
scatter_add_indices=_scatter_add_indices,
)
def grouped_gemm_fp8_rowwise(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
use_fast_accum: bool = True,
*,
_use_warp_specialization: bool = True,
_output_tensor: Optional[torch.Tensor] = None,
_scatter_add_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return _grouped_gemm(
x=x,
w=w,
m_sizes=m_sizes,
x_scale=x_scale,
w_scale=w_scale,
use_fast_accum=use_fast_accum,
use_warp_specialization=_use_warp_specialization,
output_tensor=_output_tensor,
scatter_add_indices=_scatter_add_indices,
)
import os
import sys
import pytest
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
try:
from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm
from fbgemm_grouped_gemm import (
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
)
FBGEMM_AVAILABLE = True
print("✓ Successfully imported FBGEMM grouped GEMM")
except ImportError as e:
print(f"✗ Failed to import FBGEMM grouped GEMM: {e}")
FBGEMM_AVAILABLE = False
try:
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton as sglang_grouped_gemm,
)
SGLANG_AVAILABLE = True
print("✓ Successfully imported SGLang grouped GEMM")
except ImportError as e:
print(f"✗ Failed to import SGLang grouped GEMM: {e}")
SGLANG_AVAILABLE = False
def create_uniform_groups(batch_size, num_groups, device):
tokens_per_group = batch_size // num_groups
return torch.full((num_groups,), tokens_per_group, dtype=torch.int64, device=device)
def create_non_uniform_groups(batch_size, num_groups, device):
remaining = batch_size
m_sizes = []
for i in range(num_groups - 1):
if remaining <= 1:
size = 1
else:
max_size = remaining - (num_groups - i - 1) + 1
size = torch.randint(1, max_size, (1,)).item()
m_sizes.append(size)
remaining -= size
m_sizes.append(remaining)
return torch.tensor(m_sizes, dtype=torch.int64, device=device)
def create_sglang_inputs(x, w, m_sizes, num_groups, intermediate_size, device):
batch_size = x.shape[0]
c_sglang = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device=device
)
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device=device)
current_pos = 0
for i, size in enumerate(m_sizes):
current_pos += size
seg_indptr[i + 1] = current_pos
weight_indices = torch.arange(num_groups, dtype=torch.int64, device=device)
w_sglang = w.view(num_groups, intermediate_size, -1)
return c_sglang, seg_indptr, weight_indices, w_sglang
def create_fp8_data(batch_size, num_groups, hidden_size, intermediate_size, device):
torch.manual_seed(42)
x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device=device)
w_fp16 = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.float16, device=device
)
x_fp8 = x_fp16.to(torch.float8_e4m3fn)
w_fp8 = w_fp16.to(torch.float8_e4m3fn)
x_scale = torch.randn(batch_size, dtype=torch.float32, device=device).abs() + 1e-4
w_scale = torch.randn(num_groups, dtype=torch.float32, device=device).abs() + 1e-4
return x_fp8, w_fp8, x_scale, w_scale
@pytest.fixture
def device():
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
return torch.device("cuda")
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("num_groups", [2, 4, 8])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_uniform_groups(batch_size, num_groups, hidden_size, intermediate_size, device):
if batch_size % num_groups != 0:
pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")
torch.manual_seed(42)
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
x, w, m_sizes, num_groups, intermediate_size, device
)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
@pytest.mark.parametrize("batch_size", [63, 100, 127])
@pytest.mark.parametrize("num_groups", [3, 5, 7])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_non_uniform_groups(
batch_size, num_groups, hidden_size, intermediate_size, device
):
torch.manual_seed(42)
m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
x, w, m_sizes, num_groups, intermediate_size, device
)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
@pytest.mark.parametrize("batch_size,num_groups", [(64, 4), (128, 8), (256, 16)])
@pytest.mark.parametrize("hidden_size", [768, 2048, 4096])
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 8192])
def test_large_dimensions(
batch_size, num_groups, hidden_size, intermediate_size, device
):
torch.manual_seed(42)
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
x, w, m_sizes, num_groups, intermediate_size, device
)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.parametrize("batch_size", [32, 64])
@pytest.mark.parametrize("num_groups", [2, 4])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_fp8_uniform_groups(
batch_size, num_groups, hidden_size, intermediate_size, device
):
if batch_size % num_groups != 0:
pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")
torch.manual_seed(42)
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
batch_size, num_groups, hidden_size, intermediate_size, device
)
try:
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)
assert result_fp8.shape == (batch_size, intermediate_size)
assert result_fp8.dtype == torch.bfloat16
except Exception as e:
pytest.skip(f"FP8 test failed (possibly unsupported): {e}")
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.parametrize("batch_size", [63, 100])
@pytest.mark.parametrize("num_groups", [3, 5])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_fp8_non_uniform_groups(
batch_size, num_groups, hidden_size, intermediate_size, device
):
torch.manual_seed(42)
m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
batch_size, num_groups, hidden_size, intermediate_size, device
)
try:
result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)
assert result_fp8.shape == (batch_size, intermediate_size)
assert result_fp8.dtype == torch.bfloat16
except Exception as e:
pytest.skip(f"FP8 test failed (possibly unsupported): {e}")
@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
def test_fbgemm_only_uniform(device):
torch.manual_seed(42)
batch_size, num_groups = 64, 4
hidden_size, intermediate_size = 512, 1024
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
result = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)
assert result.shape == (batch_size, intermediate_size)
assert result.dtype == torch.bfloat16
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
def test_sglang_only_uniform(device):
torch.manual_seed(42)
batch_size, num_groups = 64, 4
hidden_size, intermediate_size = 512, 1024
m_sizes = create_uniform_groups(batch_size, num_groups, device)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
)
c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
x, w, m_sizes, num_groups, intermediate_size, device
)
result = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
assert result.shape == (batch_size, intermediate_size)
assert result.dtype == torch.bfloat16
def test_imports():
assert (
FBGEMM_AVAILABLE or SGLANG_AVAILABLE
), "Neither FBGEMM nor SGLang is available"
if __name__ == "__main__":
pytest.main([__file__, "-v"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment