Commit 118f1fc7 authored by maxiao1's avatar maxiao1
Browse files

sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

parents
# ADAPTED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/utils.py
import os
import sys
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
def init_dist(local_rank: int, num_local_ranks: int, args):
ip = args.master_addr
port = args.master_port
num_nodes = args.nnodes
node_rank = args.node_rank
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
dist.init_process_group(
backend="nccl",
init_method=f"tcp://{ip}:{port}",
world_size=num_nodes * num_local_ranks,
rank=node_rank * num_local_ranks + local_rank,
)
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
torch.cuda.set_device(local_rank)
return (
dist.get_rank(),
dist.get_world_size(),
dist.new_group(list(range(num_local_ranks * num_nodes))),
)
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double() + 1, y.double() + 1
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return (1 - sim).item()
def per_token_cast_to_fp8(x: torch.Tensor):
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n
), (x_amax / 448.0).view(m, -1)
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
def inplace_unique(x: torch.Tensor, num_slots: int):
assert x.dim() == 2
mask = x < 0
x_padded = x.masked_fill(mask, num_slots)
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
bin_count = bin_count[:, :num_slots]
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
x[:, :].fill_(-1)
valid_len = min(num_slots, x.size(1))
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
def create_grouped_scores(
scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int
):
num_tokens, num_experts = scores.shape
scores = scores.view(num_tokens, num_groups, -1)
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
return (scores * mask).view(num_tokens, num_experts)
def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
# Flush L2 cache with 256 MB data
torch.cuda.synchronize()
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
# Warmup
for _ in range(num_warmups):
fn()
# Flush L2
cache.zero_()
# Testing
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
for i in range(num_tests):
# Record
start_events[i].record()
fn()
end_events[i].record()
if post_fn is not None:
post_fn()
torch.cuda.synchronize()
times = np.array(
[s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)]
)[1:]
return np.average(times), np.min(times), np.max(times)
class empty_suppress:
def __enter__(self):
return self
def __exit__(self, *_):
pass
class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, "w")
self.errnull_file = open(os.devnull, "w")
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
def bench_kineto(
fn,
kernel_names,
num_tests: int = 30,
suppress_kineto_output: bool = False,
trace_path: Optional[str] = None,
barrier_comm_profiling: bool = False,
):
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
) as prof:
for i in range(2):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if barrier_comm_profiling:
lhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
rhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
lhs @ rhs
dist.all_reduce(torch.ones(1, dtype=torch.float, device="cuda"))
for _ in range(num_tests):
fn()
prof.step()
# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tupled = isinstance(kernel_names, tuple)
prof_lines = (
prof.key_averages()
.table(sort_by="cuda_time_total", max_name_column_width=100)
.split("\n")
)
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert (
sum([name in line for line in prof_lines]) == 1
), f"Errors of the kernel {name} in the profiling table"
# Save chrome traces
if trace_path is not None:
prof.export_chrome_trace(trace_path)
# Return average kernel times
units = {"ms": 1e3, "us": 1e6}
kernel_times = []
for name in kernel_names:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, "")) / scale)
break
break
return tuple(kernel_times) if is_tupled else kernel_times[0]
def hash_tensor(t: torch.Tensor):
return t.view(torch.int64).sum().item()
# MODIFIED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py
"""
Example usage:
python tuning_deepep.py --nnodes 4 --node-rank $MY_NODE_RANK --master-addr 1.2.3.4
Then check `deepep_tuned.json`
"""
import argparse
import json
import time
from copy import deepcopy
from pathlib import Path
# noinspection PyUnresolvedReferences
import deep_ep
import torch
import torch.distributed as dist
from deepep_utils import (
bench,
calc_diff,
create_grouped_scores,
init_dist,
inplace_unique,
per_token_cast_back,
per_token_cast_to_fp8,
)
def test_main(
num_sms: int,
local_rank: int,
num_local_ranks: int,
num_ranks: int,
num_nodes: int,
rank: int,
buffer: deep_ep.Buffer,
group: dist.ProcessGroup,
args,
):
# Settings
num_tokens, hidden, num_topk_groups, num_topk, num_experts = (
4096,
7168,
min(num_nodes, 4),
8,
(256 // num_ranks) * num_ranks,
)
assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0:
print(
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}",
flush=True,
)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
x_e4m3 = per_token_cast_to_fp8(x)
scores = (
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
+ 1
)
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
group_idx = torch.topk(
group_scores, k=num_topk_groups, dim=-1, sorted=False
).indices
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[
1
]
topk_weights = (
torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank
)
topk_weights_pure_rand = torch.randn(
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
)
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
rdma_rank_idx = rank_idx // num_local_ranks
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
inplace_unique(rdma_rank_idx, num_nodes)
# RDMA dispatch counts
rdma_idx = topk_idx // (num_experts // num_nodes)
rdma_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rdma_idx, num_nodes)
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
num_tokens_per_rdma_rank = torch.empty((num_nodes,), dtype=torch.int, device="cuda")
token_idx_in_rank = torch.full(
(num_ranks, num_tokens), -1, dtype=torch.long, device="cuda"
)
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(
count, dtype=torch.long, device="cuda"
)
for i in range(num_nodes):
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
(
ref_num_tokens_per_rank,
ref_num_tokens_per_rdma_rank,
ref_num_tokens_per_expert,
ref_is_token_in_rank,
_,
) = buffer.get_dispatch_layout(topk_idx, num_experts)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True)
print("", flush=True)
group.barrier()
time.sleep(1)
# Config
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, recv_gbl_rank_prefix_sum):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = recv_gbl_rank_prefix_sum[i].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for with_topk in (False, True):
if local_rank == 0:
print(
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...',
flush=True,
end="",
)
dispatch_args = {
"x": current_x,
"num_tokens_per_rank": num_tokens_per_rank,
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": config,
"async_finish": async_mode,
}
if with_topk:
dispatch_args.update(
{
"topk_idx": topk_idx,
"topk_weights": (
topk_weights_pure_rand
if current_x is x_pure_rand
else topk_weights
),
}
)
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
(
recv_x,
recv_topk_idx,
recv_topk_weights,
recv_num_tokens_per_expert_list,
handle,
event,
) = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = (
per_token_cast_back(*recv_x)
if isinstance(recv_x, tuple)
else recv_x
)
# Checks
recv_gbl_rank_prefix_sum = handle[-4]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(
0
), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}"
assert (
gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()
== recv_num_tokens_per_expert_list
)
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
if with_topk:
# Check `topk_idx`
assert (
recv_topk_idx.eq(-1)
| (
(recv_topk_idx >= 0)
& (recv_topk_idx < (num_experts // num_ranks))
)
).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = (
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
recv_topk_weights
)[recv_topk_idx.eq(-1)]
)
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
# Test cached dispatch (must without top-k staffs)
if not with_topk:
dispatch_args = {
"x": current_x,
"handle": handle,
"config": config,
"async_finish": async_mode,
}
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = (
per_token_cast_back(*recv_x)
if isinstance(recv_x, tuple)
else recv_x
)
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
# Test combine
combine_args = {
"x": recv_x,
"handle": handle,
"config": config,
"async_finish": async_mode,
}
if with_topk:
combine_args.update({"topk_weights": recv_topk_weights})
if previous_mode:
combine_args.update({"previous_event": buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(
**combine_args
)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(
dim=1
).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = (
combined_topk_weights
if (current_x is x_pure_rand)
else (
combined_topk_weights
/ is_token_in_rank.sum(dim=1).unsqueeze(1)
)
)
ref_topk_weights = (
topk_weights_pure_rand
if current_x is x_pure_rand
else topk_weights
)
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
if local_rank == 0:
print(" passed", flush=True)
if local_rank == 0:
print("", flush=True)
output_data = {}
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
rdma_send_bytes = (
(dispatch_bf16_rdma_send_bytes * fp8_factor)
if isinstance(current_x, tuple)
else dispatch_bf16_rdma_send_bytes
)
nvl_recv_bytes = (
(dispatch_bf16_nvl_recv_bytes * fp8_factor)
if isinstance(current_x, tuple)
else dispatch_bf16_nvl_recv_bytes
)
for nvl_chunk_size in range(4, 33, 4):
for rdma_chunk_size in range(4, 33, 4):
config_kwargs = {
"num_sms": num_sms,
"num_max_nvl_chunked_send_tokens": nvl_chunk_size,
"num_max_nvl_chunked_recv_tokens": nvl_buffer_size,
"num_max_rdma_chunked_send_tokens": rdma_chunk_size,
"num_max_rdma_chunked_recv_tokens": rdma_buffer_size,
}
config = deep_ep.Config(**config_kwargs)
tune_args = {"x": current_x, "handle": handle, "config": config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
best_time, best_results = t, (
num_sms,
nvl_chunk_size,
rdma_chunk_size,
config_kwargs,
)
if local_rank == 0:
print(
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ",
flush=True,
)
if local_rank == 0:
print(
f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',
flush=True,
)
print("", flush=True)
is_fp8 = isinstance(current_x, tuple)
if is_fp8:
output_data["normal_dispatch"] = deepcopy(best_results[3])
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
best_dispatch_results = torch.tensor(
[best_results[0], best_results[1], best_results[2]],
dtype=torch.int32,
device="cuda",
)
all_best_fp8_results_list = [
torch.zeros_like(best_dispatch_results)
for _ in range(torch.distributed.get_world_size())
]
dist.all_gather(
all_best_fp8_results_list, best_dispatch_results, group=group
)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(
best_dispatch_results[0],
best_dispatch_results[1],
nvl_buffer_size,
best_dispatch_results[2],
rdma_buffer_size,
)
dispatch_args = {
"x": x,
"num_tokens_per_rank": num_tokens_per_rank,
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": dispatch_config if dispatch_config is not None else config,
}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 5, 1):
for rdma_chunk_size in range(8, 33, 4):
config_kwargs = {
"num_sms": num_sms,
"num_max_nvl_chunked_send_tokens": nvl_chunk_size,
"num_max_nvl_chunked_recv_tokens": nvl_buffer_size,
"num_max_rdma_chunked_send_tokens": rdma_chunk_size,
"num_max_rdma_chunked_recv_tokens": rdma_buffer_size,
}
config = deep_ep.Config(**config_kwargs)
tune_args = {"x": recv_x, "handle": handle, "config": config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ",
flush=True,
)
if t < best_time:
best_time, best_results = t, (
num_sms,
nvl_chunk_size,
rdma_chunk_size,
config_kwargs,
)
if local_rank == 0:
print(
f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)",
flush=True,
)
print("", flush=True)
output_data["normal_combine"] = deepcopy(best_results[3])
if rank == 0 and local_rank == 0:
_write_output(args, output_data)
def _write_output(args, output_data):
text = json.dumps(output_data, indent=4)
output_path = args.output_path
print(f"Write to {output_path} with {text}")
Path(output_path).write_text(text)
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int, args):
num_nodes = args.nnodes
rank, num_ranks, group = init_dist(local_rank, num_local_ranks, args)
num_sms = args.num_sms
num_qps_per_rank = num_sms // 2
buffer = deep_ep.Buffer(
group,
int(1e9),
int(1e9),
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
)
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)
for i in (num_sms,):
test_main(
i,
local_rank,
num_local_ranks,
num_ranks,
num_nodes,
rank,
buffer,
group,
args,
)
if local_rank == 0:
print("", flush=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-sms", type=int, default=24)
parser.add_argument("--output-path", type=str, default="deepep_tuned.json")
parser.add_argument("--nnodes", type=int, default=1)
parser.add_argument("--node-rank", type=int, default=0)
parser.add_argument("--master-addr", type=str, default="127.0.0.1")
parser.add_argument("--master-port", type=int, default=8361)
args = parser.parse_args()
print(f"Start system with {args=}")
num_processes = 8
torch.multiprocessing.spawn(
test_loop, args=(num_processes, args), nprocs=num_processes
)
## DeepSeek kernels benchmark
### Prerequisites
- You should install [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) from source before run `benchmark_deepgemm_fp8_gemm.py` and `benchmark_deepgemm_fp8_group_gemm.py`.
### Benchmark
- `benchmark_deepgemm_fp8_gemm.py`
```bash
python benchmark_deepgemm_fp8_gemm.py --run_correctness --tp_size 1
```
- `benchmark_deepgemm_fp8_group_gemm.py`
```bash
python benchmark_deepgemm_fp8_group_gemm.py --run_correctness --tp_size 1
```
- You can use the `--run_correctness` parameter to verify all kernels results's correctness.
- You can use the `--tp_size` parameter to benchmark all FP8 w8a8 block-wise matrix multiplications involved in DeepSeek V3/R1 under the current tensor parallelism (TP) setting. This benchmark compares DeepSeek's open-source [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) implementation with SGLang's and VLLM Triton implementation.
from typing import Tuple
import deep_gemm
import tilelang
import tilelang.language as T
import torch
import triton
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
)
from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_deepgemm as w8a8_block_fp8_matmul,
)
# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
def tl_gemm(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"e4m3_float8",
], "Currently only e4m3_float8 is supported"
assert out_dtype in [
"bfloat16",
"float16",
], "Currently only bfloat16 and float16 are supported"
TILE_SIZE = (128, 128, 128)
block_M = TILE_SIZE[0]
block_N = TILE_SIZE[1]
block_K = TILE_SIZE[2]
A_shape = (M, K)
Scales_A_shape = (M, T.ceildiv(K, block_K))
B_shape = (N, K)
Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K))
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (block_M, block_N)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
scales_a: T.Buffer(Scales_A_shape, "float32"),
B: T.Buffer(B_shape, in_dtype),
scales_b: T.Buffer(Scales_B_shape, "float32"),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
Scale_C_shared = T.alloc_shared((block_M), "float32")
C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=4):
# Load A into shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
# Load B into shared memory
T.copy(B[bx * block_N, k * block_K], B_shared)
# Load scale into shared memory
Scale_B = scales_b[bx, k]
for i in T.Parallel(block_M):
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Promote to enable 2xAcc
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local)
# TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n
), (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
x_view.size(0), x_view.size(2)
)
def fp8_gemm_deepgemm(
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
y_fp8: torch.Tensor,
y_scale: torch.Tensor,
m: int,
n: int,
k: int,
):
"""DeepGEMM implementation of FP8 GEMM"""
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out
def fp8_gemm_sglang(
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
y_fp8: torch.Tensor,
y_scale: torch.Tensor,
m: int,
n: int,
k: int,
):
"""SGLang implementation of FP8 GEMM"""
block_size = [128, 128] # Matches the block size in per_block_cast_to_fp8
# Run SGLang kernel
out = w8a8_block_fp8_matmul(
x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
)
return out
def fp8_gemm_vllm(
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
y_fp8: torch.Tensor,
y_scale: torch.Tensor,
m: int,
n: int,
k: int,
):
"""vLLM implementation of FP8 GEMM"""
block_size = [128, 128] # Matches the block size in per_block_cast_to_fp8
# Run vLLM kernel
out = vllm_w8a8_block_fp8_matmul(
x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
)
return out
def calculate_diff(m: int, n: int, k: int):
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
out_deepgemm = fp8_gemm_deepgemm(
x_fp8.clone(),
x_scale_col_major.clone(),
y_fp8.clone(),
y_scale.clone(),
m,
n,
k,
)
out_sglang = fp8_gemm_sglang(
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k
)
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
out_tilelang = tilelang_kernel(
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone()
)
diff_sglang_deepgemm = torch.abs(out_deepgemm - out_sglang).mean().item()
diff_tilelang_deepgemm = torch.abs(out_deepgemm - out_tilelang).mean().item()
diff_tilelang_sglang = torch.abs(out_tilelang - out_sglang).mean().item()
print(f"Shape m={m}, n={n}, k={k}:")
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
print(f"SGLang output: {out_sglang[0, 0:5]}")
print(f"TileLang output: {out_tilelang[0, 0:5]}")
print(f"Mean absolute difference (SGLang-DeepGEMM): {diff_sglang_deepgemm}")
print(f"Mean absolute difference (TileLang-DeepGEMM): {diff_tilelang_deepgemm}")
print(f"Mean absolute difference (TileLang-SGLang): {diff_tilelang_sglang}")
sglang_deepgemm_match = torch.allclose(
out_deepgemm, out_sglang, atol=1e-2, rtol=1e-2
)
tilelang_deepgemm_match = torch.allclose(
out_deepgemm, out_tilelang, atol=1e-2, rtol=1e-2
)
tilelang_sglang_match = torch.allclose(
out_tilelang, out_sglang, atol=1e-2, rtol=1e-2
)
if sglang_deepgemm_match and tilelang_deepgemm_match and tilelang_sglang_match:
print("✅ All implementations match\n")
else:
print("❌ Some implementations differ:")
print(f" - SGLang vs DeepGEMM: {'✅' if sglang_deepgemm_match else '❌'}")
print(f" - TileLang vs DeepGEMM: {'✅' if tilelang_deepgemm_match else '❌'}")
print(f" - TileLang vs SGLang: {'✅' if tilelang_sglang_match else '❌'}\n")
def get_weight_shapes(tp_size):
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
weight_shapes = []
for t in total:
weight_shapes.append(t)
for n_t in n_tp:
new_t = (n_t[0] // tp_size, n_t[1])
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = (k_t[0], k_t[1] // tp_size)
weight_shapes.append(new_t)
return weight_shapes
def create_benchmark_configs(tp_size):
configs = []
weight_shapes = get_weight_shapes(tp_size)
batch_sizes = [8, 16, 32, 64, 128, 256, 1024, 2048, 4096]
for n, k in weight_shapes:
for m in batch_sizes:
configs.append((m, n, k, tp_size))
return configs
def get_benchmark(tp_size):
all_configs = create_benchmark_configs(tp_size)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["m", "n", "k", "tp_size"],
x_vals=[list(config) for config in all_configs],
line_arg="provider",
line_vals=["deepgemm", "sglang", "tilelang"],
line_names=["DeepGEMM", "SGLang", "TileLang"],
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
ylabel="ms",
plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
args={},
)
)
def benchmark(m, n, k, tp_size, provider):
print(f"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}")
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
# Preprocess data before benchmarking
x_fp8, x_scale = per_token_cast_to_fp8(x)
y_fp8, y_scale = per_block_cast_to_fp8(y)
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
quantiles = [0.5, 0.2, 0.8]
if provider == "deepgemm":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_deepgemm(
x_fp8.clone(),
x_scale_col_major.clone(),
y_fp8.clone(),
y_scale.clone(),
m,
n,
k,
),
quantiles=quantiles,
)
elif provider == "sglang":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_sglang(
x_fp8.clone(),
x_scale.clone(),
y_fp8.clone(),
y_scale.clone(),
m,
n,
k,
),
quantiles=quantiles,
)
else: # tilelang
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: tilelang_kernel(
x_fp8.clone(),
x_scale.clone(),
y_fp8.clone(),
y_scale.clone(),
),
quantiles=quantiles,
)
# Calculate TFLOPS
flops = 2 * m * n * k # multiply-adds
tflops = flops / (ms * 1e-3) / 1e12
# Print shape-specific results with TFLOPS
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/fp8_gemm/",
help="Path to save fp8 gemm benchmark results",
)
parser.add_argument(
"--run_correctness",
action="store_true",
default=True,
help="Whether to run correctness test",
)
parser.add_argument(
"--tp_size",
type=int,
default=1,
help="Tensor parallelism size to benchmark (default: 1)",
)
args = parser.parse_args()
# Set random seed for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Run correctness tests on a few examples
if args.run_correctness:
print("Running correctness tests...")
calculate_diff(64, 512, 7168) # Small test
calculate_diff(64, 7168, 16384) # Medium test
calculate_diff(64, 18432, 7168) # Large test
# Get the benchmark function with the specified tp_size
benchmark = get_benchmark(args.tp_size)
print(f"Running performance benchmark for TP size = {args.tp_size}...")
benchmark.run(print_data=True, save_path=args.save_path)
from typing import Tuple
import deep_gemm
import torch
import triton
import triton.language as tl
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor
# Import shared functionality from the regular GEMM benchmark
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
per_block_cast_to_fp8,
per_token_cast_to_fp8,
)
def construct_grouped_and_flat_fp8(
x: torch.Tensor, y: torch.Tensor, num_groups: int, is_masked: bool
) -> Tuple[
Tuple[torch.Tensor, torch.Tensor], # grouped x_fp8
Tuple[torch.Tensor, torch.Tensor], # grouped y_fp8
Tuple[torch.Tensor, torch.Tensor], # flat x_fp8
Tuple[torch.Tensor, torch.Tensor], # flat y_fp8
torch.Tensor, # output
torch.Tensor, # reference output
]:
# Verify input shapes
m, k = x.shape
n, k_y = y.shape
assert k == k_y, f"Incompatible shapes: x({m}, {k}), y({n}, {k_y})"
assert m % num_groups == 0, f"m({m}) must be divisible by num_groups({num_groups})"
assert m % 4 == 0, f"TMA alignment error: {m}"
# Reshape inputs for grouped processing
m_per_group = m // num_groups
x_grouped = x.view(num_groups, m_per_group, k)
y_grouped = y.unsqueeze(0).expand(num_groups, n, k)
# Initialize output tensors
out = torch.empty((num_groups, m_per_group, n), device="cuda", dtype=torch.bfloat16)
ref_out = torch.einsum("gmk,gnk->gmn", x_grouped, y_grouped)
# Quantize grouped tensors
x_fp8_grouped = (
torch.empty_like(x_grouped, dtype=torch.float8_e4m3fn),
torch.empty(
(num_groups, m_per_group, k // 128), device="cuda", dtype=torch.float
),
)
y_fp8_grouped = (
torch.empty_like(y_grouped, dtype=torch.float8_e4m3fn),
torch.empty(
(num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float
),
)
for i in range(num_groups):
x_fp8_grouped[0][i], x_fp8_grouped[1][i] = per_token_cast_to_fp8(x_grouped[i])
y_fp8_grouped[0][i], y_fp8_grouped[1][i] = per_block_cast_to_fp8(y_grouped[i])
# Quantize flat tensors
x_fp8_flat = per_token_cast_to_fp8(x)
y_fp8_flat = per_block_cast_to_fp8(y)
# For non-masked input, merge the group and M dims in output
if not is_masked:
x_fp8_grouped = (
x_fp8_grouped[0].view(-1, k),
per_token_cast_to_fp8(x_grouped.view(-1, k))[1],
)
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
# Transpose earlier for testing
x_fp8_grouped = (
x_fp8_grouped[0],
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
)
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
# Since we don't have a group gemm kernel in SGLang/vLLM, we implemented a
# custom kernel based on the Triton tutorial.
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
@triton.jit
def fp8_gemm_group_triton_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
# Pointers to scaling factors
a_scale_ptr,
b_scale_ptr,
# Matrix dimensions
M,
N,
K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension.
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Strides for scaling factors
stride_a_scale_m,
stride_a_scale_k,
stride_b_scale_n,
stride_b_scale_k,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
Note: Block sizes must be multiples of 32 for optimal TMA performance.
"""
# Map program ids to the block of C it should compute
pid_group = tl.program_id(axis=0) # Group ID
pid_n = tl.program_id(axis=1) # N dimension ID
# Compute the M block ID within this group
group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M)
pid_m_within_group = tl.program_id(axis=2) % group_size_m
pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group
# Create pointers for the first blocks of A and B
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# Initialize accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Main loop
for k_block in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
k_offset = k_block * BLOCK_SIZE_K
# Load the next block of A and B, with masks
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k_offset, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_offset, other=0.0)
# Calculate indices for scaling factors for this K block
a_scale_ptrs = a_scale_ptr + (
offs_am * stride_a_scale_m + k_block * stride_a_scale_k
)
b_scale_ptrs = b_scale_ptr + (
pid_n * stride_b_scale_n + k_block * stride_b_scale_k
)
# Perform matrix multiplication in FP8
res = tl.dot(a, b)
# Load scaling factors for the current block
a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1]
b_scale = tl.load(b_scale_ptrs)
# Apply scaling factors to the accumulated result
accumulator += res * a_scale * b_scale
# Advance pointers
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# Convert to bfloat16 for output
c = accumulator.to(tl.bfloat16)
# Write back the result
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
"""
Perform matrix multiplication with FP8 inputs and proper scaling.
Args:
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
c: Output tensor in BF16 format
num_groups: Number of groups for grouped GEMM
Returns:
Result tensor in BF16 format
"""
# Unpack the tuples
a, a_scale = a_tuple
b, b_scale = b_tuple
M, K = a.shape
_, N = b.shape
# Configure block sizes - must be multiples of 32 for TMA alignment
BLOCK_SIZE_M = 128
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 128
# Calculate grid dimensions
num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)
num_groups_grid = triton.cdiv(num_pid_m, num_groups)
# 3D grid launch - (group, n_blocks, m_blocks_per_group)
grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m))
fp8_gemm_group_triton_kernel[grid](
a,
b,
c,
a_scale,
b_scale,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
a_scale.stride(0),
1, # Stride in the K dimension may be 1
b_scale.stride(0),
1 if b_scale.dim() > 1 else 0,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=num_groups,
)
return c
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
x_fp8_grouped,
y_fp8_grouped,
out,
m_indices,
)
return out
def calculate_diff(m: int, n: int, k: int, num_groups: int):
print(f"Shape (m={m}, n={n}, k={k}")
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
)
m_per_group = m // num_groups
out_deepgemm = out.clone()
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
m_indices = (
m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1)
)
fp8_gemm_group_deepgemm(
x_fp8_grouped,
y_fp8_grouped,
out_deepgemm,
m_indices,
)
torch.cuda.synchronize()
# Prepare inputs for Triton
a, a_scale = x_fp8_flat
b, b_scale = y_fp8_flat
b = b.T.contiguous()
# Ensure scales are in the right format and contiguous
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
M, _ = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups)
torch.cuda.synchronize()
diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item()
diff_torch_triton = torch.abs(out_torch - out_triton).mean().item()
diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item()
print(f"Shape m={m}, n={n}, k={k}:")
print(f"Torch output: {out_torch[0, 0:5]}")
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
print(f"Triton output: {out_triton[0, 0:5]}")
print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}")
print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}")
print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}")
deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch)
triton_torch_diff = calc_diff(out_triton, out_torch)
deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton)
DIFF_THRESHOLD = 0.001
all_match = (
deepgemm_torch_diff < DIFF_THRESHOLD
and triton_torch_diff < DIFF_THRESHOLD
and deepgemm_triton_diff < DIFF_THRESHOLD
)
if all_match:
print("✅ All implementations match\n")
else:
print("❌ Some implementations differ:")
print(
f" - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}"
f" - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}"
f" - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}"
)
def get_weight_shapes(tp_size):
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
weight_shapes = []
for t in total:
weight_shapes.append(t)
for n_t in n_tp:
new_t = (n_t[0] // tp_size, n_t[1])
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = (k_t[0], k_t[1] // tp_size)
weight_shapes.append(new_t)
return weight_shapes
def create_benchmark_configs(tp_size):
configs = []
weight_shapes = get_weight_shapes(tp_size)
batch_sizes = [2048, 4096]
group_sizes = [4, 8]
for n, k in weight_shapes:
for m in batch_sizes:
for num_groups in group_sizes:
configs.append((m, n, k, num_groups, tp_size))
return configs
def get_benchmark(tp_size):
all_configs = create_benchmark_configs(tp_size)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["m", "n", "k", "num_groups", "tp_size"],
x_vals=[config for config in all_configs],
line_arg="provider",
line_vals=["deepgemm", "triton"],
line_names=["DeepGEMM", "Triton"],
styles=[("blue", "-"), ("red", "-")],
ylabel="ms",
plot_name=f"fp8-group-gemm-performance-comparison-tp{tp_size}",
args={},
)
)
def benchmark(m, n, k, num_groups, tp_size, provider):
print(
f"Shape (m={m}, n={n}, k={k}, tp={tp_size}, num_groups={num_groups}, Provider: {provider}"
)
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
)
m_per_group = m // num_groups
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
m_indices = (
m_indices.unsqueeze(-1)
.expand(num_groups, m_per_group)
.contiguous()
.view(-1)
)
quantiles = [0.5, 0.2, 0.8]
if provider == "deepgemm":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_group_deepgemm(
x_fp8_grouped,
y_fp8_grouped,
out,
m_indices,
),
quantiles=quantiles,
)
elif provider == "triton":
# Prepare inputs for Triton
# We did it outside of the lambda function to make it fair comparison like deepgemm
a, a_scale = x_fp8_flat
b, b_scale = y_fp8_flat
b = b.T.contiguous()
# Ensure scales are in the right format and contiguous
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
M, _ = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_group_triton(
(a, a_scale),
(b, b_scale),
c,
num_groups,
),
quantiles=quantiles,
)
# Calculate TFLOPS
flops = 2 * m * n * k # multiply-adds
tflops = flops / (ms * 1e-3) / 1e12
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/fp8_group_gemm/",
help="Path to save deepgemm fp8 group gemm benchmark results",
)
parser.add_argument(
"--run_correctness",
action="store_true",
help="Whether to run correctness test",
)
parser.add_argument(
"--tp_size",
type=int,
default=1,
help="Tensor parallelism size to benchmark (default: 1)",
)
args = parser.parse_args()
# Set random seed for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Run correctness tests on a few examples
if args.run_correctness:
print("Running correctness tests...")
calculate_diff(8192, 7168, 4096, 4)
calculate_diff(8192, 2048, 7168, 4)
calculate_diff(4096, 7168, 4096, 8)
calculate_diff(4096, 2048, 7168, 8)
calculate_diff(4096, 576, 7168, 8)
# Get the benchmark function with the specified tp_size
benchmark = get_benchmark(args.tp_size)
print(f"Running performance benchmark for TP size = {args.tp_size}...")
benchmark.run(print_data=True, save_path=args.save_path)
## Benchmark FBGEMM Grouped GEMM
Benchmark FBGEMM Grouped GEMM in both Triton and CUDA version and SGLang Triton Grouped GEMM, it will be used to compare the bandwidth of different implementations.
### Requirements
```shell
pip install fbgemm-gpu-genai
```
### Usage
```bash
python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
```
For example, in H200, the Qwen2-57B-A14B-Instruct TP4 fp8w8a8 grouped gemm bandwidth result is as follows:
```shell
grouped-gemm-performance:
batch_size FBGEMM Triton Grouped GEMM FP8 FBGEMM CUTLASS F8F8BF16 Rowwise SGLang Grouped GEMM FP8
0 256.0 3704.841339 3042.626402 2254.725030
1 512.0 3691.426346 3029.065684 2269.504543
2 1024.0 3653.938629 2258.471467 2358.319020
3 2048.0 3596.644313 2271.611904 2476.895397
4 4096.0 3468.496435 2231.283986 2179.473910
```
The theoretical peak bandwidth of H200 is 4.8 TB/s. Taking batch_size 256 as an example, the bandwidth of FBGEMM Triton Grouped GEMM FP8 is 3704.841339 GB/s, the bandwidth of FBGEMM CUTLASS F8F8BF16 Rowwise is 3042.626402 GB/s, and the bandwidth of SGLang Grouped GEMM FP8 is 2254.725030 GB/s. Therefore, FBGEMM Triton Grouped GEMM FP8 achieves 77.9% of H200's theoretical peak bandwidth, FBGEMM CUTLASS F8F8BF16 Rowwise achieves 63.4% of H200's theoretical peak bandwidth, and SGLang Grouped GEMM FP8 achieves 46.9% of H200's theoretical peak bandwidth.
# python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
import argparse
import torch
import triton
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
quantize_fp8_row,
triton_quantize_fp8_row,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm as fbgemm_grouped_gemm,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
)
from transformers import AutoConfig
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton as sglang_grouped_gemm,
)
def get_model_config(model_name: str, tp_size: int):
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
num_groups = config.ffn_config.moe_num_experts
intermediate_size = config.ffn_config.ffn_hidden_size
elif config.architectures[0] == "JambaForCausalLM":
num_groups = config.num_experts
intermediate_size = config.intermediate_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
]:
num_groups = config.n_routed_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
num_groups = config.text_config.num_local_experts
intermediate_size = config.text_config.intermediate_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
num_groups = config.num_local_experts
intermediate_size = config.moe_intermediate_size
else:
num_groups = config.num_local_experts
intermediate_size = config.intermediate_size
shape_configs = {
"num_groups": num_groups,
"hidden_size": config.hidden_size,
"intermediate_size": intermediate_size,
"dtype": config.torch_dtype,
}
print(f"{shape_configs=}")
return shape_configs
def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
torch.manual_seed(42)
tokens_per_group = batch_size // num_groups
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
base_weights = torch.randn(
num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda"
)
w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size)
w_sglang = base_weights
c_fbgemm = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
c_sglang = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda")
for i in range(1, num_groups + 1):
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda")
return (
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
)
def create_fp8_test_data(
batch_size, num_groups, hidden_size, intermediate_size, backend="triton"
):
"""
Create test data for FP8 grouped GEMM operations.
Args:
batch_size: Total batch size
num_groups: Number of groups
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM
Returns:
For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale)
For cutlass: (x, wq, w_scale, m_sizes)
"""
torch.manual_seed(42)
tokens_per_group = batch_size // num_groups
# Create weight matrices for each group
w_list = []
for _ in range(num_groups):
w = torch.randn(
intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
)
w_list.append(w)
# Quantize weights using quantize_fp8_row for each group
wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list])
if backend == "triton":
# Triton format: concatenated weights
w_fp8 = torch.concat(wq_list, dim=0).contiguous()
w_scale = torch.concat(w_scale_list, dim=0).contiguous()
# Create m_sizes as int32 for triton
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
)
# Create and quantize input
x_fp16 = torch.randn(
batch_size, hidden_size, dtype=torch.float16, device="cuda"
)
x_fp8, x_scale = triton_quantize_fp8_row(x_fp16)
x_scale = x_scale.view(batch_size, -1)
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
elif backend == "cutlass":
# CUTLASS format: stacked weights
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
# Create m_sizes as int64 for cutlass
m_values = [tokens_per_group] * num_groups
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda")
# Create input data - separate for each group then concat
x_list = []
for _ in range(num_groups):
x = torch.randn(
tokens_per_group, hidden_size, dtype=torch.float16, device="cuda"
)
x_list.append(x)
# Concatenate inputs into single tensor
x = torch.concat(x_list, dim=0).contiguous()
return x, wq, w_scale, m_sizes
else:
raise ValueError(f"Unsupported backend: {backend}")
def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype):
"""
Calculate memory bandwidth based on accessed expert weights.
Args:
m_sizes: Tensor containing batch sizes for each group
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
dtype: Data type of weights
Returns:
Memory size in bytes for accessed expert weights
"""
# Count non-zero groups (active experts)
if hasattr(m_sizes, "cpu"):
active_experts = torch.count_nonzero(m_sizes).item()
else:
active_experts = sum(1 for m in m_sizes if m > 0)
# Calculate bytes per element based on dtype
if dtype in [torch.float16, torch.bfloat16]:
bytes_per_element = 2
elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
bytes_per_element = 1
elif dtype == torch.float32:
bytes_per_element = 4
else:
# Default to 2 bytes for unknown dtypes
bytes_per_element = 2
# Memory per expert weight matrix
memory_per_expert = hidden_size * intermediate_size * bytes_per_element
# Total memory for active experts
total_memory_bytes = active_experts * memory_per_expert
return total_memory_bytes
def get_benchmark_config(use_fp8_w8a8=False):
if use_fp8_w8a8:
return {
"line_vals": [
"fbgemm_triton_grouped_gemm_fp8",
"fbgemm_cutlass_f8f8bf16_rowwise",
"sglang_grouped_gemm",
],
"line_names": [
"FBGEMM Triton Grouped GEMM FP8",
"FBGEMM CUTLASS F8F8BF16 Rowwise",
"SGLang Grouped GEMM FP8",
],
"styles": [("blue", "-"), ("orange", "-"), ("red", "-")],
}
else:
return {
"line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"],
"line_names": [
"FBGEMM Triton Grouped GEMM BF16",
"SGLang Grouped GEMM BF16",
],
"styles": [("blue", "-"), ("green", "-")],
}
def run_benchmark(
model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/"
):
config = get_benchmark_config(use_fp8_w8a8)
benchmark_config = triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[256, 512, 1024, 2048, 4096],
line_arg="provider",
line_vals=config["line_vals"],
line_names=config["line_names"],
styles=config["styles"],
ylabel="Bandwidth (GB/s)",
plot_name="grouped-gemm-performance",
args={},
)
@triton.testing.perf_report(benchmark_config)
def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"Benchmarking {provider} with batch_size={batch_size}")
torch.cuda.manual_seed_all(0)
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]
if provider == "fbgemm_triton_grouped_gemm_fp8":
try:
test_data = create_fp8_test_data(
batch_size,
num_groups,
hidden_size,
intermediate_size,
backend="triton",
)
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
# Calculate memory bandwidth
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
)
def run_func():
return fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)
except Exception as e:
print(f"FP8 not supported, skipping: {e}")
return float("inf"), float("inf"), float("inf")
elif provider == "fbgemm_cutlass_f8f8bf16_rowwise":
try:
test_data = create_fp8_test_data(
batch_size,
num_groups,
hidden_size,
intermediate_size,
backend="cutlass",
)
x, wq, w_scale, m_sizes = test_data
# Calculate memory bandwidth
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
)
# Quantize input using triton_quantize_fp8_row
xq, x_scale = triton_quantize_fp8_row(x)
x_scale = x_scale.view(batch_size, -1)
def run_func():
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
xq, wq, x_scale, w_scale, m_sizes
)
except Exception as e:
print(
f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, "
f"skipping: {e}"
)
return float("inf"), float("inf"), float("inf")
else:
test_data = create_test_data(
batch_size, num_groups, hidden_size, intermediate_size
)
(
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
) = test_data
# Calculate memory bandwidth for BF16 operations
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.bfloat16
)
if provider == "fbgemm_triton_grouped_gemm":
def run_func():
return fbgemm_grouped_gemm(
x, w_fbgemm, m_sizes, use_fast_accum=True
)
else:
def run_func():
return sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
for _ in range(10):
try:
run_func()
except Exception as e:
print(f"Error during warmup for {provider}: {e}")
return float("inf"), float("inf"), float("inf")
torch.cuda.synchronize()
try:
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
# Convert time (ms) to bandwidth (GB/s)
# Bandwidth = Memory (bytes) / Time (seconds)
# Convert ms to seconds and bytes to GB (1e9)
gb_per_s = (memory_bytes / 1e9) / (ms / 1000)
# min bandwidth = max time, max bandwidth = min time
min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000)
max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000)
return gb_per_s, min_gb_per_s, max_gb_per_s
except Exception as e:
print(f"Error during benchmarking for {provider}: {e}")
return 0.0, 0.0, 0.0
dynamic_benchmark.run(
show_plots=True,
print_data=True,
save_path=save_path,
model_config=model_config,
use_fp8_w8a8=use_fp8_w8a8,
)
def verify_correctness(model_config):
print("Verifying correctness...")
batch_size = 128
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
(
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
) = test_data
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
print("✓ BF16 Correctness verification passed!")
else:
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
return False
return True
def main():
parser = argparse.ArgumentParser(
description="Benchmark FBGEMM vs SGLang Grouped GEMM"
)
parser.add_argument(
"--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
help="Model name to get configuration from",
)
parser.add_argument(
"--tp-size", type=int, default=1, help="Tensor parallelism size"
)
parser.add_argument(
"--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark"
)
parser.add_argument(
"--save-path",
type=str,
default="./benchmark_grouped_gemm/",
help="Path to save benchmark results",
)
parser.add_argument(
"--verify-correctness",
action="store_true",
help="Verify correctness before benchmarking",
)
args = parser.parse_args()
try:
model_config = get_model_config(args.model, args.tp_size)
except Exception as e:
print(f"Failed to get model config: {e}")
print("Using default configuration...")
model_config = {
"num_groups": 8,
"hidden_size": 4096,
"intermediate_size": 14336,
"dtype": torch.bfloat16,
}
print("Running benchmark with:")
print(f" num_groups: {model_config['num_groups']}")
print(f" hidden_size: {model_config['hidden_size']}")
print(f" intermediate_size: {model_config['intermediate_size']}")
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
if args.verify_correctness:
if not verify_correctness(model_config):
print("Correctness verification failed. Exiting...")
return
try:
run_benchmark(
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
save_path=args.save_path,
)
except Exception as e:
print(f"Benchmark failed: {e}")
if __name__ == "__main__":
main()
# FlashInfer Fused AllReduce + RMSNorm Benchmark
This benchmark script is modified from the [original implementation](https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py) by the vLLM community. It aims to compare the performance differences between FlashInfer fused operators in SGLang (trtllm_allreduce_fusion: AllReduce + Residual Add + RMSNorm + optional quantization) and conventional implementations (standard `tensor_model_parallel_all_reduce` + separate RMSNorm/quantization). Specifically, this script tests the timing performance of two implementation paths: 1) Standard AllReduce and RMSNorm executed separately; 2) FlashInfer's fused operator combining AllReduce, Residual Add, RMSNorm, and optional quantization operations.
This benchmark script helps us tune the ipc workspace size of the `flashinfer_allreduce_residual_rmsnorm` operator in SGLang and prepare for applications with FP8/FP4 quantized fused operators.
Script path: `benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py`
## Feature Overview
- Compare average execution time (ms) and calculate speedup ratios for the following paths:
- standard_allreduce_rmsnorm (Standard AllReduce + RMSNorm)
- flashinfer_fused_allreduce_rmsnorm (Fused AllReduce + RMSNorm), including oneshot and twoshot modes
- Optionally compare FP8/FP4 quantized fused paths with standard paths
- Use CUDA Graph capture and batch replay to reduce measurement noise
- Automatically select the faster "standard baseline" (native/compiled version) as the denominator for speedup calculation
- Optionally export results in Markdown format
## Runtime Environment and Prerequisites
- At least 2 GPUs, and launch multi-process distributed training using `torchrun` (NCCL backend)
- Properly install/compile sglang along with sgl-kernel and custom operators
## Quick Start (Command Examples)
The following examples use world_size=2. You can modify `--nproc_per_node` and parameters according to your machine:
- Regular paths only (no quantization):
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--no-quant --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
```
- FP8 quantization paths only:
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--quant-fp8 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
```
- FP4 quantization paths only:
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--quant-fp4 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
```
- Larger hidden dimensions:
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--no-quant --hidden-dim 4096 --seq-lens 512 1024 2048 4096 --trials 100
```
## Parameter Description
- `--seq-lens`: List of sequence lengths to test (default: 128 512 1024 2048)
- `--hidden-dim`: Hidden dimension (default: 8192)
- `--dtypes`: Data type list, `float16|bfloat16|float32` (default: bfloat16)
- `--no-residual`: Only test "no residual" scenarios (default tests both "with/without residual")
- Mutually exclusive quantization options:
- `--no-quant`: No quantization testing
- `--quant-fp8`: Only FP8 quantization testing
- `--quant-fp4`: Only FP4 quantization testing
- `--quant-all`: Test all (default)
- FlashInfer related:
- `--disable-oneshot`: Disable oneshot mode (default enables oneshot and tests twoshot simultaneously)
- Runtime configuration:
- `--warmup`: Warmup count before graph capture and before graph replay (default 5)
- `--trials`: Benchmark iteration count (default 20; internally each `graph.replay()` will batch replay multiple times)
- `--output-file`: Save results as Markdown file (only rank0 takes effect)
## Output Example
Each configuration group prints a table showing average execution time and relative speedup ratios (baseline is the faster standard implementation). For example:
```
================================================================================
Results: seq_len=1024, hidden_dim=1024
dtype=torch.bfloat16, residual=yes, quant_mode=none
================================================================================
Operation Time (ms) Speedup
--------------------------------------------------------------------------------
standard_allreduce_rmsnorm 0.024 0.98x
standard_allreduce_rmsnorm_native_compiled 0.023 baseline
flashinfer_fused_allreduce_rmsnorm_oneshot 0.011 2.19x
flashinfer_fused_allreduce_rmsnorm_twoshot 0.041 0.57x
```
If `--output-file` is specified, all configurations will be summarized in Markdown tables in that file.
## Important Notes and Recommendations
- Distributed: The script uses `torchrun` environment variables to initialize distributed training and binds tensors/communication groups to the current rank's corresponding device.
- World size: Requires `WORLD_SIZE > 1` to perform communication operator benchmarks. Otherwise, the script will error and prompt.
- FlashInfer:
- If not installed or interfaces are missing, the script will only run standard paths and provide prompts in the logs.
- The fused operator internally uses "oneshot"/"twoshot" two trigger methods; oneshot is enabled by default and twoshot is tested simultaneously.
- FP8/FP4:
- FP8 uses sglang's FP8 tools and dtype, with underlying platform selection of `e4m3`/`e4m3fnuz` etc.
- FP4 uses sgl-kernel's `scaled_fp4_quant`, requiring corresponding platform support.
- CUDA Graph:
- Uses sglang's `graph_capture()` to prepare capture-ready state for communication, then uses `torch.cuda.graph` to capture kernels, reducing measurement jitter.
# Modified from https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py
"""
Benchmark for FlashInfer fused collective operations vs standard operations.
This benchmark compares:
1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant)
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
Usage with torchrun:
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --no-quant --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp8 --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp4 --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --no-quant --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp8 --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp4 --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100
"""
import argparse
import contextlib
import itertools
import logging
import os
import time
from typing import Optional
import torch # type: ignore
import torch.distributed as dist # type: ignore
from sglang.srt.distributed import get_tp_group, tensor_model_parallel_all_reduce
from sglang.srt.distributed.parallel_state import (
cleanup_dist_env_and_memory,
graph_capture,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.layernorm import RMSNorm # noqa
from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype as SGLANG_FP8_DTYPE
from sglang.srt.layers.quantization.fp8_kernel import static_quant_fp8
try:
from sgl_kernel import fused_add_rmsnorm as SGL_FUSED_ADD_RMS_NORM
from sgl_kernel import rmsnorm as SGL_RMS_NORM
from sgl_kernel import scaled_fp4_quant as SGL_SCALED_FP4_QUANT
except Exception: # pragma: no cover - fallback on non-supported platforms
SGL_FUSED_ADD_RMS_NORM = None
SGL_RMS_NORM = None
SGL_SCALED_FP4_QUANT = None
FP8_DTYPE = SGLANG_FP8_DTYPE
logger = logging.getLogger(__name__)
# Try to import FlashInfer
try:
import flashinfer.comm as flashinfer_comm # type: ignore
if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"):
flashinfer_comm = None
logger.warning(
"FlashInfer comm module found but missing trtllm_allreduce_fusion"
)
except ImportError:
flashinfer_comm = None
logger.warning("FlashInfer not found, only benchmarking standard operations")
# Constants
MiB = 1024 * 1024
# FlashInfer max sizes per world size
# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes
# use --disable-oneshot to disable oneshot mode for very large input sizes
_FI_MAX_SIZES = {
2: 64 * MiB, # 64MB
4: 64 * MiB, # 64MB
8: 64 * MiB, # 64MB
}
# Global workspace tensor for FlashInfer
_FI_WORKSPACE_TENSOR = None
def setup_flashinfer_workspace(
world_size: int,
rank: int,
hidden_dim: int,
max_token_num: int,
use_fp32_lamport: bool = False,
):
"""Setup FlashInfer workspace for fused allreduce operations."""
global _FI_WORKSPACE_TENSOR
if flashinfer_comm is None:
return None, None
if world_size not in _FI_MAX_SIZES:
logger.warning("FlashInfer not supported for world size %s", world_size)
return None, None
try:
# Create IPC workspace
ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=world_size,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
group=get_tp_group().device_group,
use_fp32_lamport=use_fp32_lamport,
)
)
_FI_WORKSPACE_TENSOR = workspace_tensor
return ipc_handles, workspace_tensor
except Exception as e:
logger.error("Failed to setup FlashInfer workspace: %s", e)
return None, None
def cleanup_flashinfer_workspace(ipc_handles):
"""Cleanup FlashInfer workspace."""
if flashinfer_comm is None or ipc_handles is None:
return
try:
group = get_tp_group().device_group
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group)
except Exception as e:
logger.error("Failed to cleanup FlashInfer workspace: %s", e)
class FlashInferFusedAllReduceParams:
"""Parameters for FlashInfer fused allreduce operations."""
def __init__(
self,
rank: int,
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024,
):
self.rank = rank
self.world_size = world_size
self.use_fp32_lamport = use_fp32_lamport
self.trigger_completion_at_end = True
self.launch_with_pdl = True
self.fp32_acc = True
self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self):
return {
"world_rank": self.rank,
"world_size": self.world_size,
"launch_with_pdl": self.launch_with_pdl,
"trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc,
}
def flashinfer_fused_allreduce_rmsnorm(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
allreduce_params: "FlashInferFusedAllReduceParams",
use_oneshot: bool,
norm_out: Optional[torch.Tensor] = None,
):
"""FlashInfer fused allreduce + rmsnorm operation."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
norm_out = input_tensor
residual_out = residual
else:
residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
allreduce_out=None,
quant_out=None,
scale_out=None,
layout_code=None,
scale_factor=None,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
scale_factor: torch.Tensor,
allreduce_params: FlashInferFusedAllReduceParams,
use_oneshot: bool = True,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
norm_out = input_tensor
residual_out = residual
else:
residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
allreduce_out=None,
quant_out=quant_out,
scale_out=None,
layout_code=None,
scale_factor=scale_factor,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
input_global_scale: torch.Tensor,
allreduce_params: FlashInferFusedAllReduceParams,
quant_out: torch.Tensor,
use_oneshot: bool,
output_scale: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
norm_out = input_tensor
residual_out = residual
else:
residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
allreduce_out=None,
quant_out=quant_out,
scale_out=output_scale,
layout_code=None,
scale_factor=input_global_scale,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
def standard_allreduce_rmsnorm(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
norm_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm operations."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Then RMS norm
if residual is not None:
# Fused add + RMS norm (in-place on allreduce_out)
if SGL_FUSED_ADD_RMS_NORM is not None:
SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
rms.forward_native(allreduce_out, residual)
else:
# Just RMS norm
if SGL_RMS_NORM is not None:
_ = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
_ = rms.forward_native(allreduce_out)
def standard_allreduce_rmsnorm_fp8_quant(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
scale_factor: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm + FP8 quantization."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Then RMS norm + static FP8 quantization
if residual is not None:
if SGL_FUSED_ADD_RMS_NORM is not None:
SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps)
quant_out, _ = static_quant_fp8(
allreduce_out, scale_factor, repeat_scale=False
)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
normed, _ = rms.forward_native(allreduce_out, residual)
quant_out, _ = static_quant_fp8(normed, scale_factor, repeat_scale=False)
return quant_out, residual
else:
if SGL_RMS_NORM is not None:
normed = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
normed = rms.forward_native(allreduce_out)
quant_out, _ = static_quant_fp8(normed, scale_factor, repeat_scale=False)
return quant_out
def standard_allreduce_rmsnorm_fp4_quant(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
input_global_scale: torch.Tensor,
quant_out: torch.Tensor,
output_scale: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm + FP4 quantization."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Then RMS norm
if residual is not None:
if SGL_FUSED_ADD_RMS_NORM is not None:
SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps)
quant_input = allreduce_out
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
quant_input, _ = rms.forward_native(allreduce_out, residual)
residual_out = residual
else:
if SGL_RMS_NORM is not None:
quant_input = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
quant_input = rms.forward_native(allreduce_out)
residual_out = allreduce_out
# Finally FP4 quantization
if SGL_SCALED_FP4_QUANT is None:
raise RuntimeError("scaled_fp4_quant is not available on this platform")
quant_res, output_scale_res = SGL_SCALED_FP4_QUANT(quant_input, input_global_scale)
if residual is not None:
return quant_res, residual_out, output_scale_res
else:
return quant_res, quant_input
def standard_allreduce_rmsnorm_native(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
norm_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm operations using native RMSNorm forward."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Apply native RMSNorm
if residual is not None:
result = rmsnorm_layer.forward_native(allreduce_out, residual)
return result # Returns (norm_out, residual_out)
else:
result = rmsnorm_layer.forward_native(allreduce_out)
return result # Returns norm_out
def standard_allreduce_rmsnorm_fp8_quant_native(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
scale_factor: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm + FP8 quantization using native implementations."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Apply native RMSNorm
if residual is not None:
norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual)
else:
norm_out = rmsnorm_layer.forward_native(allreduce_out)
residual_out = allreduce_out
# Apply native FP8 quantization
quant_out, _ = static_quant_fp8(norm_out, scale_factor, repeat_scale=False)
if residual is not None:
return quant_out, residual_out
else:
return quant_out
def standard_allreduce_rmsnorm_fp4_quant_native(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
input_global_scale: torch.Tensor,
quant_out: torch.Tensor,
output_scale: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Apply native RMSNorm
if residual is not None:
norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual)
quant_input = norm_out
else:
norm_out = rmsnorm_layer.forward_native(allreduce_out)
quant_input = norm_out
residual_out = allreduce_out
# Apply FP4 quantization (still using fused CUDA op as there's no native FP4)
if SGL_SCALED_FP4_QUANT is None:
raise RuntimeError("scaled_fp4_quant is not available on this platform")
quant_res, output_scale_res = SGL_SCALED_FP4_QUANT(quant_input, input_global_scale)
if residual is not None:
return quant_res, residual_out, output_scale_res
else:
return quant_res, norm_out
# Compiled versions of native functions
@torch.compile
def standard_allreduce_rmsnorm_native_compiled(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
norm_out: Optional[torch.Tensor] = None,
):
"""Compiled version of standard allreduce + rmsnorm."""
return standard_allreduce_rmsnorm_native(
input_tensor, residual, rmsnorm_layer, norm_out
)
@torch.compile
def standard_allreduce_rmsnorm_fp8_quant_native_compiled(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
scale_factor: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
):
"""Compiled version of standard allreduce + rmsnorm + FP8 quantization."""
return standard_allreduce_rmsnorm_fp8_quant_native(
input_tensor,
residual,
rmsnorm_layer,
scale_factor,
norm_out,
quant_out,
)
@torch.compile
def standard_allreduce_rmsnorm_fp4_quant_native_compiled(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
input_global_scale: torch.Tensor,
quant_out: torch.Tensor,
output_scale: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
):
"""Compiled version of standard allreduce + rmsnorm + FP4 quantization."""
return standard_allreduce_rmsnorm_fp4_quant_native(
input_tensor,
residual,
rmsnorm_layer,
input_global_scale,
quant_out,
output_scale,
norm_out,
)
def create_test_tensors(
seq_len: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True
):
"""Create test tensors for benchmarking."""
input_tensor = torch.randn(seq_len, hidden_dim, dtype=dtype)
residual = (
torch.randn_like(input_tensor)
if use_residual
else torch.zeros_like(input_tensor)
)
rms_gamma = torch.ones(hidden_dim, dtype=dtype)
norm_out = None if use_residual else torch.empty_like(input_tensor)
# Quantization scales
scale_fp8 = torch.tensor(1.0, dtype=torch.float32)
scale_fp4 = torch.tensor(1.0, dtype=torch.float32)
quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE)
# Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)
fp4_quant_out = torch.empty((seq_len, hidden_dim // 2), dtype=torch.uint8)
fp4_output_scale = torch.empty((128, 4), dtype=torch.int32)
return (
input_tensor,
norm_out,
residual,
rms_gamma,
scale_fp8,
quant_out_fp8,
scale_fp4,
fp4_quant_out,
fp4_output_scale,
)
def benchmark_operation(
operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs
):
"""Benchmark a single operation using CUDA graphs."""
# Warmup before graph capture
for _ in range(warmup):
operation_func(*args, **kwargs)
torch.cuda.synchronize()
# Create CUDA graph
graph = torch.cuda.CUDAGraph()
num_op_per_cudagraph = 10
# Use sglang's graph_capture to make tensor_model_parallel_all_reduce graph-safe
with graph_capture() as graph_capture_context:
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
for _ in range(num_op_per_cudagraph):
operation_func(*args, **kwargs)
# Graph warmup
torch.cuda.synchronize()
for _ in range(warmup):
graph.replay()
# Benchmark with CUDA graph
torch.cuda.synchronize()
start_time = time.perf_counter()
for _ in range(trials // num_op_per_cudagraph):
# operation_func(*args, **kwargs)
graph.replay()
torch.cuda.synchronize()
end_time = time.perf_counter()
avg_time_ms = ((end_time - start_time) / trials) * 1000
return avg_time_ms
def run_benchmarks(
seq_len: int,
hidden_dim: int,
dtype: torch.dtype,
use_residual: bool,
allreduce_params: Optional[FlashInferFusedAllReduceParams],
quant_mode: str = "all",
disable_oneshot: bool = False,
):
"""Run all benchmarks for given configuration.
Args:
quant_mode: "none", "fp8_only", "fp4_only", or "all"
"""
(
input_tensor,
norm_out,
residual,
rms_gamma,
scale_fp8,
quant_out_fp8,
scale_fp4,
fp4_quant_out,
fp4_output_scale,
) = create_test_tensors(seq_len, hidden_dim, dtype, use_residual)
rms_eps = 1e-6
results = {}
# Create RMSNorm once for native benchmarks
rmsnorm_layer = RMSNorm(hidden_dim, eps=rms_eps)
rmsnorm_layer.weight.data = rms_gamma
if quant_mode in ["all", "none"]:
# Standard AllReduce + RMSNorm
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
)
results["standard_allreduce_rmsnorm"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm failed: %s", e)
results["standard_allreduce_rmsnorm"] = float("inf")
# Standard AllReduce + RMSNorm Native Compiled
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_native_compiled,
input_tensor,
residual=residual,
rmsnorm_layer=rmsnorm_layer,
norm_out=norm_out,
)
results["standard_allreduce_rmsnorm_native_compiled"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e)
results["standard_allreduce_rmsnorm_native_compiled"] = float("inf")
# FlashInfer Fused AllReduce + RMSNorm Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
try:
if not disable_oneshot:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
allreduce_params=allreduce_params,
use_oneshot=True,
)
results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms
except Exception as e:
logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e)
results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf")
# FlashInfer Fused AllReduce + RMSNorm Two-shot
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
allreduce_params=allreduce_params,
use_oneshot=False,
)
results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = time_ms
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm Two-shot failed: %s", e
)
results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = float("inf")
if quant_mode in ["all", "fp8_only"]:
# Standard AllReduce + RMSNorm + FP8 Quant
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp8_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_fp8,
quant_out=quant_out_fp8,
)
results["standard_allreduce_rmsnorm_fp8_quant"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e)
results["standard_allreduce_rmsnorm_fp8_quant"] = float("inf")
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp8_quant_native_compiled,
input_tensor,
residual=residual,
rmsnorm_layer=rmsnorm_layer,
# quant_fp8_layer removed in sglang version; static_quant_fp8 is used within the function
scale_factor=scale_fp8,
norm_out=norm_out,
quant_out=quant_out_fp8,
)
results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e)
results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
try:
if not disable_oneshot:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_fp8,
quant_out=quant_out_fp8,
allreduce_params=allreduce_params,
use_oneshot=True,
)
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Two-shot
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_fp8,
quant_out=quant_out_fp8,
allreduce_params=allreduce_params,
use_oneshot=False,
)
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP8 Two-shot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = float(
"inf"
)
if quant_mode in ["all", "fp4_only"]:
# Standard AllReduce + RMSNorm + FP4 Quant
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp4_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
)
results["standard_allreduce_rmsnorm_fp4_quant"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e)
results["standard_allreduce_rmsnorm_fp4_quant"] = float("inf")
# Standard AllReduce + RMSNorm + FP4 Quant Native Compiled
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp4_quant_native_compiled,
input_tensor,
residual=residual,
rmsnorm_layer=rmsnorm_layer,
input_global_scale=scale_fp4,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
norm_out=norm_out,
)
results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e)
results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
try:
if not disable_oneshot:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
allreduce_params=allreduce_params,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
use_oneshot=True,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
if flashinfer_comm is not None and allreduce_params is not None:
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
allreduce_params=allreduce_params,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
use_oneshot=False,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float(
"inf"
)
return results
def prepare_results_with_speedups(results_dict):
"""Prepare results with speedup calculations based on dynamic baseline selection."""
prepared_results = []
# Determine the fastest baseline for each operation type
def get_fastest_baseline(op_name, results_dict):
"""Get the fastest baseline between standard and native_compiled versions."""
if "fp8_quant" in op_name:
candidates = [
"standard_allreduce_rmsnorm_fp8_quant",
"standard_allreduce_rmsnorm_fp8_quant_native_compiled",
]
elif "fp4_quant" in op_name:
candidates = [
"standard_allreduce_rmsnorm_fp4_quant",
"standard_allreduce_rmsnorm_fp4_quant_native_compiled",
]
else:
candidates = [
"standard_allreduce_rmsnorm",
"standard_allreduce_rmsnorm_native_compiled",
]
# Find the fastest among available candidates
fastest_time = float("inf")
fastest_baseline = None
for candidate in candidates:
if (
candidate in results_dict
and results_dict[candidate] != float("inf")
and results_dict[candidate] < fastest_time
):
fastest_time = results_dict[candidate]
fastest_baseline = candidate
return fastest_baseline
# Create dynamic baseline mapping
dynamic_baseline_mapping = {}
for op_name in results_dict:
if (
op_name.startswith("flashinfer_")
or op_name.startswith("standard_")
and not op_name.endswith("_native_compiled")
):
dynamic_baseline_mapping[op_name] = get_fastest_baseline(
op_name, results_dict
)
for op_name, time_ms in results_dict.items():
if time_ms == float("inf"):
speedup_str = "FAILED"
time_str = "FAILED"
else:
time_str = f"{time_ms:.3f}"
# Find the appropriate baseline for this operation
baseline_op = dynamic_baseline_mapping.get(op_name)
if baseline_op and baseline_op in results_dict:
baseline_time = results_dict[baseline_op]
if baseline_time != float("inf") and baseline_time > 0:
speedup = baseline_time / time_ms
speedup_str = f"{speedup:.2f}x"
else:
speedup_str = "N/A"
else:
# For baseline operations, determine if this is the fastest baseline
if op_name.endswith("_native_compiled") or (
op_name.startswith("standard_")
and not op_name.endswith("_native_compiled")
):
fastest_baseline = get_fastest_baseline(op_name, results_dict)
if fastest_baseline == op_name:
speedup_str = "baseline"
else:
if fastest_baseline and fastest_baseline in results_dict:
baseline_time = results_dict[fastest_baseline]
if baseline_time != float("inf") and baseline_time > 0:
speedup = baseline_time / time_ms
speedup_str = f"{speedup:.2f}x"
else:
speedup_str = "N/A"
else:
speedup_str = "N/A"
else:
speedup_str = "N/A"
prepared_results.append(
{
"operation": op_name,
"time_ms": time_ms,
"time_str": time_str,
"speedup_str": speedup_str,
}
)
return prepared_results
def print_results(results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode):
"""Print benchmark results in a formatted table."""
print(f"\n{'=' * 80}")
print(f"Results: seq_len={seq_len}, hidden_dim={hidden_dim}")
print(
f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, "
f"quant_mode={quant_mode}"
)
print(f"{'=' * 80}")
print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}")
print(f"{'-' * 80}")
# Prepare results with speedup calculations
prepared_results = prepare_results_with_speedups(results_dict)
for result in prepared_results:
if result["time_ms"] == float("inf"):
time_display = result["time_str"]
else:
time_display = f"{result['time_ms']:.3f}"
print(
f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}"
)
def format_results_markdown(
all_results: list[dict], world_size: int, args: argparse.Namespace
) -> str:
"""Format all benchmark results as markdown."""
markdown = f"""# FlashInfer Fused Collective Operations Benchmark Results
**World Size:** {world_size}
**Hidden Dimension:** {args.hidden_dim}
**Warmup Iterations:** {args.warmup}
**Benchmark Trials:** {args.trials}
**Quantization Mode:** {all_results[0]["quant_mode"] if all_results else "N/A"}
---
"""
for result in all_results:
seq_len = result["seq_len"]
dtype = result["dtype"]
use_residual = result["use_residual"]
results_dict = result["results"]
residual_str = "with residual" if use_residual else "no residual"
markdown += f"""
## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str}
| Operation | Time (ms) | Speedup |
|-----------|-----------|---------|
"""
# Prepare results with speedup calculations
prepared_results = prepare_results_with_speedups(results_dict)
for result in prepared_results:
# Format operation name for better readability
formatted_op_name = result["operation"].replace("_", " ").title()
markdown += f"| {formatted_op_name} | {result['time_str']} |"
markdown += f"{result['speedup_str']} |\n"
markdown += "\n"
return markdown
def save_results_to_file(
all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int
):
"""Save benchmark results to markdown file (only on rank 0)."""
if rank != 0:
return
if not all_results:
logger.warning("No results to save")
return
output_path = args.output_file
try:
markdown_content = format_results_markdown(all_results, world_size, args)
with open(output_path, "w") as f:
f.write(markdown_content)
except Exception as e:
logger.error("Failed to save results to file: %s", e)
def main():
parser = argparse.ArgumentParser(
description="Benchmark fused collective operations"
)
parser.add_argument(
"--seq-lens",
type=int,
nargs="+",
default=[128, 512, 1024, 2048],
help="Sequence lengths to test",
)
parser.add_argument(
"--hidden-dim", type=int, default=8192, help="Hidden dimension size"
)
parser.add_argument(
"--dtypes",
type=str,
nargs="+",
default=["bfloat16"],
choices=["float16", "bfloat16", "float32"],
help="Data types to test",
)
parser.add_argument(
"--no-residual",
action="store_true",
help="Skip residual connection tests",
)
# Quantization mode options (mutually exclusive with --no-quant)
quant_group = parser.add_mutually_exclusive_group()
quant_group.add_argument(
"--no-quant", action="store_true", help="Skip all quantization tests"
)
quant_group.add_argument(
"--quant-fp8", action="store_true", help="Only run FP8 quantization tests"
)
quant_group.add_argument(
"--quant-fp4", action="store_true", help="Only run FP4 quantization tests"
)
quant_group.add_argument(
"--quant-all",
action="store_true",
help="Run all quantization tests (default)",
)
parser.add_argument(
"--disable-oneshot",
action="store_true",
help="Disable oneshot mode for FlashInfer operations",
)
parser.add_argument(
"--warmup", type=int, default=5, help="Number of warmup iterations"
)
parser.add_argument(
"--trials", type=int, default=20, help="Number of benchmark trials"
)
parser.add_argument(
"--output-file",
type=str,
help="""Output file path for markdown results
(default: benchmark_results_<timestamp>.md)
""",
)
args = parser.parse_args()
# Check if running with torchrun (required for collective operations)
if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
raise RuntimeError(
"Must run with torchrun for distributed benchmarking. "
"Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py"
)
# Initialize distributed environment
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
init_distributed_environment(
world_size=world_size,
rank=rank,
local_rank=rank,
backend="nccl",
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Validate world size (must be > 1 for collective operations)
if world_size <= 1:
raise ValueError(
"World size must be > 1 for collective operations benchmarking. "
f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1."
)
# Determine quantization mode
if args.no_quant:
quant_mode = "none"
elif args.quant_fp8:
quant_mode = "fp8_only"
elif args.quant_fp4:
quant_mode = "fp4_only"
else: # args.quant_all or default
quant_mode = "all"
if rank == 0:
logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank)
logger.info("Quantization mode: %s", quant_mode)
if flashinfer_comm is not None:
oneshot_status = "enabled" if not args.disable_oneshot else "disabled"
logger.info(
"FlashInfer available - will benchmark fused operations (oneshot: %s)",
oneshot_status,
)
else:
logger.info(
"FlashInfer not available - only benchmarking standard operations"
)
# Convert dtype strings to torch dtypes
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
dtypes = [dtype_map[dt] for dt in args.dtypes]
# Test configurations
residual_options = [True] if not args.no_residual else [False]
if not args.no_residual:
residual_options.append(False)
configs = list(itertools.product(args.seq_lens, dtypes, residual_options))
# Setup FlashInfer workspace if available
ipc_handles = None
allreduce_params = None
if flashinfer_comm is not None:
# Use the largest hidden dimension for workspace setup
max_num_token = _FI_MAX_SIZES.get(world_size) // (
args.hidden_dim * world_size * 2
)
ipc_handles, workspace_tensor = setup_flashinfer_workspace(
world_size, rank, args.hidden_dim, max_num_token
)
if workspace_tensor is not None:
allreduce_params = FlashInferFusedAllReduceParams(
rank=rank,
world_size=world_size,
max_token_num=max_num_token,
)
# Collect all results for markdown export
all_results = []
try:
# Run benchmarks
for seq_len, dtype, use_residual in configs:
if rank == 0:
logger.info(
"\nTesting: seq_len=%s, hidden_dim=%s, dtype=%s, residual=%s",
seq_len,
args.hidden_dim,
dtype,
use_residual,
)
results = run_benchmarks(
seq_len,
args.hidden_dim,
dtype,
use_residual,
allreduce_params,
quant_mode=quant_mode,
disable_oneshot=args.disable_oneshot,
)
# Store results for markdown export
if rank == 0:
all_results.append(
{
"seq_len": seq_len,
"hidden_dim": args.hidden_dim,
"dtype": str(dtype).replace("torch.", ""),
"use_residual": use_residual,
"quant_mode": quant_mode,
"results": results,
}
)
print_results(
results,
seq_len,
args.hidden_dim,
dtype,
use_residual,
quant_mode,
)
# Save results to markdown file
if args.output_file and rank == 0:
save_results_to_file(all_results, world_size, args, rank)
finally:
# Cleanup
if ipc_handles is not None:
cleanup_flashinfer_workspace(ipc_handles)
with contextlib.suppress(Exception):
dist.barrier()
cleanup_dist_env_and_memory(shutdown_ray=False)
if __name__ == "__main__":
main()
## Tuning Triton MoE Kernels
This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.
### Tuning Tool
- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures.
Example usage:
```bash
# Tune Mixtral-8x7B with default settings
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--tune
# Tune Qwen2-57B with FP8 and TP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct \
--tp-size 4 \
--dtype fp8_w8a8 \
--tune
# Tune Qwen3-235B-A22B-FP8 and TP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen3-235B-A22B-FP8 \
--tp-size 4 \
--dtype fp8_w8a8 \
--tune
# Tune DeepSeek-V3 with FP8 and TP=8
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8 \
--dtype fp8_w8a8 \
--tune
# Tune DeepSeek-R1 with channel-wise INT8 and TP=16
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model meituan/DeepSeek-R1-Channel-INT8 \
--tp-size 16 \
--dtype int8_w8a8 \
--tune
```
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/triton_version` dir to use it in `sglang`.
### Performance Comparison Tool
- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.
Example usage:
```bash
# Compare with default settings (Mixtral model)
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
# Compare with FP8 mode for Qwen2-57B
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct \
--use-fp8-w8a8
# Compare with custom TP size
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8
# Compare with custom TP size
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8
```
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel.
# python3 benchmark/kernels/fused_moe_triton/sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8
import argparse
import torch
import triton
from transformers import AutoConfig
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward,
)
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
]:
E = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
}
print(f"{shape_configs=}")
return shape_configs
def fused_moe_triton_api(
x,
w1,
w2,
input_gating,
topk,
):
topk_op = TopK(
top_k=topk,
renormalize=False,
use_grouped_topk=False,
)
topk_op.use_triton_kernels = True
triton_topk_output = topk_op.forward_cuda(
hidden_states=x,
router_logits=input_gating,
)
moe_runner_config = MoeRunnerConfig(
inplace=False,
)
return triton_kernel_moe_forward(
x,
w1,
w2,
triton_topk_output,
moe_runner_config,
)
def fused_moe_sglang_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
):
topk_output = select_experts(
hidden_states=x,
router_logits=input_gating,
topk_config=TopKConfig(top_k=topk, renormalize=False),
)
return fused_moe_sglang(
x,
w1,
w2,
topk_output,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=list([128, 256, 512, 1024, 2048, 4096, 8192]),
line_arg="provider",
line_vals=[
"sglang_fused_moe_triton_v340",
"sglang_fused_moe_triton",
],
line_names=[
"sglang_fused_moe_triton_v340",
"sglang_fused_moe_triton",
],
styles=[
("blue", "-"),
("green", "-"),
],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(
batch_size,
provider,
model_config,
use_fp8_w8a8=False,
use_cuda_graph: bool = False,
):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_tokens = batch_size
num_experts = model_config["num_experts"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
block_shape = model_config["block_shape"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
w1_tri = w1.clone()
w2_tri = w2.clone()
w1_tri = w1_tri.transpose(-2, -1).contiguous()
w2_tri = w2_tri.transpose(-2, -1).contiguous()
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
if provider == "sglang_fused_moe_triton_v340":
api_func = fused_moe_triton_api
api_kwargs = {
"x": x,
"w1": w1_tri,
"w2": w2_tri,
"input_gating": input_gating,
"topk": topk,
}
else:
api_func = fused_moe_sglang_api
api_kwargs = {
"x": x,
"w1": w1,
"w2": w2,
"input_gating": input_gating,
"topk": topk,
"use_fp8_w8a8": use_fp8_w8a8,
"block_shape": block_shape,
}
# Warmup
for _ in range(10):
_ = api_func(**api_kwargs)
torch.cuda.synchronize()
if use_cuda_graph:
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
api_func(**api_kwargs)
torch.cuda.synchronize()
bench_lambda = lambda: graph.replay()
else:
bench_lambda = lambda: api_func(**api_kwargs)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(bench_lambda, quantiles=quantiles)
return ms, min_ms, max_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--use-cuda-graph", action="store_true", help="Enable CUDA Graph capture/replay"
)
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/sglang_fused_moe/",
)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
try:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:23456",
world_size=1,
rank=0,
)
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method="tcp://127.0.0.1:23456",
local_rank=0,
backend="nccl" if torch.cuda.is_available() else "gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
use_cuda_graph=args.use_cuda_graph,
)
finally:
destroy_model_parallel()
destroy_distributed_environment()
if __name__ == "__main__":
main()
import torch
import triton
import triton.language as tl
from triton.testing import do_bench
@triton.jit
def _moe_sum_reduce_kernel(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: int,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)
offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
mask_token = offs_token < token_num
mask_dim = offs_dim < hidden_dim
base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tile = tl.load(
base_ptrs + i * input_stride_1,
mask=mask_token[:, None] & mask_dim[None, :],
other=0.0,
)
accumulator += tile.to(tl.float32)
accumulator *= routed_scaling_factor
# -------- Write back --------
store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
tl.store(
store_ptrs,
accumulator.to(input_ptr.dtype.element_ty),
mask=mask_token[:, None] & mask_dim[None, :],
)
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
def moe_sum_reduce(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
assert output.is_contiguous()
token_num, topk_num, hidden_dim = input.shape
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 1
num_warps = 16
grid = (
triton.cdiv(token_num, BLOCK_M),
triton.cdiv(hidden_dim, BLOCK_DIM),
)
_moe_sum_reduce_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
token_num=token_num,
topk_num=topk_num,
hidden_dim=hidden_dim,
routed_scaling_factor=routed_scaling_factor,
BLOCK_M=BLOCK_M,
BLOCK_DIM=BLOCK_DIM,
NUM_STAGE=NUM_STAGE,
num_warps=num_warps,
)
return
def compute_sum_scaled_baseline(
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
) -> torch.Tensor:
torch.sum(x, dim=1, out=out)
out.mul_(routed_scaling_factor)
return out
@torch.compile
def compute_sum_scaled_compiled(
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
) -> torch.Tensor:
torch.sum(x * routed_scaling_factor, dim=1, out=out)
return out
def get_benchmark():
num_tokens_range = [2**i for i in range(0, 13)]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=num_tokens_range,
line_arg="version",
line_vals=["baseline", "compiled", "triton"],
line_names=["Original", "TorchCompile", "TritonKernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="sum_scaled_performance",
args={},
)
)
def benchmark(num_tokens, version):
topk = 9
hidden_size = 4096
dtype = torch.bfloat16
scaling_factor = 0.3
x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
# Warmup
for _ in range(3):
if version == "baseline":
compute_sum_scaled_baseline(x, out, scaling_factor)
elif version == "compiled":
compute_sum_scaled_compiled(x, out, scaling_factor)
else:
moe_sum_reduce(x, out, scaling_factor)
# Benchmark
quantiles = [0.5, 0.2, 0.8]
if version == "baseline":
ms, min_ms, max_ms = do_bench(
lambda: compute_sum_scaled_baseline(x, out, scaling_factor),
quantiles=quantiles,
)
elif version == "compiled":
ms, min_ms, max_ms = do_bench(
lambda: compute_sum_scaled_compiled(x, out, scaling_factor),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = do_bench(
lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
def verify_correctness(num_tokens=1024):
x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16)
scaling_factor = 0.3
out_baseline = torch.empty_like(x[:, 0])
compute_sum_scaled_baseline(x, out_baseline, scaling_factor)
out_compiled = torch.empty_like(out_baseline)
compute_sum_scaled_compiled(x, out_compiled, scaling_factor)
out_triton = torch.empty_like(out_baseline)
moe_sum_reduce(x, out_triton, scaling_factor)
if torch.allclose(
out_baseline, out_compiled, atol=1e-2, rtol=1e-2
) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
print(
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
)
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}")
if __name__ == "__main__":
print("Running correctness verification...")
verify_correctness()
print("\nRunning performance benchmark...")
benchmark = get_benchmark()
benchmark.run(
print_data=True,
# save_path="./configs/benchmark_ops/sum_scaled/"
)
# python3 benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
import argparse
import torch
import triton
from torch.nn import functional as F
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_triton,
)
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
}
print(f"{shape_configs=}")
return shape_configs
def fused_topk_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_weights = F.softmax(gating_output.float(), dim=-1)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
@torch.compile(dynamic=False)
def fused_moe_torch(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
) -> torch.Tensor:
assert not use_fp8_w8a8, "Fp8_w8a8 fused_moe is not supported for torch compile"
topk_weights, topk_ids = fused_topk_native(
hidden_states=x,
gating_output=input_gating,
topk=topk,
renormalize=True,
)
w13_weights = w1[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = w2[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
def fused_moe_torch_compile(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
):
return fused_moe_torch(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def fused_moe_sglang_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
):
return fused_moe_triton(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=list(range(1, 5)),
line_arg="provider",
line_vals=[
"fused_moe_triton",
"fused_moe_torch_compile",
],
line_names=[
"fused_moe_triton",
"fused_moe_torch_compile",
],
styles=[
("blue", "-"),
("green", "-"),
],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
set_torch_compile_config()
num_tokens = batch_size
num_experts = model_config["num_experts"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_fp8_w8a8:
init_dtype = dtype
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
w1_scale = w2_scale = a1_scale = a2_scale = None
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
# Warmup
api_func = (
fused_moe_torch_compile
if provider == "fused_moe_torch_compile"
else fused_moe_sglang_api
)
for _ in range(10):
y = api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
torch.cuda.synchronize()
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)[0],
quantiles=quantiles,
)
return ms, min_ms, max_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/fused_moe_torch_compile/",
)
args = parser.parse_args()
model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
)
if __name__ == "__main__":
main()
# python3 benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
import argparse
import torch
import triton
import vllm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
]:
E = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
vllm_version_num = (
vllm.__version_tuple__[0] * 100
+ vllm.__version_tuple__[1] * 10
+ vllm.__version_tuple__[2]
)
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
assert (
vllm_version_num >= 66
), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
}
print(f"{shape_configs=}")
return shape_configs
def fused_moe_vllm_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
):
if block_shape is not None:
return fused_moe_vllm(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
else:
return fused_moe_vllm(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def fused_moe_sglang_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
):
return fused_moe_sglang(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=list(range(1, 513)),
line_arg="provider",
line_vals=[
"vllm_fused_moe_triton",
"sglang_fused_moe_triton",
],
line_names=[
"vllm_fused_moe_triton",
"sglang_fused_moe_triton",
],
styles=[
("blue", "-"),
("green", "-"),
],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_tokens = batch_size
num_experts = model_config["num_experts"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
block_shape = model_config["block_shape"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1_scale = w2_scale = a1_scale = a2_scale = None
if use_fp8_w8a8:
init_dtype = dtype
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
if block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand(
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
)
w2_scale = torch.rand(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
else:
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
# Warmup
api_func = (
fused_moe_vllm_api
if provider == "vllm_fused_moe_triton"
else fused_moe_sglang_api
)
for _ in range(10):
y = api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
torch.cuda.synchronize()
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)[0],
quantiles=quantiles,
)
return ms, min_ms, max_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/vllm_sglang_fused_moe/",
)
args = parser.parse_args()
try:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:23456",
world_size=1,
rank=0,
)
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method="tcp://127.0.0.1:23456",
local_rank=0,
backend="nccl" if torch.cuda.is_available() else "gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
)
finally:
destroy_model_parallel()
destroy_distributed_environment()
if __name__ == "__main__":
main()
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
import argparse
import json
import time
from contextlib import nullcontext
from datetime import datetime
from typing import Any, Dict, List, Tuple, TypedDict
import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_dtype_str,
get_config_file_name,
get_default_config,
get_moe_configs,
)
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip
_is_hip = is_hip()
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def benchmark_config(
config: BenchmarkConfig,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int] = None,
num_iters: int = 100,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16 or use_int8_w8a8:
w1 = torch.randint(
-127,
127,
(
num_experts,
shard_intermediate_size,
hidden_size,
),
dtype=torch.int8,
)
w2 = torch.randint(
-127,
127,
(
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8,
)
else:
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8 or use_int8_w8a8:
if use_int8_w8a8 and block_shape is None:
w1_scale = torch.randn(
num_experts, shard_intermediate_size, dtype=torch.float32
)
w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
elif block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand(
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
)
w2_scale = torch.rand(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
if use_fp8_w8a8:
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_config = TopKConfig(
top_k=topk,
renormalize=True,
)
topk_output = select_experts(x, input_gating, topk_config)
def prepare(i: int):
input_gating = gating_output[i]
new_topk_output = select_experts(x, input_gating, topk_config)
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
topk_output.router_logits.copy_(new_topk_output.router_logits)
def run():
moe_runner_config = MoeRunnerConfig(
inplace=True,
)
with override_config(config):
fused_moe(
x,
w1,
w2,
topk_output,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
# JIT compilation & warmup
run()
torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for _ in range(10):
run()
torch.cuda.synchronize()
# Warmup
for _ in range(5):
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: List[float] = []
for i in range(num_iters):
prepare(i)
torch.cuda.synchronize()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset()
return avg
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
waves_per_eu_range = 0
for num_stages in [2]:
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_warps in [1, 2, 4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs: List[BenchmarkConfig] = []
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(0)
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
op_config = get_moe_configs(
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
)
if op_config is None:
config = get_default_config(
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype_str,
False,
block_shape,
)
else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
return config, kernel_time
def tune(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
search_space: List[Dict[str, int]],
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
num_iters=10,
)
except (triton.runtime.autotuner.OutOfResources, RuntimeError):
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None
return best_config
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
}
def save_configs(
configs: Dict[int, BenchmarkConfig],
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
) -> None:
dtype_str = get_config_dtype_str(
dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(
num_experts,
shard_intermediate_size // 2,
dtype_str,
block_shape,
)
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1
)
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
if args.batch_size is None:
batch_sizes = [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
else:
batch_sizes = [args.batch_size]
ray.init()
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
outputs = []
worker_idx = 0
for input_args in inputs:
worker = workers[worker_idx]
worker_method = getattr(worker, method)
output = worker_method.remote(*input_args)
outputs.append(output)
worker_idx = (worker_idx + 1) % num_gpus
return ray.get(outputs)
if args.tune:
search_space = get_configs_compute_bound()
if block_shape is not None:
block_n, block_k = block_shape[0], block_shape[1]
search_space = [
config
for config in search_space
if block_k % config["BLOCK_SIZE_K"] == 0
]
print(f"Start tuning over {len(search_space)} configurations...")
start = time.perf_counter()
configs = _distribute(
"tune",
[
(
batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
search_space,
)
for batch_size in batch_sizes
],
)
best_configs = {
M: sort_config(config) for M, config in zip(batch_sizes, configs)
}
save_configs(
best_configs,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
end = time.perf_counter()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute(
"benchmark",
[
(
batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
for batch_size in batch_sizes
],
)
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}")
print(f"Kernel time: {kernel_time:.2f} us")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument(
"--dtype",
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
default="auto",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
args = parser.parse_args()
main(args)
import itertools
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange
from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
@triton.jit
def _decode_kernel(
Q,
K,
V,
KV,
Out,
S,
b: tl.constexpr,
h: tl.constexpr,
n: tl.constexpr,
d: tl.constexpr,
d_original: tl.constexpr,
e: tl.constexpr,
e_original: tl.constexpr,
):
off_bh = tl.program_id(0)
off_h = off_bh % h
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
kv_offset = off_bh * d * e
s = tl.load(S + off_h)
ratio = tl.exp(-s)
d_idx = tl.arange(0, d)
e_idx = tl.arange(0, e)
# Create masks for original dimensions
d_mask = d_idx < d_original
e_mask = e_idx < e_original
# Load with masking
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
# Load KV with 2D masking
kv = tl.load(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
mask=(d_mask[:, None] & e_mask[None, :]),
other=0.0,
)
# Compute outer product using element-wise operations
k_v_prod = k[:, None] * v[None, :]
kv = ratio * kv + k_v_prod
# Store KV with 2D masking
tl.store(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
kv.to(KV.dtype.element_ty),
mask=(d_mask[:, None] & e_mask[None, :]),
)
# Compute matrix-vector multiplication using element-wise operations and reduction
o = tl.sum(q[:, None] * kv, axis=0)
# Store output with masking
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
def lightning_attn_decode(q, k, v, kv, s):
"""Triton implementation of Lightning Attention decode operation"""
b, h, n, d = q.shape
e = v.shape[-1]
assert n == 1, "Sequence length must be 1 in decode mode"
# Get padded dimensions (power of 2)
d_padded = next_power_of_2(d)
e_padded = next_power_of_2(e)
# Create output tensor (padded)
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
# Create padded tensors without actually padding the data
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
kv_padded = torch.empty(
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
)
# Copy data to padded tensors
q_padded[..., :d] = q
k_padded[..., :d] = k
v_padded[..., :e] = v
kv_padded[..., :d, :e] = kv
# Launch kernel
grid = (b * h, 1)
_decode_kernel[grid](
q_padded,
k_padded,
v_padded,
kv_padded,
o_padded,
s,
b=b,
h=h,
n=n,
d=d_padded,
d_original=d,
e=e_padded,
e_original=e,
)
# Get unpadded outputs
o = o_padded[..., :e]
kv_out = kv_padded[..., :d, :e]
return o, kv_out
def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2))))
class MiniMaxText01LightningAttention(nn.Module):
def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs):
super().__init__()
if config is None:
config = type("Config", (), kwargs)
bias = False
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.out_proj = nn.Linear(
self.head_dim * self.num_heads, self.hidden_size, bias=bias
)
self.act = get_activation_fn(config.hidden_act)
self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
self.qkv_proj = nn.Linear(
self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias
)
self.output_gate = nn.Linear(
self.hidden_size, self.head_dim * self.num_heads, bias=bias
)
# for inference only
self.offset = 0
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
output_attentions: bool = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None,
**kwargs,
):
if (not self.training) and (not do_eval):
return self.inference(
hidden_states,
attn_mask,
output_attentions,
past_key_value,
use_cache,
slope_rate,
)
def inference(
self,
x,
attn_mask: Optional[torch.Tensor] = None, # (b, n)
output_attentions: bool = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
):
# x: b n d
b, n, d = x.shape
# linear map
qkv = self.act(self.qkv_proj(x))
new_shape = qkv.size()[:-1] + (self.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
q = q.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d]
k = k.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d]
v = v.transpose(1, 2) # [b, n, h, d] -> [b, h, n, e]
self.offset += 1
ratio = torch.exp(-slope_rate) # [h, 1, 1]
# decode mode
kv = past_key_value # [b, h, d, e]
output = []
for i in range(n):
# kv: [b, h, d, e]
# ratio: [h, 1, 1]
# k: [b, h, n, d]
# v: [b, h, n, e]
# k[:, :, i : i + 1]: [b, h, 1, d]
# v[:, :, i : i + 1]: [b, h, 1, e]
# ratio * kv: [b, h, d, e]
# torch.einsum(
# "... n d, ... n e -> ... d e",
# k[:, :, i : i + 1],
# v[:, :, i : i + 1],
# )
# [b, h, d, e] + [b, h, d, e] -> [b, h, d, e]
kv = ratio * kv + torch.einsum(
"... n d, ... n e -> ... d e",
k[:, :, i : i + 1],
v[:, :, i : i + 1],
)
# q[:, :, i : i + 1]: [b, h, 1, d]
# kv.to(q.dtype): [b, h, d, e]
# torch.einsum(
# "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
# )
# [b, h, 1, d] * [b, h, d, e] -> [b, h, 1, e]
qkv = torch.einsum(
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
)
output.append(qkv)
output = torch.cat(output, dim=-2)
# reshape
output = rearrange(output, "b h n d -> b n (h d)")
# normalize
output = self.norm(output)
# gate
output = F.sigmoid(self.output_gate(x)) * output
# outproj
output = self.out_proj(output)
attn_weights = None
return output, attn_weights, kv
def get_activation_fn(activation):
if activation == "gelu":
return F.gelu
elif activation == "relu":
return F.relu
elif activation == "elu":
return F.elu
elif activation == "sigmoid":
return F.sigmoid
elif activation == "exp":
def f(x):
with torch.no_grad():
x_max = torch.max(x, dim=-1, keepdims=True).values
y = torch.exp(x - x_max)
return y
return f
elif activation == "leak":
return F.leaky_relu
elif activation == "1+elu":
def f(x):
return 1 + F.elu(x)
return f
elif activation == "2+elu":
def f(x):
return 2 + F.elu(x)
return f
elif activation == "silu" or activation == "swish":
return F.silu
elif activation == "sine":
return torch.sin
else:
return lambda x: x
class MiniMaxText01RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def test_lightning_attention_implementations(model_params):
torch.manual_seed(42)
batch_size = 64
seq_len = 1
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_states = torch.randn(
batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device
)
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device)
model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device)
model_attn.eval()
d = model_params["head_dim"]
past_kv = torch.randn(
batch_size,
model_params["num_attention_heads"],
d,
d,
device=device,
)
with torch.no_grad():
model_output, _, new_kv = model_attn.inference(
hidden_states,
attn_mask=attention_mask,
slope_rate=slope_rate,
past_key_value=past_kv,
)
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
past_kv = past_kv.contiguous()
slope_rate = slope_rate.contiguous()
# Test Triton implementation
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
triton_output = triton_output.transpose(1, 2).contiguous()
triton_output = triton_output.view(batch_size, seq_len, -1)
triton_output = model_attn.norm(triton_output)
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
triton_output = model_attn.out_proj(triton_output)
# Test SGL implementation
sgl_output = torch.empty_like(v)
sgl_new_kv = torch.empty_like(past_kv)
sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)
sgl_output = sgl_output.transpose(1, 2).contiguous()
sgl_output = sgl_output.view(batch_size, seq_len, -1)
sgl_output = model_attn.norm(sgl_output)
sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
sgl_output = model_attn.out_proj(sgl_output)
# Verify Triton implementation results
torch.testing.assert_close(
model_output,
triton_output,
rtol=1e-3,
atol=1e-2,
msg="Triton lightning attention implementation produces different output results",
)
torch.testing.assert_close(
new_kv,
triton_new_kv,
rtol=1e-3,
atol=1e-2,
msg="Triton lightning attention implementation produces different kv results",
)
# Verify SGL implementation results
torch.testing.assert_close(
model_output,
sgl_output,
rtol=1e-3,
atol=1e-2,
msg="SGL lightning attention implementation produces different output results",
)
torch.testing.assert_close(
new_kv,
sgl_new_kv,
rtol=1e-3,
atol=1e-2,
msg="SGL lightning attention implementation produces different kv results",
)
print("✅ All implementations match")
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
n_attention_heads, 1, 1
)
return slopes
def get_benchmark():
batch_size_range = [i for i in range(1, 33)] # max 32
seq_length_range = [1] # decode mode sequence length is fixed to 1
configs = list(itertools.product(batch_size_range, seq_length_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["Original", "Triton", "SGL"],
line_names=[
"Original PyTorch Implementation",
"Triton Implementation",
"SGL Implementation",
],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="lightning-attention-decode-performance",
args={},
)
)
def benchmark(batch_size, seq_len, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
params = {
"hidden_size": 6144,
"num_attention_heads": 64,
"head_dim": 96,
"hidden_act": "gelu",
}
hidden_states = torch.randn(
batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device
)
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device)
model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device)
model_attn.eval()
d = params["head_dim"]
past_kv = torch.randn(
batch_size,
params["num_attention_heads"],
d,
d,
device=device,
)
quantiles = [0.5, 0.2, 0.8]
if provider == "Original":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: model_attn.inference(
hidden_states,
attn_mask=attention_mask,
slope_rate=slope_rate,
past_key_value=past_kv,
),
quantiles=quantiles,
)
elif provider == "Triton":
def run_triton():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
output, new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, -1)
output = model_attn.norm(output)
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
return model_attn.out_proj(output)
ms, min_ms, max_ms = triton.testing.do_bench(
run_triton,
quantiles=quantiles,
)
else: # SGL
def run_sgl():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
output = torch.empty_like(v)
new_kv = torch.empty_like(past_kv)
sgl_lightning_attention_decode(
q, k, v, past_kv, slope_rate, output, new_kv
)
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, -1)
output = model_attn.norm(output)
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
return model_attn.out_proj(output)
ms, min_ms, max_ms = triton.testing.do_bench(
run_sgl,
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/lightning_attention_decode/",
help="Path to save lightning attention decode benchmark results",
)
args = parser.parse_args()
params = {
"hidden_size": 6144,
"num_attention_heads": 64,
"head_dim": 96,
"hidden_act": "silu",
}
# Run correctness test first
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
test_lightning_attention_implementations(params)
# Run performance benchmark
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=args.save_path)
import itertools
import math
import os
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange
# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py
@triton.jit
def _fwd_kernel(
Q,
K,
V,
Out,
S, # log lambda
b: tl.constexpr,
h: tl.constexpr,
n: tl.constexpr,
d: tl.constexpr,
e: tl.constexpr,
BLOCK: tl.constexpr,
NUM_BLOCK: tl.constexpr,
BLOCK_MODEL: tl.constexpr,
):
##### get offset
off_bh = tl.program_id(0)
off_h = off_bh % h
off_e = tl.program_id(1)
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
# channel offset
e_offset = off_e * BLOCK_MODEL
##### get block ptr
Q_block_ptr = Q + qk_offset + tl.arange(0, d)[None, :]
K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None]
V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
S_block_ptr = S + off_h
##### init diag decay(Lambda); q, k decay; kv
s = tl.load(S_block_ptr)
# q, k decay
off_block = tl.arange(
0, BLOCK
) # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent
q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None])
k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - off_block[None, :]))
block_decay = tl.exp(-s.to(tl.float32) * BLOCK)
# diag decay
index = off_block[:, None] - off_block[None, :]
s_index = s * index
s_index = tl.where(index >= 0, -s_index, float("-inf"))
diag_decay = tl.exp(s_index)
kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32)
##### compute
for i in range(NUM_BLOCK):
# load
q = tl.load(
Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0
).to(tl.float32)
k_trans = tl.load(
K_trans_block_ptr + off_block[None, :] * d,
mask=off_block[None, :] < n,
other=0.0,
).to(tl.float32)
v = tl.load(
V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0
).to(tl.float32)
# compute
qk = tl.dot(q, k_trans) * diag_decay
o_intra = tl.dot(qk, v)
o_inter = tl.dot(q, kv) * q_decay
o = o_intra + o_inter
# save and update
tl.store(
O_block_ptr + off_block[:, None] * e,
o.to(O_block_ptr.dtype.element_ty),
mask=off_block[:, None] < n,
)
kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v)
off_block += BLOCK
def lightning_attn2(q, k, v, s):
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
s = s.contiguous()
b, h, n, d = q.shape
e = v.shape[-1]
# Pad d to next power of 2
d_padded = next_power_of_2(d)
if d_padded != d:
q_padded = F.pad(q, (0, d_padded - d))
k_padded = F.pad(k, (0, d_padded - d))
else:
q_padded = q
k_padded = k
# Pad e to next power of 2
e_padded = next_power_of_2(e)
if e_padded != e:
v_padded = F.pad(v, (0, e_padded - e))
else:
v_padded = v
o_padded = torch.empty((b, h, n, e_padded), dtype=q.dtype, device=q.device)
BLOCK = 64
NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK)
# parallel over channel
BLOCK_MODEL = min(triton.next_power_of_2(e_padded), 32)
grid = (b * h, triton.cdiv(e_padded, BLOCK_MODEL))
_fwd_kernel[grid](
q_padded,
k_padded,
v_padded,
o_padded,
s,
b,
h,
n,
d_padded,
e_padded,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
BLOCK_MODEL=BLOCK_MODEL,
)
# Remove padding from output
if e_padded != e:
o = o_padded[..., :e]
else:
o = o_padded
return o
def is_support(dim):
return 16 % dim
def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2))))
def lightning_attn_func(q, k, v, s):
b, h, n, d = q.shape
e = v.shape[-1]
assert is_support(d) and is_support(e)
# pad v's feature dim to power of 2
e_pad = next_power_of_2(e)
need_pad = e_pad != e
if need_pad:
v = F.pad(v, (0, e_pad - e))
if d > 128:
# split over head
if 64 % d:
m = 64
elif 32 % d:
m = 32
elif 16 % d:
m = 16
arr = [m * i for i in range(d // m + 1)]
if arr[-1] != d:
arr.append(d)
n = len(arr)
o = 0
for i in range(n - 1):
start = arr[i]
end = arr[i + 1]
q1 = q[..., start:end]
k1 = k[..., start:end]
o += lightning_attn2(q1, k1, v, s)
else:
o = lightning_attn2(q, k, v, s)
if need_pad:
o = o[:, :, :, :e]
return o
debug = eval(os.environ.get("debug", default="False"))
BLOCK = 256
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01
class MiniMaxText01RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
def get_activation_fn(activation):
if debug:
logger.info(f"activation: {activation}")
if activation == "gelu":
return F.gelu
elif activation == "relu":
return F.relu
elif activation == "elu":
return F.elu
elif activation == "sigmoid":
return F.sigmoid
elif activation == "exp":
def f(x):
with torch.no_grad():
x_max = torch.max(x, dim=-1, keepdims=True).values
y = torch.exp(x - x_max)
return y
return f
elif activation == "leak":
return F.leaky_relu
elif activation == "1+elu":
def f(x):
return 1 + F.elu(x)
return f
elif activation == "2+elu":
def f(x):
return 2 + F.elu(x)
return f
elif activation == "silu" or activation == "swish":
return F.silu
elif activation == "sine":
return torch.sin
else:
logger.info(f"activation: does not support {activation}, use Identity!!!")
return lambda x: x
# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
class MiniMaxText01LightningAttention(nn.Module):
def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs):
super().__init__()
if config is None:
config = type("Config", (), kwargs)
bias = False
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.out_proj = nn.Linear(
self.head_dim * self.num_heads, self.hidden_size, bias=bias
)
self.act = get_activation_fn(config.hidden_act)
self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
self.qkv_proj = nn.Linear(
self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias
)
self.output_gate = nn.Linear(
self.hidden_size, self.head_dim * self.num_heads, bias=bias
)
# for inference only
self.offset = 0
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
output_attentions: bool = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None,
**kwargs,
):
if (not self.training) and (not do_eval):
return self.inference(
hidden_states,
attn_mask,
output_attentions,
past_key_value,
use_cache,
slope_rate,
)
def inference(
self,
x,
attn_mask: Optional[torch.Tensor] = None, # (b, n)
output_attentions: bool = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
):
# x: b n d
b, n, d = x.shape
# linear map
qkv = self.act(self.qkv_proj(x))
new_shape = qkv.size()[:-1] + (self.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if past_key_value is None:
self.offset = q.shape[-2]
else:
self.offset += 1
# for align with metaseq
ratio = torch.exp(-slope_rate)
# only use for the first time
if past_key_value is None:
slope_rate = slope_rate.to(torch.float32)
if attn_mask is not None:
v = v.masked_fill(
(1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0
)
NUM_BLOCK = (n + BLOCK - 1) // BLOCK
b, h, n, d = q.shape
e = v.shape[-1]
# other
array = torch.arange(BLOCK).to(q) + 1
q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
index = array[:, None] - array[None, :]
s_index = (
slope_rate
* index[
None,
None,
]
)
s_index = torch.where(index >= 0, -s_index, float("-inf"))
diag_decay = torch.exp(s_index)
kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
for i in range(NUM_BLOCK):
si = i * BLOCK
ei = min(si + BLOCK, n)
m = ei - si
qi = q[:, :, si:ei].contiguous()
ki = k[:, :, si:ei].contiguous()
vi = v[:, :, si:ei].contiguous()
qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32)
# diag
qk = (
torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32)
* diag_decay[:, :, :m, :m]
)
qkv_diag = torch.matmul(qk, vi.to(torch.float32))
block_decay = torch.exp(-slope_rate * m)
output[:, :, si:ei] = qkv_none_diag + qkv_diag
kv = block_decay * kv + torch.matmul(
(ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi
)
else:
kv = past_key_value
output = []
for i in range(n):
kv = ratio * kv + torch.einsum(
"... n d, ... n e -> ... d e",
k[:, :, i : i + 1],
v[:, :, i : i + 1],
)
qkv = torch.einsum(
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
)
output.append(qkv)
output = torch.cat(output, dim=-2)
# reshape
output = rearrange(output, "b h n d -> b n (h d)")
# normalize
output = self.norm(output)
# gate
output = F.sigmoid(self.output_gate(x)) * output
# outproj
output = self.out_proj(output)
attn_weights = None
return output, attn_weights, kv
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(
n
) # In the paper, we only train models that have 2^a heads for some a. This function has
else: # some good properties that only occur when the input is a power of 2. To maintain that even
closest_power_of_2 = 2 ** math.floor(
math.log2(n)
) # when the number of heads is not a power of 2, we use this workaround.
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
# h, 1, 1
slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
n_attention_heads, 1, 1
)
return slopes
def test_lightning_attention_implementations(model_params):
torch.manual_seed(42)
batch_size = 2
seq_len = 1024
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_states = torch.randn(
batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device
)
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device)
model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device)
model_attn.eval()
with torch.no_grad():
model_output, _, _ = model_attn.inference(
hidden_states, attn_mask=attention_mask, slope_rate=slope_rate
)
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
lib_output = lightning_attn_func(q, k, v, slope_rate)
lib_output = lib_output.transpose(1, 2).contiguous()
lib_output = lib_output.view(batch_size, seq_len, -1)
lib_output = model_attn.norm(lib_output)
lib_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output
lib_output = model_attn.out_proj(lib_output)
torch.testing.assert_close(
model_output,
lib_output,
rtol=1e-3,
atol=1e-2,
msg="Lightning attention implementations produce different results",
)
print("✅ Two implementations match")
def get_benchmark():
batch_size_range = [2**i for i in range(0, 7)] # max 64
seq_length_range = [256, 512, 1024, 2048, 4096] # max 4096
configs = list(itertools.product(batch_size_range, seq_length_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["MiniMax-Text-01", "OpenNLPLab"],
line_names=[
"MiniMax-Text-01 Model Implementation",
"OpenNLPLab Library Implementation",
],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="lightning-attention-prefill-performance",
args={},
)
)
def benchmark(batch_size, seq_len, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
params = {
"hidden_size": 6144,
"num_attention_heads": 64,
"head_dim": 96,
"hidden_act": "gelu",
}
hidden_states = torch.randn(
batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device
)
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device)
model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device)
model_attn.eval()
quantiles = [0.5, 0.2, 0.8]
if provider == "MiniMax-Text-01":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: model_attn.inference(
hidden_states, attn_mask=attention_mask, slope_rate=slope_rate
),
quantiles=quantiles,
)
else:
def run_lib():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
lib_output = lightning_attn_func(q, k, v, slope_rate)
lib_output = lib_output.transpose(1, 2).contiguous()
lib_output = lib_output.view(batch_size, seq_len, -1)
lib_output = model_attn.norm(lib_output)
lib_output = (
torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output
)
return model_attn.out_proj(lib_output)
ms, min_ms, max_ms = triton.testing.do_bench(
run_lib,
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/lightning_attention_prefill/",
help="Path to save lightning attention prefill benchmark results",
)
args = parser.parse_args()
# Run correctness test first
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
params = {
"hidden_size": 6144,
"num_attention_heads": 64,
"head_dim": 96,
"hidden_act": "silu",
}
test_lightning_attention_implementations(params)
# Run performance benchmark
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=args.save_path)
import argparse
import itertools
import torch
import triton
from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant
from sgl_kernel.elementwise import silu_and_mul
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
from sglang.srt.layers.quantization import deep_gemm_wrapper
def _test_accuracy_once(E, M, K, input_dtype, device):
x = torch.randn(E, M, K, device=device, dtype=input_dtype)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.full((E,), M, dtype=torch.int32, device=device)
out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks)
out1, blk_scales1 = scaled_fp4_grouped_quant(
silu_and_mul(x),
glb_scales,
masks,
)
torch.testing.assert_close(out, out1)
torch.testing.assert_close(blk_scales, blk_scales1)
print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK")
NUM_RANKS = 48
M_PER_RANKs = [128, 256, 512, 1024]
Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs]
Ks = [2048, 4096, 7168]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M", "K"],
x_vals=list(itertools.product(Ms, Ks)),
x_log=False,
line_arg="provider",
line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
styles=[("blue", "-"), ("orange", "-"), ("green", "-")],
ylabel="ms",
plot_name="fp4 quant",
args={},
)
)
def benchmark(M, K, provider):
E = 6
device = "cuda"
x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device)
fp8_out = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2,
),
device=x.device,
dtype=torch.float8_e4m3fn,
)
scale_block_size = 128
fp8_scales = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2 // scale_block_size,
),
device=x.device,
dtype=torch.float32,
)
quantiles = [0.5, 0.2, 0.8]
if provider == "triton_fp8":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_masked_post_quant_fwd(
x,
fp8_out,
fp8_scales,
scale_block_size,
masks,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
),
quantiles=quantiles,
)
if provider == "cuda_unfused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: scaled_fp4_grouped_quant(
silu_and_mul(x),
glb_scales,
masks,
),
quantiles=quantiles,
)
if provider == "cuda_fused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_scaled_fp4_grouped_quant(
x,
glb_scales,
masks,
),
quantiles=quantiles,
)
return ms, min_ms, max_ms
def test_accuracy():
E = 6
N_RANKS = 48
Ms = [128, 256, 512, 1024]
Ks = [2048, 4096, 7168]
input_dtype = torch.bfloat16
for M in Ms:
for K in Ks:
_test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./bench_fp4_quant_res",
help="Path to save fp4 quant benchmark results",
)
args = parser.parse_args()
test_accuracy()
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
import argparse
import torch
import triton
from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
@torch.compile(backend="inductor")
def torch_int8_quant(x):
int8_max = torch.iinfo(torch.int8).max
abs_max = x.abs().max(dim=-1, keepdim=True).values
scales = abs_max.to(torch.float32) / float(int8_max)
q_x = (x / scales).round().to(torch.int8)
return q_x, scales
def _test_accuracy_once(M, K, input_dtype, device):
x = torch.randn(M, K, dtype=input_dtype, device=device) * 5000
out, scales, _ = vllm_scaled_int8_quant(x, symmetric=True)
out1, scales1 = per_token_quant_int8(x)
out2, scales2 = torch_int8_quant(x)
torch.testing.assert_close(out, out2, atol=1, rtol=0)
torch.testing.assert_close(out, out1, atol=1, rtol=0)
torch.testing.assert_close(scales, scales2)
torch.testing.assert_close(scales1, scales2)
print(f"M: {M}, K: {K}, type: {input_dtype} OK")
def test_accuracy():
Ms = [1, 13, 128, 1024, 2048, 4096]
Ks = [512, 1024, 2048, 8192]
input_dtypes = [torch.float16, torch.bfloat16]
for M in Ms:
for K in Ks:
for input_dtype in input_dtypes:
_test_accuracy_once(M, K, input_dtype, "cuda")
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=["vllm op", "triton", "torch.compile"],
line_names=["vllm op", "triton", "torch.compile"],
styles=[("blue", "-"), ("orange", "-"), ("red", "-")],
ylabel="ms",
plot_name="int8 per token quant",
args={},
)
)
def benchmark(batch_size, provider):
M, K = batch_size, 16384
x = torch.randn(M, K, dtype=torch.float16, device="cuda") * 1000
quantiles = [0.5, 0.2, 0.8]
if provider == "vllm op":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_int8_quant(x, symmetric=True),
quantiles=quantiles,
)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: per_token_quant_int8(x),
quantiles=quantiles,
)
if provider == "torch.compile":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_int8_quant(x),
quantiles=quantiles,
)
return ms, min_ms, max_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./bench_int8_quant_res",
help="Path to save int8 quant benchmark results",
)
args = parser.parse_args()
test_accuracy()
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import json
import multiprocessing as mp
import os
import time
from datetime import datetime
from typing import Any, Dict, List
import torch
import triton
from tqdm import tqdm
mp.set_start_method("spawn", force=True)
from sglang.srt.layers.quantization.fp8_kernel import (
_w8a8_block_fp8_matmul,
_w8a8_block_fp8_matmul_unrolledx4,
)
from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
_is_hip = is_hip()
DTYPE_MAP = {
"float32": torch.float32,
"float16": torch.float16,
"half": torch.half,
"bfloat16": torch.bfloat16,
}
def w8a8_block_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
config: Dict[str, Any],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)
else:
kernel = _w8a8_block_int8_matmul
kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
return C
def get_rocm_configs_compute_bound():
configs = []
waves_per_eu_range = 0
for num_stages in [2]:
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound():
configs = []
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
def get_weight_shapes(tp_size):
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
weight_shapes = []
for t in total:
weight_shapes.append(t)
for n_t in n_tp:
new_t = (n_t[0] // tp_size, n_t[1])
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = (k_t[0], k_t[1] // tp_size)
weight_shapes.append(new_t)
return weight_shapes
def benchmark_config(
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
):
def run():
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
torch.cuda.synchronize()
# JIT complication & warmup
for _ in range(5):
run()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: List[float] = []
for i in range(num_iters):
torch.cuda.synchronize()
start_event.record()
run()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
return avg
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
factor_for_scale = 1e-2
if input_type == "fp8":
fp8_info = torch.finfo(
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
)
B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
)
else:
int8_info = torch.iinfo(torch.int8)
int8_max, int8_min = int8_info.max, int8_info.min
A_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
)
A = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
)
B = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
Bs = (
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
* factor_for_scale
)
best_config = None
best_time = float("inf")
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
A,
B,
As,
Bs,
block_size,
config,
out_dtype,
num_iters=10,
)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={M}")
assert best_config is not None
return best_config
def save_configs(
N,
K,
block_n,
block_k,
configs,
save_path,
input_type="fp8",
) -> None:
os.makedirs(save_path, exist_ok=True)
device_name = get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,block_shape=[{block_n}, {block_k}].json"
config_file_path = os.path.join(save_path, json_file_name)
print(f"Writing best config to {config_file_path}...")
with open(config_file_path, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def get_available_gpu_count():
"""Get the number of available GPUs."""
return torch.cuda.device_count()
def tune_on_gpu(args_dict):
"""Run tuning on a specific GPU."""
gpu_id = args_dict["gpu_id"]
batch_sizes = args_dict["batch_sizes"]
weight_shapes = args_dict["weight_shapes"]
args = args_dict["args"]
torch.cuda.set_device(gpu_id)
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
block_n = args.block_n
block_k = args.block_k
out_dtype = DTYPE_MAP[args.out_dtype]
save_path = args.save_path
input_type = args.input_type
search_space = get_configs_compute_bound()
search_space = [
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
]
start = time.perf_counter()
results = {}
for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"):
N, K = shape[0], shape[1]
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
benchmark_results = [
tune(
batch_size,
N,
K,
[block_n, block_k],
out_dtype,
search_space,
input_type,
)
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
]
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
end = time.perf_counter()
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
def distribute_batch_sizes(batch_sizes, num_gpus):
"""Distribute batch sizes across available GPUs."""
batches_per_gpu = []
for i in range(num_gpus):
start_idx = i * len(batch_sizes) // num_gpus
end_idx = (i + 1) * len(batch_sizes) // num_gpus
batches_per_gpu.append(batch_sizes[start_idx:end_idx])
return batches_per_gpu
def main(args):
print(args)
num_gpus = get_available_gpu_count()
if num_gpus == 0:
raise RuntimeError("No GPU available for tuning")
print(f"Found {num_gpus} GPUs for parallel tuning")
torch.cuda.init()
if args.batch_size is None:
batch_sizes = [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
else:
batch_sizes = [args.batch_size]
num_gpus = 1 # If only one batch size, use only one GPU
weight_shapes = get_weight_shapes(args.tp_size)
batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)
process_args = []
for gpu_id in range(num_gpus):
process_args.append(
{
"gpu_id": gpu_id,
"batch_sizes": batches_per_gpu[gpu_id],
"weight_shapes": weight_shapes, # Each GPU processes all weight shapes
"args": args,
}
)
ctx = mp.get_context("spawn")
with ctx.Pool(num_gpus) as pool:
pool.map(tune_on_gpu, process_args)
print("Multi-GPU tuning completed")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tp-size", "-tp", type=int, default=8)
parser.add_argument(
"--input-type", type=str, choices=["fp8", "int8"], default="fp8"
)
parser.add_argument(
"--out-dtype",
type=str,
choices=["float32", "float16", "bfloat16", "half"],
default="float16",
)
parser.add_argument("--block-n", type=int, default=128)
parser.add_argument("--block-k", type=int, default=128)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument(
"--save-path", type=str, default="python/sglang/srt/layers/quantization/configs"
)
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment