"torchvision/vscode:/vscode.git/clone" did not exist on "14c003bd58dbf02393292517bf724d62e93360d9"
Unverified Commit 3c06b673 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[8/N] MoE Refactor: deprecate `EPMoE` (#11211)

parent 7c3f07db
## Benchmark FBGEMM Grouped GEMM
Benchmark FBGEMM Grouped GEMM in both Triton and CUDA version and SGLang Triton Grouped GEMM, it will be used to compare the bandwidth of different implementations.
### Requirements
```shell
pip install fbgemm-gpu-genai
```
### Usage
```bash
python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
```
For example, in H200, the Qwen2-57B-A14B-Instruct TP4 fp8w8a8 grouped gemm bandwidth result is as follows:
```shell
grouped-gemm-performance:
batch_size FBGEMM Triton Grouped GEMM FP8 FBGEMM CUTLASS F8F8BF16 Rowwise SGLang Grouped GEMM FP8
0 256.0 3704.841339 3042.626402 2254.725030
1 512.0 3691.426346 3029.065684 2269.504543
2 1024.0 3653.938629 2258.471467 2358.319020
3 2048.0 3596.644313 2271.611904 2476.895397
4 4096.0 3468.496435 2231.283986 2179.473910
```
The theoretical peak bandwidth of H200 is 4.8 TB/s. Taking batch_size 256 as an example, the bandwidth of FBGEMM Triton Grouped GEMM FP8 is 3704.841339 GB/s, the bandwidth of FBGEMM CUTLASS F8F8BF16 Rowwise is 3042.626402 GB/s, and the bandwidth of SGLang Grouped GEMM FP8 is 2254.725030 GB/s. Therefore, FBGEMM Triton Grouped GEMM FP8 achieves 77.9% of H200's theoretical peak bandwidth, FBGEMM CUTLASS F8F8BF16 Rowwise achieves 63.4% of H200's theoretical peak bandwidth, and SGLang Grouped GEMM FP8 achieves 46.9% of H200's theoretical peak bandwidth.
# python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
import argparse
import torch
import triton
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
quantize_fp8_row,
triton_quantize_fp8_row,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm as fbgemm_grouped_gemm,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
)
from transformers import AutoConfig
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton as sglang_grouped_gemm,
)
def get_model_config(model_name: str, tp_size: int):
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
num_groups = config.ffn_config.moe_num_experts
intermediate_size = config.ffn_config.ffn_hidden_size
elif config.architectures[0] == "JambaForCausalLM":
num_groups = config.num_experts
intermediate_size = config.intermediate_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
]:
num_groups = config.n_routed_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
num_groups = config.text_config.num_local_experts
intermediate_size = config.text_config.intermediate_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
num_groups = config.num_local_experts
intermediate_size = config.moe_intermediate_size
else:
num_groups = config.num_local_experts
intermediate_size = config.intermediate_size
shape_configs = {
"num_groups": num_groups,
"hidden_size": config.hidden_size,
"intermediate_size": intermediate_size,
"dtype": config.torch_dtype,
}
print(f"{shape_configs=}")
return shape_configs
def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
torch.manual_seed(42)
tokens_per_group = batch_size // num_groups
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
base_weights = torch.randn(
num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda"
)
w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size)
w_sglang = base_weights
c_fbgemm = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
c_sglang = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda")
for i in range(1, num_groups + 1):
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda")
return (
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
)
def create_fp8_test_data(
batch_size, num_groups, hidden_size, intermediate_size, backend="triton"
):
"""
Create test data for FP8 grouped GEMM operations.
Args:
batch_size: Total batch size
num_groups: Number of groups
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM
Returns:
For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale)
For cutlass: (x, wq, w_scale, m_sizes)
"""
torch.manual_seed(42)
tokens_per_group = batch_size // num_groups
# Create weight matrices for each group
w_list = []
for _ in range(num_groups):
w = torch.randn(
intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
)
w_list.append(w)
# Quantize weights using quantize_fp8_row for each group
wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list])
if backend == "triton":
# Triton format: concatenated weights
w_fp8 = torch.concat(wq_list, dim=0).contiguous()
w_scale = torch.concat(w_scale_list, dim=0).contiguous()
# Create m_sizes as int32 for triton
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
)
# Create and quantize input
x_fp16 = torch.randn(
batch_size, hidden_size, dtype=torch.float16, device="cuda"
)
x_fp8, x_scale = triton_quantize_fp8_row(x_fp16)
x_scale = x_scale.view(batch_size, -1)
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
elif backend == "cutlass":
# CUTLASS format: stacked weights
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
# Create m_sizes as int64 for cutlass
m_values = [tokens_per_group] * num_groups
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda")
# Create input data - separate for each group then concat
x_list = []
for _ in range(num_groups):
x = torch.randn(
tokens_per_group, hidden_size, dtype=torch.float16, device="cuda"
)
x_list.append(x)
# Concatenate inputs into single tensor
x = torch.concat(x_list, dim=0).contiguous()
return x, wq, w_scale, m_sizes
else:
raise ValueError(f"Unsupported backend: {backend}")
def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype):
"""
Calculate memory bandwidth based on accessed expert weights.
Args:
m_sizes: Tensor containing batch sizes for each group
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
dtype: Data type of weights
Returns:
Memory size in bytes for accessed expert weights
"""
# Count non-zero groups (active experts)
if hasattr(m_sizes, "cpu"):
active_experts = torch.count_nonzero(m_sizes).item()
else:
active_experts = sum(1 for m in m_sizes if m > 0)
# Calculate bytes per element based on dtype
if dtype in [torch.float16, torch.bfloat16]:
bytes_per_element = 2
elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
bytes_per_element = 1
elif dtype == torch.float32:
bytes_per_element = 4
else:
# Default to 2 bytes for unknown dtypes
bytes_per_element = 2
# Memory per expert weight matrix
memory_per_expert = hidden_size * intermediate_size * bytes_per_element
# Total memory for active experts
total_memory_bytes = active_experts * memory_per_expert
return total_memory_bytes
def get_benchmark_config(use_fp8_w8a8=False):
if use_fp8_w8a8:
return {
"line_vals": [
"fbgemm_triton_grouped_gemm_fp8",
"fbgemm_cutlass_f8f8bf16_rowwise",
"sglang_grouped_gemm",
],
"line_names": [
"FBGEMM Triton Grouped GEMM FP8",
"FBGEMM CUTLASS F8F8BF16 Rowwise",
"SGLang Grouped GEMM FP8",
],
"styles": [("blue", "-"), ("orange", "-"), ("red", "-")],
}
else:
return {
"line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"],
"line_names": [
"FBGEMM Triton Grouped GEMM BF16",
"SGLang Grouped GEMM BF16",
],
"styles": [("blue", "-"), ("green", "-")],
}
def run_benchmark(
model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/"
):
config = get_benchmark_config(use_fp8_w8a8)
benchmark_config = triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[256, 512, 1024, 2048, 4096],
line_arg="provider",
line_vals=config["line_vals"],
line_names=config["line_names"],
styles=config["styles"],
ylabel="Bandwidth (GB/s)",
plot_name="grouped-gemm-performance",
args={},
)
@triton.testing.perf_report(benchmark_config)
def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"Benchmarking {provider} with batch_size={batch_size}")
torch.cuda.manual_seed_all(0)
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]
if provider == "fbgemm_triton_grouped_gemm_fp8":
try:
test_data = create_fp8_test_data(
batch_size,
num_groups,
hidden_size,
intermediate_size,
backend="triton",
)
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
# Calculate memory bandwidth
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
)
def run_func():
return fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)
except Exception as e:
print(f"FP8 not supported, skipping: {e}")
return float("inf"), float("inf"), float("inf")
elif provider == "fbgemm_cutlass_f8f8bf16_rowwise":
try:
test_data = create_fp8_test_data(
batch_size,
num_groups,
hidden_size,
intermediate_size,
backend="cutlass",
)
x, wq, w_scale, m_sizes = test_data
# Calculate memory bandwidth
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
)
# Quantize input using triton_quantize_fp8_row
xq, x_scale = triton_quantize_fp8_row(x)
x_scale = x_scale.view(batch_size, -1)
def run_func():
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
xq, wq, x_scale, w_scale, m_sizes
)
except Exception as e:
print(
f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, "
f"skipping: {e}"
)
return float("inf"), float("inf"), float("inf")
else:
test_data = create_test_data(
batch_size, num_groups, hidden_size, intermediate_size
)
(
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
) = test_data
# Calculate memory bandwidth for BF16 operations
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.bfloat16
)
if provider == "fbgemm_triton_grouped_gemm":
def run_func():
return fbgemm_grouped_gemm(
x, w_fbgemm, m_sizes, use_fast_accum=True
)
else:
def run_func():
return sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
for _ in range(10):
try:
run_func()
except Exception as e:
print(f"Error during warmup for {provider}: {e}")
return float("inf"), float("inf"), float("inf")
torch.cuda.synchronize()
try:
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
# Convert time (ms) to bandwidth (GB/s)
# Bandwidth = Memory (bytes) / Time (seconds)
# Convert ms to seconds and bytes to GB (1e9)
gb_per_s = (memory_bytes / 1e9) / (ms / 1000)
# min bandwidth = max time, max bandwidth = min time
min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000)
max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000)
return gb_per_s, min_gb_per_s, max_gb_per_s
except Exception as e:
print(f"Error during benchmarking for {provider}: {e}")
return 0.0, 0.0, 0.0
dynamic_benchmark.run(
show_plots=True,
print_data=True,
save_path=save_path,
model_config=model_config,
use_fp8_w8a8=use_fp8_w8a8,
)
def verify_correctness(model_config):
print("Verifying correctness...")
batch_size = 128
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
(
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
) = test_data
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
print("✓ BF16 Correctness verification passed!")
else:
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
return False
return True
def main():
parser = argparse.ArgumentParser(
description="Benchmark FBGEMM vs SGLang Grouped GEMM"
)
parser.add_argument(
"--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
help="Model name to get configuration from",
)
parser.add_argument(
"--tp-size", type=int, default=1, help="Tensor parallelism size"
)
parser.add_argument(
"--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark"
)
parser.add_argument(
"--save-path",
type=str,
default="./benchmark_grouped_gemm/",
help="Path to save benchmark results",
)
parser.add_argument(
"--verify-correctness",
action="store_true",
help="Verify correctness before benchmarking",
)
args = parser.parse_args()
try:
model_config = get_model_config(args.model, args.tp_size)
except Exception as e:
print(f"Failed to get model config: {e}")
print("Using default configuration...")
model_config = {
"num_groups": 8,
"hidden_size": 4096,
"intermediate_size": 14336,
"dtype": torch.bfloat16,
}
print("Running benchmark with:")
print(f" num_groups: {model_config['num_groups']}")
print(f" hidden_size: {model_config['hidden_size']}")
print(f" intermediate_size: {model_config['intermediate_size']}")
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
if args.verify_correctness:
if not verify_correctness(model_config):
print("Correctness verification failed. Exiting...")
return
try:
run_benchmark(
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
save_path=args.save_path,
)
except Exception as e:
print(f"Benchmark failed: {e}")
if __name__ == "__main__":
main()
......@@ -246,7 +246,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|-----------|-------------|----------|
| `--ep-size` | The expert parallelism size. | 1 |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none |
| `--moe-runner-backend` | Select the runner backend for MoE. | 'triton' |
| `--moe-runner-backend` | Select the runner backend for MoE. | auto |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in EPLB. | None |
......
......@@ -13,22 +13,18 @@ from sgl_kernel import (
from sglang.srt.layers.moe.ep_moe.kernels import (
post_reorder_triton_kernel_for_cutlass_moe,
pre_reorder_triton_kernel_for_cutlass_moe,
run_cutlass_moe_ep_preproess,
run_moe_ep_preproess,
)
def cutlass_w4a8_moe(
start_expert_id: int,
end_expert_id: int,
total_num_experts: int,
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids_: torch.Tensor,
local_topk_ids: torch.Tensor,
topk_ids: torch.Tensor,
a_strides1: torch.Tensor,
b_strides1: torch.Tensor,
c_strides1: torch.Tensor,
......@@ -64,6 +60,7 @@ def cutlass_w4a8_moe(
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts, N // 512, K * 4]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The ids of each token->expert mapping.
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
......@@ -83,7 +80,7 @@ def cutlass_w4a8_moe(
Returns:
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
"""
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_q.dtype == torch.int8
assert w2_q.dtype == torch.int8
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
......@@ -96,20 +93,21 @@ def cutlass_w4a8_moe(
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
num_experts = w1_q.size(0)
num_local_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(2) * 2 # w1_q is transposed and packed
n = w2_q.size(2) * 2 # w2_q is transposed and packed
topk = topk_ids_.size(1)
topk = topk_ids.size(1)
if apply_router_weight_on_input:
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
device = a.device
topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)
_, src2dst, _ = run_cutlass_moe_ep_preproess(
local_topk_ids,
num_experts,
_, src2dst, _ = run_moe_ep_preproess(
topk_ids,
num_local_experts,
)
gateup_input = torch.empty(
......@@ -122,9 +120,9 @@ def cutlass_w4a8_moe(
a,
gateup_input,
src2dst,
local_topk_ids,
topk_ids,
a1_scale,
total_num_experts,
num_local_experts,
topk,
k,
BLOCK_SIZE=512,
......@@ -133,16 +131,16 @@ def cutlass_w4a8_moe(
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
# they are kept to allow for a quick switch of the permutation logic
# from the current triton kernel implementation to the cutlass-based one if needed.
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
get_cutlass_w4a8_moe_mm_data(
local_topk_ids,
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
num_experts,
num_local_experts,
n,
k,
)
......@@ -195,12 +193,11 @@ def cutlass_w4a8_moe(
c2,
output,
src2dst,
local_topk_ids,
topk_ids,
topk_weights,
num_experts,
topk,
num_local_experts,
k,
0,
BLOCK_SIZE=512,
)
return output
......@@ -130,28 +130,30 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
@triton.jit
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
expert = tl.program_id(0)
expert_id_minus_1 = tl.program_id(0) - 1
low = 0
high = num_toks - 1
target_location = -1
while low <= high:
mid = (low + high) // 2
if tl.load(reorder_topk_ids + mid) > expert:
if tl.load(reorder_topk_ids + mid) > expert_id_minus_1:
high = mid - 1
else:
low = mid + 1
target_location = mid
tl.store(seg_indptr + expert + 1, target_location + 1)
tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1)
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_local_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
seg_indptr = torch.zeros(
num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
compute_seg_indptr_triton_kernel[(num_experts,)](
compute_seg_indptr_triton_kernel[(num_local_experts,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)
......@@ -164,25 +166,6 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
return reorder_topk_ids, src2dst, seg_indptr
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
)
src2dst = torch.empty(
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
)
BLOCK_SIZE = 512
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
compute_src2dst_triton_kernel[grid](
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
)
return reorder_topk_ids, src2dst, seg_indptr
@triton.jit
def pre_reorder_triton_kernel_for_cutlass_moe(
input_ptr,
......@@ -190,52 +173,13 @@ def pre_reorder_triton_kernel_for_cutlass_moe(
src2dst_ptr,
topk_ids_ptr,
a1_scales_ptr,
num_experts,
num_local_experts,
topk,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
OutDtype = gateup_input_ptr.dtype.element_ty
src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id != num_experts:
if a1_scales_ptr is not None:
scale = 1.0 / tl.load(a1_scales_ptr)
else:
scale = 1.0
dst_idx = tl.load(src2dst_ptr + idx)
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
out_data = (in_data * scale).to(OutDtype)
tl.store(dst_ptr + offset, out_data, mask=mask)
@triton.jit
def pre_reorder_triton_kernel(
input_ptr,
gateup_input_ptr,
src2dst_ptr,
topk_ids_ptr,
a1_scales_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
BLOCK_SIZE: tl.constexpr,
use_per_token_if_dynamic: tl.constexpr,
):
OutDtype = gateup_input_ptr.dtype.element_ty
src_idx_int32 = tl.program_id(0)
src_idx = src_idx_int32.to(tl.int64)
src2dst_ptr = src2dst_ptr + src_idx * topk
......@@ -244,15 +188,11 @@ def pre_reorder_triton_kernel(
vec = tl.arange(0, BLOCK_SIZE)
if a1_scales_ptr is not None and use_per_token_if_dynamic:
scale = 1.0 / tl.load(a1_scales_ptr + src_idx)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
if expert_id != num_local_experts:
if a1_scales_ptr is not None:
if not use_per_token_if_dynamic:
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
scale = 1.0 / tl.load(a1_scales_ptr)
else:
scale = 1.0
......@@ -267,52 +207,6 @@ def pre_reorder_triton_kernel(
tl.store(dst_ptr + offset, out_data, mask=mask)
@triton.jit
def silu_and_mul_triton_kernel(
gateup_output,
down_input,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
half_hidden_size = hidden_size // 2
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# silu & mul & quantize
gate_output = gate_output * tl.sigmoid(gate_output)
gate_output = gate_output.to(InDtype)
silu_mul_output = gate_output * up_output * scale
silu_mul_output = silu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
@triton.jit
def _silu_and_mul_post_quant_kernel(
......@@ -461,84 +355,15 @@ def silu_and_mul_masked_post_quant_fwd(
@triton.jit
def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def gelu_and_mul_triton_kernel(
gateup_output,
down_input,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
half_hidden_size = hidden_size // 2
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# gelu & mul & quantize
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
# sqrt(2/pi)
kAlpha = 0.7978845608028654
gate_output = (
0.5
* gate_output
* (
1
+ tanh(
kAlpha
* (
gate_output
+ 0.044715 * gate_output * gate_output * gate_output
)
)
)
)
gate_output = gate_output.to(InDtype)
gelu_mul_output = gate_output * up_output * scale
gelu_mul_output = gelu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
@triton.jit
def post_reorder_triton_kernel(
def post_reorder_triton_kernel_for_cutlass_moe(
down_output_ptr,
output_ptr,
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
start_expert_id,
end_expert_id,
topk,
num_local_experts,
hidden_size,
dst_start,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty
......@@ -549,7 +374,6 @@ def post_reorder_triton_kernel(
topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk
computed = False
store_ptr = output_ptr + src_idx * hidden_size
vec = tl.arange(0, BLOCK_SIZE)
......@@ -561,37 +385,25 @@ def post_reorder_triton_kernel(
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
computed = True
if expert_id != num_local_experts:
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - dst_start
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
sum_vec += in_data * weigh_scale
tl.store(store_ptr + offset, sum_vec, mask=mask)
if computed == False:
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + vec
mask = offset < hidden_size
tl.store(
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
)
@triton.jit
def post_reorder_triton_kernel_for_cutlass_moe(
def post_reorder_triton_kernel(
down_output_ptr,
output_ptr,
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
num_experts,
topk,
hidden_size,
dst_start,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty
......@@ -613,10 +425,9 @@ def post_reorder_triton_kernel_for_cutlass_moe(
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id != num_experts:
if expert_id > 0:
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - dst_start
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
......@@ -624,232 +435,6 @@ def post_reorder_triton_kernel_for_cutlass_moe(
tl.store(store_ptr + offset, sum_vec, mask=mask)
@triton.jit
def compute_m_range(
pid,
batch_size,
seg_indptr,
weight_indices,
m_num_tiles_indptr,
BLOCK_SIZE_M: tl.constexpr,
):
idx = 0
for bs in range(batch_size):
tiles = tl.load(m_num_tiles_indptr + bs)
if pid >= tiles:
idx = bs
idx_start = tl.load(m_num_tiles_indptr + idx)
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
expert_id = tl.load(weight_indices + idx)
return m_range_start, m_range_end, expert_id
@triton.jit
def grouped_gemm_triton_kernel(
a,
b,
c,
batch_size,
N,
K,
seg_indptr,
weight_indices,
m_num_tiles_indptr,
scale_a,
scale_b,
use_fp8_w8a8: tl.constexpr,
group_n: tl.constexpr,
group_k: tl.constexpr,
a_stride_0: tl.constexpr,
b_stride_0: tl.constexpr,
b_stride_1: tl.constexpr,
as_stride_0: tl.constexpr,
as_stride_1: tl.constexpr,
bs_stride_0: tl.constexpr,
bs_stride_2: tl.constexpr,
bs_stride_1: tl.constexpr,
use_per_token_if_dynamic: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
c_dtype = c.dtype.element_ty
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
total_m_block = tl.load(m_num_tiles_indptr + batch_size)
if pid_m >= total_m_block:
return
m_range_start, m_range_end, expert_id = compute_m_range(
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
)
if m_range_end - m_range_start == 0:
return
n_range_start = pid_n * BLOCK_SIZE_N
n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
offs_am = tl.arange(0, BLOCK_SIZE_M)
offs_bn = tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
b_ptr = b + (
(expert_id * b_stride_0)
+ (n_range_start + offs_bn[:, None]) * b_stride_1
+ offs_k[None, :]
)
if group_k > 0 and group_n > 0:
a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
offs_bsn = (n_range_start + offs_bn) // group_n
b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a_tile = tl.load(
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
)
b_tile = tl.load(
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
)
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
else:
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
a_ptr += BLOCK_SIZE_K
b_ptr += BLOCK_SIZE_K
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
if use_per_token_if_dynamic:
scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None]))
else:
scale_a_value = tl.load(scale_a + expert_id)
scale_b_value = tl.load(scale_b + expert_id)
accumulator *= scale_a_value * scale_b_value
c_tile = accumulator.to(c_dtype)
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
tl.store(c_ptr, c_tile, mask=c_mask)
@triton.jit
def compute_m_num_tiles_indptr(
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
):
for bs in range(batch_size):
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
def grouped_gemm_triton(
a: torch.Tensor,
b: torch.Tensor,
c: torch.Tensor,
batch_size: int,
weight_column_major: bool,
seg_indptr: Optional[torch.Tensor] = None,
weight_indices: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
scale_a: torch.Tensor = None,
scale_b: torch.Tensor = None,
block_shape: Optional[List[int]] = None,
c_dtype=None,
use_per_token_if_dynamic: bool = True,
):
assert weight_column_major == True # TODO: more
if use_fp8_w8a8 and block_shape is None:
assert scale_a is not None and scale_b is not None
if block_shape is not None:
a_original = a
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
a, scale_a = per_token_group_quant_fp8(a, block_k)
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
dispose_tensor(a_original)
# TODO: adjust config or tune kernel
# Reduce block size to prevent L40 shared memory overflow.
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
}
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
compute_m_num_tiles_indptr[(1,)](
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
)
if c is None:
assert c_dtype is not None
c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
grid = lambda META: (
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
)
if use_fp8_w8a8 and block_shape is None and use_per_token_if_dynamic:
assert (
scale_a.shape[0] == a.shape[0]
), f"scale_a.shape: {scale_a.shape}, a.shape: {a.shape}"
grouped_gemm_triton_kernel[grid](
a,
b,
c,
batch_size,
b.size(1),
b.size(2),
seg_indptr,
weight_indices,
m_num_tiles_indptr,
scale_a,
scale_b,
use_fp8_w8a8,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
a.stride(0),
b.stride(0),
b.stride(1),
scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
use_per_token_if_dynamic,
**config,
)
return c
@triton.jit
def _fwd_kernel_ep_scatter_1(
num_recv_tokens_per_expert,
......@@ -1234,7 +819,7 @@ def deepgemm_compute_src2dst_triton_kernel(
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
expert_dst_start = tl.load(seg_indptr + expert_id)
expert_dst_start = tl.load(seg_indptr + expert_id, mask=(expert_id >= 0))
expert_dst_offset = dst_id - expert_dst_start
dst_id = expert_id * m_max + expert_dst_offset
tl.store(src2dst + src_id, dst_id, mask=mask)
......@@ -1248,10 +833,7 @@ def fill_gateup_input_triton_kernel(
gateup_input_scale_ptr,
src2dst_ptr,
topk_ids_ptr,
start_expert_id,
end_expert_id,
topk,
m_max,
hidden_size,
scale_size,
BLOCK_SIZE: tl.constexpr,
......@@ -1267,10 +849,9 @@ def fill_gateup_input_triton_kernel(
vec = tl.arange(0, BLOCK_SIZE)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
if expert_id >= 0:
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - start_expert_id * m_max
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + vec
......@@ -1287,31 +868,31 @@ def fill_gateup_input_triton_kernel(
def moe_ep_deepgemm_preprocess(
topk_ids: torch.Tensor,
num_experts: int,
num_local_experts: int,
hidden_states: torch.Tensor,
top_k: int,
start_expert_id,
end_expert_id,
block_shape,
output_dtype: torch.dtype = torch.float8_e4m3fn,
):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
seg_indptr = torch.zeros(
num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
masked_m = torch.empty(num_local_experts, device=topk_ids.device, dtype=torch.int32)
compute_seg_indptr_triton_kernel[(num_experts,)](
compute_seg_indptr_triton_kernel[(num_local_experts + 1,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)
grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
compute_masked_m_triton_kernel[(num_local_experts,)](seg_indptr, masked_m)
# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
m_max = (hidden_states.size(0) + 255) // 256 * 256
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
m_max = (hidden_states.size(0) // 256 + 1) * 256
expected_m = (topk_ids.numel() - 1) // num_local_experts + 1
gateup_input = torch.empty(
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
(num_local_experts, m_max, hidden_states.size(1)),
device=hidden_states.device,
dtype=output_dtype,
)
......@@ -1330,6 +911,8 @@ def moe_ep_deepgemm_preprocess(
block_shape = [128, 128]
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
# TODO: fuse this with the preprocess
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
gateup_input_scale = torch.empty(
......@@ -1345,18 +928,14 @@ def moe_ep_deepgemm_preprocess(
gateup_input_scale,
src2dst,
topk_ids,
start_expert_id,
end_expert_id,
top_k,
m_max,
hidden_states.size(1),
scale.size(1),
BLOCK_SIZE=1024,
)
return (
m_max,
masked_m[start_expert_id : (end_expert_id + 1)],
masked_m,
expected_m,
src2dst,
gateup_input,
......
from __future__ import annotations
import logging
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
import triton
import triton.language as tl
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
......@@ -18,13 +14,10 @@ from sglang.srt.layers.moe import (
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
moe_ep_deepgemm_preprocess,
post_reorder_triton_kernel,
silu_and_mul_masked_post_quant_fwd,
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config
......@@ -36,19 +29,10 @@ from sglang.srt.layers.quantization.modelopt_quant import (
CUTEDSL_MOE_NVFP4_DISPATCH,
ModelOptNvFp4FusedMoEMethod,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.offloader import get_offloader
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import (
ceil_div,
dispose_tensor,
get_bool_env_var,
get_int_env_var,
is_cuda,
is_hip,
is_npu,
)
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
......@@ -72,29 +56,13 @@ if _use_aiter:
logger = logging.getLogger(__name__)
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
temp = x.to(torch.float32).view(torch.int32)
exp = torch.bitwise_right_shift(temp, 23)
mant = torch.bitwise_and(temp, 0x7FFFFF)
is_ru = torch.logical_and(
torch.logical_and((mant > 0), (exp != 0xFE)),
~torch.logical_and((exp == 0), (mant <= 0x400000)),
)
exp = torch.where(is_ru, exp + 1, exp)
new_x = exp.to(torch.uint8).view(torch.int)
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
class EPMoE(FusedMoE):
class DeepEPMoE(FusedMoE):
"""
MoE Expert Parallel Impl
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
"""
_has_printed = False
def __init__(
self,
num_experts: int,
......@@ -108,272 +76,29 @@ class EPMoE(FusedMoE):
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
with_bias: bool = False,
):
super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_fused_shared_experts=num_fused_shared_experts,
layer_id=layer_id,
top_k=top_k,
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
with_bias=with_bias,
)
self.intermediate_size = intermediate_size
if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
else:
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm(hidden_states, topk_output)
else:
return super().forward(hidden_states, topk_output)
def forward_deepgemm(
self,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
):
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
topk_weights, topk_ids, _ = topk_output
if not self.use_block_quant:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size = 128
w13_weight_scale_n = 2 * (
(self.intermediate_size + scale_block_size - 1) // scale_block_size
)
w13_weight_scale_k = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w13_weight_scale = (
self.w13_weight_scale.unsqueeze(1)
.repeat_interleave(w13_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w13_weight_scale_k, dim=2)
)
self.w13_weight_fp8 = (
self.w13_weight,
w13_weight_scale,
)
w2_weight_scale_n = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w2_weight_scale_k = (
self.intermediate_size + scale_block_size - 1
) // scale_block_size
w2_weight_scale = (
self.w2_weight_scale.unsqueeze(1)
.repeat_interleave(w2_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w2_weight_scale_k, dim=2)
)
self.w2_weight_fp8 = (
self.w2_weight,
w2_weight_scale,
)
# PreReorder
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
self.num_experts,
hidden_states,
self.top_k,
self.start_expert_id,
self.end_expert_id,
self.block_shape,
)
)
dispose_tensor(hidden_states)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
b, s_mn, s_k = gateup_input_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
# GroupGemm-0
gateup_input_fp8 = (
gateup_input,
(
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
gateup_input_scale
)
),
)
num_groups, m, k = gateup_input_fp8[0].size()
n = self.w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
gateup_input_fp8,
self.w13_weight_fp8,
gateup_output,
masked_m,
expected_m,
)
del gateup_input
del gateup_input_fp8
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
(
down_input_scale
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
),
)
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8,
self.w2_weight_fp8,
down_output,
masked_m,
expected_m,
)
del down_input
del down_input_fp8
# PostReorder
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states_shape[1],
m_max * self.start_expert_id,
BLOCK_SIZE=512,
)
if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.moe_runner_config.routed_scaling_factor
return output
class DeepEPMoE(EPMoE):
"""
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
"""
_has_printed = False
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
):
super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
layer_id=layer_id,
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
activation=activation,
routed_scaling_factor=routed_scaling_factor,
)
self.deepep_mode = get_deepep_mode()
# TODO: move to the beginning of the file
......@@ -567,7 +292,6 @@ class DeepEPMoE(EPMoE):
N = self.w13_weight.size(1)
scale_block_size = 128
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
w13_weight_fp8 = (
self.w13_weight,
(
......@@ -988,8 +712,6 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
return FlashInferFusedMoE
if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE
if get_moe_expert_parallel_world_size() > 1:
return EPMoE
return FusedMoE
......
......@@ -156,8 +156,7 @@ class FusedMoE(torch.nn.Module):
self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
if self.moe_ep_size > 1:
# TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1
......
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
import torch
from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo,
MoeRunnerConfig,
MoeRunnerCore,
RunnerInput,
RunnerOutput,
register_post_permute,
register_pre_permute,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend
from sglang.srt.utils import dispose_tensor
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
)
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
temp = x.to(torch.float32).view(torch.int32)
exp = torch.bitwise_right_shift(temp, 23)
mant = torch.bitwise_and(temp, 0x7FFFFF)
is_ru = torch.logical_and(
torch.logical_and((mant > 0), (exp != 0xFE)),
~torch.logical_and((exp == 0), (mant <= 0x400000)),
)
exp = torch.where(is_ru, exp + 1, exp)
new_x = exp.to(torch.uint8).view(torch.int)
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
@dataclass
class DeepGemmRunnerInput(RunnerInput):
hidden_states: torch.Tensor
hidden_states_scale: torch.Tensor
masked_m: torch.Tensor
expected_m: int
use_masked_gemm: bool
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
@dataclass
class DeepGemmRunnerOutput(RunnerOutput):
hidden_states: torch.Tensor
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
@dataclass
class DeepGemmMoeQuantInfo(MoeQuantInfo):
w13_weight: torch.Tensor
w2_weight: torch.Tensor
use_fp8: bool
w13_scale: Optional[torch.Tensor] = None
w2_scale: Optional[torch.Tensor] = None
block_shape: Optional[List[int]] = None
class DeepGemmRunnerCore(MoeRunnerCore):
def __init__(self, config: MoeRunnerConfig):
super().__init__(config)
assert self.config.activation == "silu"
def run(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> DeepGemmRunnerOutput:
if runner_input.use_masked_gemm:
hidden_states = self._run_masked_gemm(
runner_input,
quant_info,
running_state,
)
else:
hidden_states = self._run_contiguous_gemm(
runner_input,
quant_info,
running_state,
)
return DeepGemmRunnerOutput(hidden_states=hidden_states)
def _run_masked_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
hidden_states = runner_input.hidden_states
hidden_states_scale = runner_input.hidden_states_scale
masked_m = runner_input.masked_m
expected_m = runner_input.expected_m
w13_weight = quant_info.w13_weight
w2_weight = quant_info.w2_weight
w13_scale = quant_info.w13_scale
w2_scale = quant_info.w2_scale
hidden_states_device = running_state["hidden_states_device"]
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
b, s_mn, s_k = hidden_states_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
# GroupGemm-0
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
else:
hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
hidden_states_scale
)
num_groups, m, k = hidden_states.shape
n = w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(hidden_states, hidden_states_scale),
(w13_weight, w13_scale),
gateup_output,
masked_m,
expected_m,
)
dispose_tensor(hidden_states)
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=torch.float8_e4m3fn,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = w2_weight.shape[1]
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
down_input_scale
)
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(down_input, down_input_scale),
(w2_weight, w2_scale),
down_output,
masked_m,
expected_m,
)
del down_input
return down_output
def _run_contiguous_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
pass
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
@register_pre_permute("standard", "deep_gemm")
def pre_permute_standard_to_deep_gemm(
dispatch_output: StandardDispatchOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepGemmRunnerInput:
from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
hidden_states, topk_output = dispatch_output
topk_weights, topk_ids, _ = topk_output
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
hidden_states_ref = hidden_states
topk_weights, topk_ids = topk_weights, topk_ids
# PreReorder
masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
runner_config.num_local_experts,
hidden_states,
runner_config.top_k,
quant_info.block_shape,
)
)
dispose_tensor(hidden_states_ref)
running_state["topk_ids"] = topk_ids
running_state["topk_weights"] = topk_weights
running_state["hidden_states_shape"] = hidden_states_shape
running_state["hidden_states_dtype"] = hidden_states_dtype
running_state["hidden_states_device"] = hidden_states_device
running_state["src2dst"] = src2dst
return DeepGemmRunnerInput(
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
masked_m=masked_m,
expected_m=expected_m,
use_masked_gemm=True,
)
@register_post_permute("deep_gemm", "standard")
def post_permute_deep_gemm_to_standard(
runner_output: DeepGemmRunnerOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> StandardCombineInput:
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
hidden_states_shape = running_state["hidden_states_shape"]
hidden_states_dtype = running_state["hidden_states_dtype"]
hidden_states_device = running_state["hidden_states_device"]
src2dst = running_state["src2dst"]
topk_ids = running_state["topk_ids"]
topk_weights = running_state["topk_weights"]
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
runner_output.hidden_states,
output,
src2dst,
topk_ids,
topk_weights,
runner_config.top_k,
hidden_states_shape[1],
BLOCK_SIZE=512,
)
dispose_tensor(runner_output.hidden_states)
if runner_config.routed_scaling_factor is not None:
output *= runner_config.routed_scaling_factor
return StandardCombineInput(
hidden_states=output,
)
......@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
MoeRunnerConfig,
PermuteMethodPool,
)
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
......@@ -30,6 +31,8 @@ class MoeRunner:
if runner_backend.is_triton():
self.runner_core = TritonRunnerCore(config)
elif runner_backend.is_deep_gemm():
self.runner_core = DeepGemmRunnerCore(config)
else:
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
......
......@@ -44,6 +44,7 @@ class MoeA2ABackend(Enum):
class MoeRunnerBackend(Enum):
AUTO = "auto"
DEEP_GEMM = "deep_gemm"
TRITON = "triton"
TRITON_KERNEL = "triton_kernel"
FLASHINFER_TRTLLM = "flashinfer_trtllm"
......@@ -54,6 +55,9 @@ class MoeRunnerBackend(Enum):
def is_auto(self):
return self == MoeRunnerBackend.AUTO
def is_deep_gemm(self):
return self == MoeRunnerBackend.DEEP_GEMM
def is_triton(self):
return self == MoeRunnerBackend.TRITON
......@@ -147,7 +151,9 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
def get_moe_runner_backend() -> MoeRunnerBackend:
global MOE_RUNNER_BACKEND
if MOE_RUNNER_BACKEND is None:
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
logger.warning(
"MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected"
)
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
return MOE_RUNNER_BACKEND
......
......@@ -31,8 +31,8 @@ except ImportError:
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
from sglang.srt.layers.parameter import (
BlockQuantScaleParameter,
ModelWeightParameter,
......@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
from sglang.srt.layers.moe.utils import (
get_moe_a2a_backend,
get_moe_runner_backend,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
moe_runner_backend = get_moe_runner_backend()
if moe_runner_backend.is_auto():
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and get_moe_a2a_backend().is_deepep()
):
moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
else:
moe_runner_backend = MoeRunnerBackend.TRITON
if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
else:
# TODO(cwan): refactor other backends
pass
def apply(
self,
......@@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
return StandardCombineInput(hidden_states=output)
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
w13_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
if self.runner.runner_backend.is_deep_gemm():
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
if self.block_quant:
block_shape = self.quant_config.weight_block_size
w13_scale = layer.w13_weight_scale_inv
w2_scale = layer.w2_weight_scale_inv
else:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size = 128
block_shape = [scale_block_size, scale_block_size]
w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
w13_scale = (
layer.w13_weight_scale.unsqueeze(1)
.repeat_interleave(w13_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w13_scale_k, dim=2)
)
w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
w2_scale = (
layer.w2_weight_scale.unsqueeze(1)
.repeat_interleave(w2_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w2_scale_k, dim=2)
)
quant_info = DeepGemmMoeQuantInfo(
w13_weight=w13_weight,
w2_weight=w2_weight,
use_fp8=True,
w13_scale=w13_scale,
w2_scale=w2_scale,
block_shape=block_shape,
)
elif self.runner.runner_backend.is_triton():
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
w13_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
else:
raise NotImplementedError(
"Unsupported runner backend: %s" % self.runner.runner_backend
)
return self.runner.run(dispatch_output, quant_info)
def apply_with_router_logits(
......
......@@ -21,7 +21,6 @@ from sglang.srt.utils import is_npu, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
......@@ -94,9 +93,7 @@ class W4AFp8Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.managers.schedule_batch import global_server_args_dict
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
......@@ -133,7 +130,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def create_weights(
self,
layer: EPMoE,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
......@@ -292,7 +289,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: EPMoE,
layer: Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
......@@ -303,18 +300,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
local_topk_ids = topk_ids
if get_moe_expert_parallel_world_size() > 1:
local_topk_ids = torch.where(
topk_ids == -1,
layer.num_experts,
topk_ids,
)
output = cutlass_w4a8_moe(
layer.start_expert_id,
layer.end_expert_id,
layer.num_experts,
x,
layer.w13_weight,
layer.w2_weight,
......@@ -322,7 +309,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv,
topk_weights,
topk_ids,
local_topk_ids,
self.a_strides1,
self.b_strides1,
self.c_strides1,
......
......@@ -49,7 +49,6 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.router import fused_moe_router_shim
from sglang.srt.layers.moe.topk import TopK
......@@ -176,17 +175,7 @@ class Grok1MoE(nn.Module):
custom_routing_function=custom_routing_function,
)
kwargs = {}
if get_moe_expert_parallel_world_size() > 1:
MoEImpl = EPMoE
else:
MoEImpl = FusedMoE
kwargs["reduce_results"] = reduce_results
kwargs["use_presharded_weights"] = use_presharded_weights
kwargs["inplace"] = inplace
kwargs["no_combine"] = no_combine
self.experts = MoEImpl(
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
layer_id=layer_id,
......@@ -195,7 +184,10 @@ class Grok1MoE(nn.Module):
params_dtype=params_dtype,
quant_config=quant_config,
activation="gelu",
**kwargs,
reduce_results=reduce_results,
use_presharded_weights=use_presharded_weights,
inplace=inplace,
no_combine=no_combine,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
......@@ -36,7 +36,6 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
......@@ -94,8 +93,7 @@ class MixtralMoE(nn.Module):
renormalize=True,
)
MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE
self.experts = MoEImpl(
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
layer_id=layer_id,
......
......@@ -121,6 +121,17 @@ NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
MOE_RUNNER_BACKEND_CHOICES = [
"auto",
"deep_gemm",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
"flashinfer_cutedsl",
]
# Allow external code to add more choices
def add_load_format_choices(choices):
......@@ -143,6 +154,10 @@ def add_grammar_backend_choices(choices):
GRAMMAR_BACKEND_CHOICES.extend(choices)
def add_moe_runner_backend_choices(choices):
MOE_RUNNER_BACKEND_CHOICES.extend(choices)
def add_deterministic_attention_backend_choices(choices):
DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices)
......@@ -315,14 +330,7 @@ class ServerArgs:
# Expert parallelism
ep_size: int = 1
moe_a2a_backend: Literal["none", "deepep"] = "none"
moe_runner_backend: Literal[
"auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
] = "auto"
moe_runner_backend: str = "auto"
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
......@@ -2191,15 +2199,7 @@ class ServerArgs:
parser.add_argument(
"--moe-runner-backend",
type=str,
choices=[
"auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
"flashinfer_cutedsl",
],
choices=MOE_RUNNER_BACKEND_CHOICES,
default=ServerArgs.moe_runner_backend,
help="Choose the runner backend for MoE.",
)
......
import itertools
import random
import unittest
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.test.test_utils import CustomTestCase
# For test
def ep_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
topk_config: TopKConfig,
# ep config
num_experts: int = 256,
fp8_dtype: torch.types = torch.float8_e4m3fn,
num_experts_per_partition: int = 128,
start_expert_id: int = 0,
end_expert_id: int = 127,
use_fp8_w8a8: bool = False,
w1_scale_inv: Optional[torch.Tensor] = None,
w2_scale_inv: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
):
use_blockwise_fp8 = block_shape is not None
top_k = topk_config.top_k
topk_output = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=topk_config,
)
topk_weights, topk_ids, _ = topk_output
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
gateup_input = torch.empty(
(int(hidden_states.shape[0] * top_k), hidden_states.shape[1]),
device=hidden_states.device,
dtype=(
fp8_dtype
if (use_fp8_w8a8 and not use_blockwise_fp8)
else hidden_states.dtype
),
)
if use_fp8_w8a8 and not use_blockwise_fp8:
max_value = (
torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32)
)
w1_input_scale = max_value / torch.finfo(fp8_dtype).max
else:
w1_input_scale = None
# PreReorder
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
src2dst,
topk_ids,
w1_input_scale,
start_expert_id,
end_expert_id,
top_k,
hidden_states.shape[1],
BLOCK_SIZE=512,
use_per_token_if_dynamic=True,
)
seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2]
weight_indices_cur_rank = torch.arange(
0,
num_experts_per_partition,
device=hidden_states.device,
dtype=torch.int64,
)
# GroupGemm-0
gateup_output = torch.empty(
gateup_input.shape[0],
w1.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
gateup_output = grouped_gemm_triton(
a=gateup_input,
b=w1,
c=gateup_output,
batch_size=num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=use_fp8_w8a8,
scale_a=w1_input_scale,
scale_b=w1_scale_inv,
block_shape=block_shape,
)
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
fp8_dtype
if (use_fp8_w8a8 and not use_blockwise_fp8)
else hidden_states.dtype
),
)
if use_fp8_w8a8 and not use_blockwise_fp8:
w2_input_scale = torch.ones(
num_experts_per_partition,
dtype=torch.float32,
device=hidden_states.device,
)
else:
w2_input_scale = None
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
w2_input_scale,
start_expert_id,
end_expert_id,
BLOCK_SIZE=512,
)
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
w2.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
down_output = grouped_gemm_triton(
a=down_input,
b=w2,
c=down_output,
batch_size=num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=use_fp8_w8a8,
scale_a=w2_input_scale,
scale_b=w2_scale_inv,
block_shape=block_shape,
)
# PostReorder
output = torch.empty_like(hidden_states)
post_reorder_triton_kernel[(hidden_states.size(0),)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
top_k,
hidden_states.size(1),
0,
BLOCK_SIZE=512,
)
return output
# test util
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function converts block-wise quantization to tensor-wise quantization.
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
and the block size.
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
Note only float8 is supported for now.
"""
# process 3D tensor
if x_q_block.dim() == 3:
batch_size = x_q_block.size(0)
return torch.stack(
[block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)]
)
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
x_dq_block_tiles = [
[
x_dq_block[
j * block_n : min((j + 1) * block_n, n),
i * block_k : min((i + 1) * block_k, k),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
return x_dq_block
class TestW8A8BlockFP8EPMoE(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 222, 1024, 2048]
N = [128, 1024, 2048]
K = [256, 4096, 5120]
E = [8, 16]
ep_size = [2, 4]
TOP_KS = [2, 4]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w8a8_block_fp8_ep_moe(
self, M, N, K, E, ep_size, topk, block_size, dtype, seed
):
torch.manual_seed(seed)
random.seed(seed)
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max
w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max
w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_s = (
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
* factor_for_scale
)
w2_s = (
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
* factor_for_scale
)
w1_ref = block_dequant(w1, w1_s, block_size).to(dtype)
w2_ref = block_dequant(w2, w2_s, block_size).to(dtype)
score = torch.randn((M, E), dtype=dtype)
num_experts_per_partition = E // ep_size
cur_rank = random.randint(0, ep_size - 1)
start_id = cur_rank * num_experts_per_partition
end_id = start_id + num_experts_per_partition - 1
topk_config = TopKConfig(
top_k=topk,
renormalize=False,
)
with torch.inference_mode():
out = ep_moe(
hidden_states=a,
w1=w1,
w2=w2,
router_logits=score,
topk_config=topk_config,
use_fp8_w8a8=True,
w1_scale_inv=w1_s,
w2_scale_inv=w2_s,
block_shape=block_size,
num_experts=E,
num_experts_per_partition=num_experts_per_partition,
start_expert_id=start_id,
end_expert_id=end_id,
)
ref_out = ep_moe(
hidden_states=a,
w1=w1_ref,
w2=w2_ref,
router_logits=score,
topk_config=topk_config,
use_fp8_w8a8=False,
w1_scale_inv=None,
w2_scale_inv=None,
block_shape=None,
num_experts=E,
num_experts_per_partition=num_experts_per_partition,
start_expert_id=start_id,
end_expert_id=end_id,
)
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6)
< 0.06
)
def test_w8a8_block_fp8_ep_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.ep_size,
self.TOP_KS,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
ep_size=params[4],
topk=params[5],
block_size=params[6],
dtype=params[7],
seed=params[8],
):
self._w8a8_block_fp8_ep_moe(*params)
torch.cuda.empty_cache()
if __name__ == "__main__":
unittest.main(verbosity=2)
......@@ -120,7 +120,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
)
topk_weights, topk_ids, _ = topk_output
expert_map = torch.arange(E, dtype=torch.int32, device=device)
expert_map[local_e:] = E
expert_map[local_e:] = -1
output = cutlass_moe(
a,
......@@ -138,9 +138,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
c_strides2,
s_strides13,
s_strides2,
0,
local_e - 1,
E,
local_e,
a1_scale,
a2_scale,
expert_map,
......@@ -178,7 +176,7 @@ def cutlass_moe(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids_: torch.Tensor,
topk_ids: torch.Tensor,
a_strides1: torch.Tensor,
b_strides1: torch.Tensor,
c_strides1: torch.Tensor,
......@@ -187,40 +185,32 @@ def cutlass_moe(
c_strides2: torch.Tensor,
s_strides13: torch.Tensor,
s_strides2: torch.Tensor,
start_expert_id: int,
end_expert_id: int,
E: int,
num_local_experts: int,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
):
local_topk_ids = topk_ids_
local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
topk_ids = expert_map[topk_ids]
device = a.device
local_num_experts = end_expert_id - start_expert_id + 1
expert_offsets = torch.empty(
(local_num_experts + 1), dtype=torch.int32, device=device
(num_local_experts + 1), dtype=torch.int32, device=device
)
problem_sizes1 = torch.empty(
(local_num_experts, 3), dtype=torch.int32, device=device
(num_local_experts, 3), dtype=torch.int32, device=device
)
problem_sizes2 = torch.empty(
(local_num_experts, 3), dtype=torch.int32, device=device
(num_local_experts, 3), dtype=torch.int32, device=device
)
return cutlass_w4a8_moe(
start_expert_id,
end_expert_id,
E,
a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids_,
local_topk_ids,
topk_ids,
a_strides1,
b_strides1,
c_strides1,
......
......@@ -12,7 +12,7 @@ from sglang.test.test_utils import (
)
class TestEpMoE(CustomTestCase):
class TestEp(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
......@@ -34,18 +34,6 @@ class TestEpMoE(CustomTestCase):
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
......@@ -59,7 +47,7 @@ class TestEpMoE(CustomTestCase):
self.assertGreaterEqual(metrics["score"], 0.8)
class TestEpMoEFP8(CustomTestCase):
class TestEpDeepGEMM(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
......@@ -76,6 +64,8 @@ class TestEpMoEFP8(CustomTestCase):
"2",
"--quantization",
"fp8",
"--moe-runner-backend",
"deep_gemm",
],
)
......@@ -83,18 +73,6 @@ class TestEpMoEFP8(CustomTestCase):
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
......
......@@ -130,6 +130,7 @@ suites = {
TestFile("test_modelopt_loader.py", 30),
],
"per-commit-2-gpu": [
TestFile("ep/test_moe_ep.py", 140),
TestFile("lora/test_lora_tp.py", 116),
TestFile("rl/test_update_weights_from_distributed.py", 103),
TestFile("test_data_parallelism.py", 73),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment