Unverified Commit 3c06b673 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[8/N] MoE Refactor: deprecate `EPMoE` (#11211)

parent 7c3f07db
## Benchmark FBGEMM Grouped GEMM
Benchmark FBGEMM Grouped GEMM in both Triton and CUDA version and SGLang Triton Grouped GEMM, it will be used to compare the bandwidth of different implementations.
### Requirements
```shell
pip install fbgemm-gpu-genai
```
### Usage
```bash
python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
```
For example, in H200, the Qwen2-57B-A14B-Instruct TP4 fp8w8a8 grouped gemm bandwidth result is as follows:
```shell
grouped-gemm-performance:
batch_size FBGEMM Triton Grouped GEMM FP8 FBGEMM CUTLASS F8F8BF16 Rowwise SGLang Grouped GEMM FP8
0 256.0 3704.841339 3042.626402 2254.725030
1 512.0 3691.426346 3029.065684 2269.504543
2 1024.0 3653.938629 2258.471467 2358.319020
3 2048.0 3596.644313 2271.611904 2476.895397
4 4096.0 3468.496435 2231.283986 2179.473910
```
The theoretical peak bandwidth of H200 is 4.8 TB/s. Taking batch_size 256 as an example, the bandwidth of FBGEMM Triton Grouped GEMM FP8 is 3704.841339 GB/s, the bandwidth of FBGEMM CUTLASS F8F8BF16 Rowwise is 3042.626402 GB/s, and the bandwidth of SGLang Grouped GEMM FP8 is 2254.725030 GB/s. Therefore, FBGEMM Triton Grouped GEMM FP8 achieves 77.9% of H200's theoretical peak bandwidth, FBGEMM CUTLASS F8F8BF16 Rowwise achieves 63.4% of H200's theoretical peak bandwidth, and SGLang Grouped GEMM FP8 achieves 46.9% of H200's theoretical peak bandwidth.
# python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
import argparse
import torch
import triton
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
quantize_fp8_row,
triton_quantize_fp8_row,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm as fbgemm_grouped_gemm,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
)
from transformers import AutoConfig
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton as sglang_grouped_gemm,
)
def get_model_config(model_name: str, tp_size: int):
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
num_groups = config.ffn_config.moe_num_experts
intermediate_size = config.ffn_config.ffn_hidden_size
elif config.architectures[0] == "JambaForCausalLM":
num_groups = config.num_experts
intermediate_size = config.intermediate_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
]:
num_groups = config.n_routed_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
num_groups = config.text_config.num_local_experts
intermediate_size = config.text_config.intermediate_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
num_groups = config.num_local_experts
intermediate_size = config.moe_intermediate_size
else:
num_groups = config.num_local_experts
intermediate_size = config.intermediate_size
shape_configs = {
"num_groups": num_groups,
"hidden_size": config.hidden_size,
"intermediate_size": intermediate_size,
"dtype": config.torch_dtype,
}
print(f"{shape_configs=}")
return shape_configs
def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
torch.manual_seed(42)
tokens_per_group = batch_size // num_groups
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
base_weights = torch.randn(
num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda"
)
w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size)
w_sglang = base_weights
c_fbgemm = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
c_sglang = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda")
for i in range(1, num_groups + 1):
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda")
return (
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
)
def create_fp8_test_data(
batch_size, num_groups, hidden_size, intermediate_size, backend="triton"
):
"""
Create test data for FP8 grouped GEMM operations.
Args:
batch_size: Total batch size
num_groups: Number of groups
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM
Returns:
For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale)
For cutlass: (x, wq, w_scale, m_sizes)
"""
torch.manual_seed(42)
tokens_per_group = batch_size // num_groups
# Create weight matrices for each group
w_list = []
for _ in range(num_groups):
w = torch.randn(
intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
)
w_list.append(w)
# Quantize weights using quantize_fp8_row for each group
wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list])
if backend == "triton":
# Triton format: concatenated weights
w_fp8 = torch.concat(wq_list, dim=0).contiguous()
w_scale = torch.concat(w_scale_list, dim=0).contiguous()
# Create m_sizes as int32 for triton
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
)
# Create and quantize input
x_fp16 = torch.randn(
batch_size, hidden_size, dtype=torch.float16, device="cuda"
)
x_fp8, x_scale = triton_quantize_fp8_row(x_fp16)
x_scale = x_scale.view(batch_size, -1)
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
elif backend == "cutlass":
# CUTLASS format: stacked weights
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
# Create m_sizes as int64 for cutlass
m_values = [tokens_per_group] * num_groups
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda")
# Create input data - separate for each group then concat
x_list = []
for _ in range(num_groups):
x = torch.randn(
tokens_per_group, hidden_size, dtype=torch.float16, device="cuda"
)
x_list.append(x)
# Concatenate inputs into single tensor
x = torch.concat(x_list, dim=0).contiguous()
return x, wq, w_scale, m_sizes
else:
raise ValueError(f"Unsupported backend: {backend}")
def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype):
"""
Calculate memory bandwidth based on accessed expert weights.
Args:
m_sizes: Tensor containing batch sizes for each group
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
dtype: Data type of weights
Returns:
Memory size in bytes for accessed expert weights
"""
# Count non-zero groups (active experts)
if hasattr(m_sizes, "cpu"):
active_experts = torch.count_nonzero(m_sizes).item()
else:
active_experts = sum(1 for m in m_sizes if m > 0)
# Calculate bytes per element based on dtype
if dtype in [torch.float16, torch.bfloat16]:
bytes_per_element = 2
elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
bytes_per_element = 1
elif dtype == torch.float32:
bytes_per_element = 4
else:
# Default to 2 bytes for unknown dtypes
bytes_per_element = 2
# Memory per expert weight matrix
memory_per_expert = hidden_size * intermediate_size * bytes_per_element
# Total memory for active experts
total_memory_bytes = active_experts * memory_per_expert
return total_memory_bytes
def get_benchmark_config(use_fp8_w8a8=False):
if use_fp8_w8a8:
return {
"line_vals": [
"fbgemm_triton_grouped_gemm_fp8",
"fbgemm_cutlass_f8f8bf16_rowwise",
"sglang_grouped_gemm",
],
"line_names": [
"FBGEMM Triton Grouped GEMM FP8",
"FBGEMM CUTLASS F8F8BF16 Rowwise",
"SGLang Grouped GEMM FP8",
],
"styles": [("blue", "-"), ("orange", "-"), ("red", "-")],
}
else:
return {
"line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"],
"line_names": [
"FBGEMM Triton Grouped GEMM BF16",
"SGLang Grouped GEMM BF16",
],
"styles": [("blue", "-"), ("green", "-")],
}
def run_benchmark(
model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/"
):
config = get_benchmark_config(use_fp8_w8a8)
benchmark_config = triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[256, 512, 1024, 2048, 4096],
line_arg="provider",
line_vals=config["line_vals"],
line_names=config["line_names"],
styles=config["styles"],
ylabel="Bandwidth (GB/s)",
plot_name="grouped-gemm-performance",
args={},
)
@triton.testing.perf_report(benchmark_config)
def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"Benchmarking {provider} with batch_size={batch_size}")
torch.cuda.manual_seed_all(0)
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]
if provider == "fbgemm_triton_grouped_gemm_fp8":
try:
test_data = create_fp8_test_data(
batch_size,
num_groups,
hidden_size,
intermediate_size,
backend="triton",
)
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
# Calculate memory bandwidth
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
)
def run_func():
return fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)
except Exception as e:
print(f"FP8 not supported, skipping: {e}")
return float("inf"), float("inf"), float("inf")
elif provider == "fbgemm_cutlass_f8f8bf16_rowwise":
try:
test_data = create_fp8_test_data(
batch_size,
num_groups,
hidden_size,
intermediate_size,
backend="cutlass",
)
x, wq, w_scale, m_sizes = test_data
# Calculate memory bandwidth
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
)
# Quantize input using triton_quantize_fp8_row
xq, x_scale = triton_quantize_fp8_row(x)
x_scale = x_scale.view(batch_size, -1)
def run_func():
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
xq, wq, x_scale, w_scale, m_sizes
)
except Exception as e:
print(
f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, "
f"skipping: {e}"
)
return float("inf"), float("inf"), float("inf")
else:
test_data = create_test_data(
batch_size, num_groups, hidden_size, intermediate_size
)
(
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
) = test_data
# Calculate memory bandwidth for BF16 operations
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.bfloat16
)
if provider == "fbgemm_triton_grouped_gemm":
def run_func():
return fbgemm_grouped_gemm(
x, w_fbgemm, m_sizes, use_fast_accum=True
)
else:
def run_func():
return sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
for _ in range(10):
try:
run_func()
except Exception as e:
print(f"Error during warmup for {provider}: {e}")
return float("inf"), float("inf"), float("inf")
torch.cuda.synchronize()
try:
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
# Convert time (ms) to bandwidth (GB/s)
# Bandwidth = Memory (bytes) / Time (seconds)
# Convert ms to seconds and bytes to GB (1e9)
gb_per_s = (memory_bytes / 1e9) / (ms / 1000)
# min bandwidth = max time, max bandwidth = min time
min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000)
max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000)
return gb_per_s, min_gb_per_s, max_gb_per_s
except Exception as e:
print(f"Error during benchmarking for {provider}: {e}")
return 0.0, 0.0, 0.0
dynamic_benchmark.run(
show_plots=True,
print_data=True,
save_path=save_path,
model_config=model_config,
use_fp8_w8a8=use_fp8_w8a8,
)
def verify_correctness(model_config):
print("Verifying correctness...")
batch_size = 128
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
(
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
) = test_data
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
print("✓ BF16 Correctness verification passed!")
else:
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
return False
return True
def main():
parser = argparse.ArgumentParser(
description="Benchmark FBGEMM vs SGLang Grouped GEMM"
)
parser.add_argument(
"--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
help="Model name to get configuration from",
)
parser.add_argument(
"--tp-size", type=int, default=1, help="Tensor parallelism size"
)
parser.add_argument(
"--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark"
)
parser.add_argument(
"--save-path",
type=str,
default="./benchmark_grouped_gemm/",
help="Path to save benchmark results",
)
parser.add_argument(
"--verify-correctness",
action="store_true",
help="Verify correctness before benchmarking",
)
args = parser.parse_args()
try:
model_config = get_model_config(args.model, args.tp_size)
except Exception as e:
print(f"Failed to get model config: {e}")
print("Using default configuration...")
model_config = {
"num_groups": 8,
"hidden_size": 4096,
"intermediate_size": 14336,
"dtype": torch.bfloat16,
}
print("Running benchmark with:")
print(f" num_groups: {model_config['num_groups']}")
print(f" hidden_size: {model_config['hidden_size']}")
print(f" intermediate_size: {model_config['intermediate_size']}")
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
if args.verify_correctness:
if not verify_correctness(model_config):
print("Correctness verification failed. Exiting...")
return
try:
run_benchmark(
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
save_path=args.save_path,
)
except Exception as e:
print(f"Benchmark failed: {e}")
if __name__ == "__main__":
main()
...@@ -246,7 +246,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -246,7 +246,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|-----------|-------------|----------| |-----------|-------------|----------|
| `--ep-size` | The expert parallelism size. | 1 | | `--ep-size` | The expert parallelism size. | 1 |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none | | `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none |
| `--moe-runner-backend` | Select the runner backend for MoE. | 'triton' | | `--moe-runner-backend` | Select the runner backend for MoE. | auto |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto | | `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 | | `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in EPLB. | None | | `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in EPLB. | None |
......
...@@ -13,22 +13,18 @@ from sgl_kernel import ( ...@@ -13,22 +13,18 @@ from sgl_kernel import (
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
post_reorder_triton_kernel_for_cutlass_moe, post_reorder_triton_kernel_for_cutlass_moe,
pre_reorder_triton_kernel_for_cutlass_moe, pre_reorder_triton_kernel_for_cutlass_moe,
run_cutlass_moe_ep_preproess, run_moe_ep_preproess,
) )
def cutlass_w4a8_moe( def cutlass_w4a8_moe(
start_expert_id: int,
end_expert_id: int,
total_num_experts: int,
a: torch.Tensor, a: torch.Tensor,
w1_q: torch.Tensor, w1_q: torch.Tensor,
w2_q: torch.Tensor, w2_q: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids_: torch.Tensor, topk_ids: torch.Tensor,
local_topk_ids: torch.Tensor,
a_strides1: torch.Tensor, a_strides1: torch.Tensor,
b_strides1: torch.Tensor, b_strides1: torch.Tensor,
c_strides1: torch.Tensor, c_strides1: torch.Tensor,
...@@ -64,6 +60,7 @@ def cutlass_w4a8_moe( ...@@ -64,6 +60,7 @@ def cutlass_w4a8_moe(
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts, N // 512, K * 4] Shape: [num_experts, N // 512, K * 4]
- topk_weights (torch.Tensor): The weights of each token->expert mapping. - topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The ids of each token->expert mapping.
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm. - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm. - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm. - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
...@@ -83,7 +80,7 @@ def cutlass_w4a8_moe( ...@@ -83,7 +80,7 @@ def cutlass_w4a8_moe(
Returns: Returns:
- torch.Tensor: The fp8 output tensor after applying the MoE layer. - torch.Tensor: The fp8 output tensor after applying the MoE layer.
""" """
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_q.dtype == torch.int8 assert w1_q.dtype == torch.int8
assert w2_q.dtype == torch.int8 assert w2_q.dtype == torch.int8
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1" assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
...@@ -96,20 +93,21 @@ def cutlass_w4a8_moe( ...@@ -96,20 +93,21 @@ def cutlass_w4a8_moe(
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch" assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch" assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch" assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
num_experts = w1_q.size(0) num_local_experts = w1_q.size(0)
m = a.size(0) m = a.size(0)
k = w1_q.size(2) * 2 # w1_q is transposed and packed k = w1_q.size(2) * 2 # w1_q is transposed and packed
n = w2_q.size(2) * 2 # w2_q is transposed and packed n = w2_q.size(2) * 2 # w2_q is transposed and packed
topk = topk_ids_.size(1) topk = topk_ids.size(1)
if apply_router_weight_on_input: if apply_router_weight_on_input:
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1" assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
device = a.device device = a.device
topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)
_, src2dst, _ = run_cutlass_moe_ep_preproess( _, src2dst, _ = run_moe_ep_preproess(
local_topk_ids, topk_ids,
num_experts, num_local_experts,
) )
gateup_input = torch.empty( gateup_input = torch.empty(
...@@ -122,9 +120,9 @@ def cutlass_w4a8_moe( ...@@ -122,9 +120,9 @@ def cutlass_w4a8_moe(
a, a,
gateup_input, gateup_input,
src2dst, src2dst,
local_topk_ids, topk_ids,
a1_scale, a1_scale,
total_num_experts, num_local_experts,
topk, topk,
k, k,
BLOCK_SIZE=512, BLOCK_SIZE=512,
...@@ -133,16 +131,16 @@ def cutlass_w4a8_moe( ...@@ -133,16 +131,16 @@ def cutlass_w4a8_moe(
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel, # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
# they are kept to allow for a quick switch of the permutation logic # they are kept to allow for a quick switch of the permutation logic
# from the current triton kernel implementation to the cutlass-based one if needed. # from the current triton kernel implementation to the cutlass-based one if needed.
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
get_cutlass_w4a8_moe_mm_data( get_cutlass_w4a8_moe_mm_data(
local_topk_ids, topk_ids,
expert_offsets, expert_offsets,
problem_sizes1, problem_sizes1,
problem_sizes2, problem_sizes2,
a_map, a_map,
c_map, c_map,
num_experts, num_local_experts,
n, n,
k, k,
) )
...@@ -195,12 +193,11 @@ def cutlass_w4a8_moe( ...@@ -195,12 +193,11 @@ def cutlass_w4a8_moe(
c2, c2,
output, output,
src2dst, src2dst,
local_topk_ids, topk_ids,
topk_weights, topk_weights,
num_experts,
topk, topk,
num_local_experts,
k, k,
0,
BLOCK_SIZE=512, BLOCK_SIZE=512,
) )
return output return output
from __future__ import annotations from __future__ import annotations
import logging import logging
from contextlib import nullcontext from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import torch import torch
import triton
import triton.language as tl
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
get_deepep_mode, get_deepep_mode,
get_moe_a2a_backend, get_moe_a2a_backend,
...@@ -18,13 +14,10 @@ from sglang.srt.layers.moe import ( ...@@ -18,13 +14,10 @@ from sglang.srt.layers.moe import (
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather, ep_gather,
ep_scatter, ep_scatter,
moe_ep_deepgemm_preprocess,
post_reorder_triton_kernel,
silu_and_mul_masked_post_quant_fwd, silu_and_mul_masked_post_quant_fwd,
tma_align_input_scale, tma_align_input_scale,
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
...@@ -36,19 +29,10 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ...@@ -36,19 +29,10 @@ from sglang.srt.layers.quantization.modelopt_quant import (
CUTEDSL_MOE_NVFP4_DISPATCH, CUTEDSL_MOE_NVFP4_DISPATCH,
ModelOptNvFp4FusedMoEMethod, ModelOptNvFp4FusedMoEMethod,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.offloader import get_offloader from sglang.srt.offloader import get_offloader
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ( from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
ceil_div,
dispose_tensor,
get_bool_env_var,
get_int_env_var,
is_cuda,
is_hip,
is_npu,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import ( from sglang.srt.layers.moe.token_dispatcher import (
...@@ -72,29 +56,13 @@ if _use_aiter: ...@@ -72,29 +56,13 @@ if _use_aiter:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO(kaixih@nvidia): ideally we should merge this logic into class DeepEPMoE(FusedMoE):
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
temp = x.to(torch.float32).view(torch.int32)
exp = torch.bitwise_right_shift(temp, 23)
mant = torch.bitwise_and(temp, 0x7FFFFF)
is_ru = torch.logical_and(
torch.logical_and((mant > 0), (exp != 0xFE)),
~torch.logical_and((exp == 0), (mant <= 0x400000)),
)
exp = torch.where(is_ru, exp + 1, exp)
new_x = exp.to(torch.uint8).view(torch.int)
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
class EPMoE(FusedMoE):
""" """
MoE Expert Parallel Impl MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
""" """
_has_printed = False
def __init__( def __init__(
self, self,
num_experts: int, num_experts: int,
...@@ -108,272 +76,29 @@ class EPMoE(FusedMoE): ...@@ -108,272 +76,29 @@ class EPMoE(FusedMoE):
prefix: str = "", prefix: str = "",
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
with_bias: bool = False,
): ):
super().__init__( super().__init__(
num_experts=num_experts, num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
num_fused_shared_experts=num_fused_shared_experts,
layer_id=layer_id, layer_id=layer_id,
top_k=top_k, num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
activation=activation, activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
with_bias=with_bias,
) )
self.intermediate_size = intermediate_size
if isinstance(quant_config, Fp8Config): if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False) self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = True self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
else: else:
self.use_fp8_w8a8 = False self.use_fp8_w8a8 = False
self.use_block_quant = False self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm(hidden_states, topk_output)
else:
return super().forward(hidden_states, topk_output)
def forward_deepgemm(
self,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
):
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
topk_weights, topk_ids, _ = topk_output
if not self.use_block_quant:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size = 128
w13_weight_scale_n = 2 * (
(self.intermediate_size + scale_block_size - 1) // scale_block_size
)
w13_weight_scale_k = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w13_weight_scale = (
self.w13_weight_scale.unsqueeze(1)
.repeat_interleave(w13_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w13_weight_scale_k, dim=2)
)
self.w13_weight_fp8 = (
self.w13_weight,
w13_weight_scale,
)
w2_weight_scale_n = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w2_weight_scale_k = (
self.intermediate_size + scale_block_size - 1
) // scale_block_size
w2_weight_scale = (
self.w2_weight_scale.unsqueeze(1)
.repeat_interleave(w2_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w2_weight_scale_k, dim=2)
)
self.w2_weight_fp8 = (
self.w2_weight,
w2_weight_scale,
)
# PreReorder
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
self.num_experts,
hidden_states,
self.top_k,
self.start_expert_id,
self.end_expert_id,
self.block_shape,
)
)
dispose_tensor(hidden_states)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
b, s_mn, s_k = gateup_input_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
# GroupGemm-0
gateup_input_fp8 = (
gateup_input,
(
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
gateup_input_scale
)
),
)
num_groups, m, k = gateup_input_fp8[0].size()
n = self.w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
gateup_input_fp8,
self.w13_weight_fp8,
gateup_output,
masked_m,
expected_m,
)
del gateup_input
del gateup_input_fp8
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
(
down_input_scale
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
),
)
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8,
self.w2_weight_fp8,
down_output,
masked_m,
expected_m,
)
del down_input
del down_input_fp8
# PostReorder
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states_shape[1],
m_max * self.start_expert_id,
BLOCK_SIZE=512,
)
if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.moe_runner_config.routed_scaling_factor
return output
class DeepEPMoE(EPMoE):
"""
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
"""
_has_printed = False
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
):
super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
layer_id=layer_id,
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
activation=activation,
routed_scaling_factor=routed_scaling_factor,
)
self.deepep_mode = get_deepep_mode() self.deepep_mode = get_deepep_mode()
# TODO: move to the beginning of the file # TODO: move to the beginning of the file
...@@ -567,7 +292,6 @@ class DeepEPMoE(EPMoE): ...@@ -567,7 +292,6 @@ class DeepEPMoE(EPMoE):
N = self.w13_weight.size(1) N = self.w13_weight.size(1)
scale_block_size = 128 scale_block_size = 128
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
w13_weight_fp8 = ( w13_weight_fp8 = (
self.w13_weight, self.w13_weight,
( (
...@@ -988,8 +712,6 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): ...@@ -988,8 +712,6 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
return FlashInferFusedMoE return FlashInferFusedMoE
if get_moe_runner_backend().is_flashinfer_cutlass(): if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE return FusedMoE
if get_moe_expert_parallel_world_size() > 1:
return EPMoE
return FusedMoE return FusedMoE
......
...@@ -156,8 +156,7 @@ class FusedMoE(torch.nn.Module): ...@@ -156,8 +156,7 @@ class FusedMoE(torch.nn.Module):
self.moe_tp_rank = get_moe_tensor_parallel_rank() self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0 assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size self.num_local_experts = num_experts // self.moe_ep_size
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
if self.moe_ep_size > 1: if self.moe_ep_size > 1:
# TODO(ch-wan): support shared experts fusion # TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1 # Create a tensor of size num_experts filled with -1
......
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
import torch
from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo,
MoeRunnerConfig,
MoeRunnerCore,
RunnerInput,
RunnerOutput,
register_post_permute,
register_pre_permute,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend
from sglang.srt.utils import dispose_tensor
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
)
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
temp = x.to(torch.float32).view(torch.int32)
exp = torch.bitwise_right_shift(temp, 23)
mant = torch.bitwise_and(temp, 0x7FFFFF)
is_ru = torch.logical_and(
torch.logical_and((mant > 0), (exp != 0xFE)),
~torch.logical_and((exp == 0), (mant <= 0x400000)),
)
exp = torch.where(is_ru, exp + 1, exp)
new_x = exp.to(torch.uint8).view(torch.int)
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
@dataclass
class DeepGemmRunnerInput(RunnerInput):
hidden_states: torch.Tensor
hidden_states_scale: torch.Tensor
masked_m: torch.Tensor
expected_m: int
use_masked_gemm: bool
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
@dataclass
class DeepGemmRunnerOutput(RunnerOutput):
hidden_states: torch.Tensor
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
@dataclass
class DeepGemmMoeQuantInfo(MoeQuantInfo):
w13_weight: torch.Tensor
w2_weight: torch.Tensor
use_fp8: bool
w13_scale: Optional[torch.Tensor] = None
w2_scale: Optional[torch.Tensor] = None
block_shape: Optional[List[int]] = None
class DeepGemmRunnerCore(MoeRunnerCore):
def __init__(self, config: MoeRunnerConfig):
super().__init__(config)
assert self.config.activation == "silu"
def run(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> DeepGemmRunnerOutput:
if runner_input.use_masked_gemm:
hidden_states = self._run_masked_gemm(
runner_input,
quant_info,
running_state,
)
else:
hidden_states = self._run_contiguous_gemm(
runner_input,
quant_info,
running_state,
)
return DeepGemmRunnerOutput(hidden_states=hidden_states)
def _run_masked_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
hidden_states = runner_input.hidden_states
hidden_states_scale = runner_input.hidden_states_scale
masked_m = runner_input.masked_m
expected_m = runner_input.expected_m
w13_weight = quant_info.w13_weight
w2_weight = quant_info.w2_weight
w13_scale = quant_info.w13_scale
w2_scale = quant_info.w2_scale
hidden_states_device = running_state["hidden_states_device"]
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
b, s_mn, s_k = hidden_states_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
# GroupGemm-0
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
else:
hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
hidden_states_scale
)
num_groups, m, k = hidden_states.shape
n = w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(hidden_states, hidden_states_scale),
(w13_weight, w13_scale),
gateup_output,
masked_m,
expected_m,
)
dispose_tensor(hidden_states)
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=torch.float8_e4m3fn,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = w2_weight.shape[1]
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
down_input_scale
)
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(down_input, down_input_scale),
(w2_weight, w2_scale),
down_output,
masked_m,
expected_m,
)
del down_input
return down_output
def _run_contiguous_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
pass
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
@register_pre_permute("standard", "deep_gemm")
def pre_permute_standard_to_deep_gemm(
dispatch_output: StandardDispatchOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepGemmRunnerInput:
from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
hidden_states, topk_output = dispatch_output
topk_weights, topk_ids, _ = topk_output
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
hidden_states_ref = hidden_states
topk_weights, topk_ids = topk_weights, topk_ids
# PreReorder
masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
runner_config.num_local_experts,
hidden_states,
runner_config.top_k,
quant_info.block_shape,
)
)
dispose_tensor(hidden_states_ref)
running_state["topk_ids"] = topk_ids
running_state["topk_weights"] = topk_weights
running_state["hidden_states_shape"] = hidden_states_shape
running_state["hidden_states_dtype"] = hidden_states_dtype
running_state["hidden_states_device"] = hidden_states_device
running_state["src2dst"] = src2dst
return DeepGemmRunnerInput(
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
masked_m=masked_m,
expected_m=expected_m,
use_masked_gemm=True,
)
@register_post_permute("deep_gemm", "standard")
def post_permute_deep_gemm_to_standard(
runner_output: DeepGemmRunnerOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> StandardCombineInput:
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
hidden_states_shape = running_state["hidden_states_shape"]
hidden_states_dtype = running_state["hidden_states_dtype"]
hidden_states_device = running_state["hidden_states_device"]
src2dst = running_state["src2dst"]
topk_ids = running_state["topk_ids"]
topk_weights = running_state["topk_weights"]
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
runner_output.hidden_states,
output,
src2dst,
topk_ids,
topk_weights,
runner_config.top_k,
hidden_states_shape[1],
BLOCK_SIZE=512,
)
dispose_tensor(runner_output.hidden_states)
if runner_config.routed_scaling_factor is not None:
output *= runner_config.routed_scaling_factor
return StandardCombineInput(
hidden_states=output,
)
...@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.moe_runner.base import ( ...@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
MoeRunnerConfig, MoeRunnerConfig,
PermuteMethodPool, PermuteMethodPool,
) )
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
from sglang.srt.layers.moe.utils import get_moe_a2a_backend from sglang.srt.layers.moe.utils import get_moe_a2a_backend
...@@ -30,6 +31,8 @@ class MoeRunner: ...@@ -30,6 +31,8 @@ class MoeRunner:
if runner_backend.is_triton(): if runner_backend.is_triton():
self.runner_core = TritonRunnerCore(config) self.runner_core = TritonRunnerCore(config)
elif runner_backend.is_deep_gemm():
self.runner_core = DeepGemmRunnerCore(config)
else: else:
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}") raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
......
...@@ -44,6 +44,7 @@ class MoeA2ABackend(Enum): ...@@ -44,6 +44,7 @@ class MoeA2ABackend(Enum):
class MoeRunnerBackend(Enum): class MoeRunnerBackend(Enum):
AUTO = "auto" AUTO = "auto"
DEEP_GEMM = "deep_gemm"
TRITON = "triton" TRITON = "triton"
TRITON_KERNEL = "triton_kernel" TRITON_KERNEL = "triton_kernel"
FLASHINFER_TRTLLM = "flashinfer_trtllm" FLASHINFER_TRTLLM = "flashinfer_trtllm"
...@@ -54,6 +55,9 @@ class MoeRunnerBackend(Enum): ...@@ -54,6 +55,9 @@ class MoeRunnerBackend(Enum):
def is_auto(self): def is_auto(self):
return self == MoeRunnerBackend.AUTO return self == MoeRunnerBackend.AUTO
def is_deep_gemm(self):
return self == MoeRunnerBackend.DEEP_GEMM
def is_triton(self): def is_triton(self):
return self == MoeRunnerBackend.TRITON return self == MoeRunnerBackend.TRITON
...@@ -147,7 +151,9 @@ def get_moe_a2a_backend() -> MoeA2ABackend: ...@@ -147,7 +151,9 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
def get_moe_runner_backend() -> MoeRunnerBackend: def get_moe_runner_backend() -> MoeRunnerBackend:
global MOE_RUNNER_BACKEND global MOE_RUNNER_BACKEND
if MOE_RUNNER_BACKEND is None: if MOE_RUNNER_BACKEND is None:
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend") logger.warning(
"MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected"
)
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
return MOE_RUNNER_BACKEND return MOE_RUNNER_BACKEND
......
...@@ -31,8 +31,8 @@ except ImportError: ...@@ -31,8 +31,8 @@ except ImportError:
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
BlockQuantScaleParameter, BlockQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
...@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def create_moe_runner( def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
): ):
from sglang.srt.layers.moe.utils import (
get_moe_a2a_backend,
get_moe_runner_backend,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
self.moe_runner_config = moe_runner_config self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) moe_runner_backend = get_moe_runner_backend()
if moe_runner_backend.is_auto():
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and get_moe_a2a_backend().is_deepep()
):
moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
else:
moe_runner_backend = MoeRunnerBackend.TRITON
if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
else:
# TODO(cwan): refactor other backends
pass
def apply( def apply(
self, self,
...@@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
return StandardCombineInput(hidden_states=output) return StandardCombineInput(hidden_states=output)
quant_info = TritonMoeQuantInfo( if self.runner.runner_backend.is_deep_gemm():
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight, w13_weight = layer.w13_weight
use_fp8_w8a8=True, w2_weight = layer.w2_weight
w13_scale=(
layer.w13_weight_scale_inv if self.block_quant:
if self.block_quant block_shape = self.quant_config.weight_block_size
else layer.w13_weight_scale w13_scale = layer.w13_weight_scale_inv
), w2_scale = layer.w2_weight_scale_inv
w2_scale=( else:
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
), scale_block_size = 128
a13_scale=layer.w13_input_scale, block_shape = [scale_block_size, scale_block_size]
a2_scale=layer.w2_input_scale, w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
block_shape=self.quant_config.weight_block_size, w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
) w13_scale = (
layer.w13_weight_scale.unsqueeze(1)
.repeat_interleave(w13_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w13_scale_k, dim=2)
)
w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
w2_scale = (
layer.w2_weight_scale.unsqueeze(1)
.repeat_interleave(w2_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w2_scale_k, dim=2)
)
quant_info = DeepGemmMoeQuantInfo(
w13_weight=w13_weight,
w2_weight=w2_weight,
use_fp8=True,
w13_scale=w13_scale,
w2_scale=w2_scale,
block_shape=block_shape,
)
elif self.runner.runner_backend.is_triton():
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
w13_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
else:
raise NotImplementedError(
"Unsupported runner backend: %s" % self.runner.runner_backend
)
return self.runner.run(dispatch_output, quant_info) return self.runner.run(dispatch_output, quant_info)
def apply_with_router_logits( def apply_with_router_logits(
......
...@@ -21,7 +21,6 @@ from sglang.srt.utils import is_npu, set_weight_attrs ...@@ -21,7 +21,6 @@ from sglang.srt.utils import is_npu, set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.token_dispatcher import ( from sglang.srt.layers.moe.token_dispatcher import (
CombineInput, CombineInput,
StandardDispatchOutput, StandardDispatchOutput,
...@@ -94,9 +93,7 @@ class W4AFp8Config(QuantizationConfig): ...@@ -94,9 +93,7 @@ class W4AFp8Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.managers.schedule_batch import global_server_args_dict
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers): if is_layer_skipped(prefix, self.ignored_layers):
...@@ -133,7 +130,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -133,7 +130,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def create_weights( def create_weights(
self, self,
layer: EPMoE, layer: Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size_per_partition: int, intermediate_size_per_partition: int,
...@@ -292,7 +289,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -292,7 +289,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: EPMoE, layer: Module,
dispatch_output: StandardDispatchOutput, dispatch_output: StandardDispatchOutput,
) -> CombineInput: ) -> CombineInput:
...@@ -303,18 +300,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -303,18 +300,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
topk_output = dispatch_output.topk_output topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
local_topk_ids = topk_ids
if get_moe_expert_parallel_world_size() > 1:
local_topk_ids = torch.where(
topk_ids == -1,
layer.num_experts,
topk_ids,
)
output = cutlass_w4a8_moe( output = cutlass_w4a8_moe(
layer.start_expert_id,
layer.end_expert_id,
layer.num_experts,
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -322,7 +309,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -322,7 +309,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv, layer.w2_weight_scale_inv,
topk_weights, topk_weights,
topk_ids, topk_ids,
local_topk_ids,
self.a_strides1, self.a_strides1,
self.b_strides1, self.b_strides1,
self.c_strides1, self.c_strides1,
......
...@@ -49,7 +49,6 @@ from sglang.srt.layers.linear import ( ...@@ -49,7 +49,6 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.router import fused_moe_router_shim from sglang.srt.layers.moe.router import fused_moe_router_shim
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
...@@ -176,17 +175,7 @@ class Grok1MoE(nn.Module): ...@@ -176,17 +175,7 @@ class Grok1MoE(nn.Module):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
) )
kwargs = {} self.experts = FusedMoE(
if get_moe_expert_parallel_world_size() > 1:
MoEImpl = EPMoE
else:
MoEImpl = FusedMoE
kwargs["reduce_results"] = reduce_results
kwargs["use_presharded_weights"] = use_presharded_weights
kwargs["inplace"] = inplace
kwargs["no_combine"] = no_combine
self.experts = MoEImpl(
num_experts=num_experts, num_experts=num_experts,
top_k=top_k, top_k=top_k,
layer_id=layer_id, layer_id=layer_id,
...@@ -195,7 +184,10 @@ class Grok1MoE(nn.Module): ...@@ -195,7 +184,10 @@ class Grok1MoE(nn.Module):
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
activation="gelu", activation="gelu",
**kwargs, reduce_results=reduce_results,
use_presharded_weights=use_presharded_weights,
inplace=inplace,
no_combine=no_combine,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
...@@ -36,7 +36,6 @@ from sglang.srt.layers.linear import ( ...@@ -36,7 +36,6 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
...@@ -94,8 +93,7 @@ class MixtralMoE(nn.Module): ...@@ -94,8 +93,7 @@ class MixtralMoE(nn.Module):
renormalize=True, renormalize=True,
) )
MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE self.experts = FusedMoE(
self.experts = MoEImpl(
num_experts=num_experts, num_experts=num_experts,
top_k=top_k, top_k=top_k,
layer_id=layer_id, layer_id=layer_id,
......
...@@ -121,6 +121,17 @@ NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter" ...@@ -121,6 +121,17 @@ NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"] RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
MOE_RUNNER_BACKEND_CHOICES = [
"auto",
"deep_gemm",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
"flashinfer_cutedsl",
]
# Allow external code to add more choices # Allow external code to add more choices
def add_load_format_choices(choices): def add_load_format_choices(choices):
...@@ -143,6 +154,10 @@ def add_grammar_backend_choices(choices): ...@@ -143,6 +154,10 @@ def add_grammar_backend_choices(choices):
GRAMMAR_BACKEND_CHOICES.extend(choices) GRAMMAR_BACKEND_CHOICES.extend(choices)
def add_moe_runner_backend_choices(choices):
MOE_RUNNER_BACKEND_CHOICES.extend(choices)
def add_deterministic_attention_backend_choices(choices): def add_deterministic_attention_backend_choices(choices):
DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices) DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices)
...@@ -315,14 +330,7 @@ class ServerArgs: ...@@ -315,14 +330,7 @@ class ServerArgs:
# Expert parallelism # Expert parallelism
ep_size: int = 1 ep_size: int = 1
moe_a2a_backend: Literal["none", "deepep"] = "none" moe_a2a_backend: Literal["none", "deepep"] = "none"
moe_runner_backend: Literal[ moe_runner_backend: str = "auto"
"auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
] = "auto"
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default" flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
enable_flashinfer_allreduce_fusion: bool = False enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto" deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
...@@ -2191,15 +2199,7 @@ class ServerArgs: ...@@ -2191,15 +2199,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--moe-runner-backend", "--moe-runner-backend",
type=str, type=str,
choices=[ choices=MOE_RUNNER_BACKEND_CHOICES,
"auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
"flashinfer_cutedsl",
],
default=ServerArgs.moe_runner_backend, default=ServerArgs.moe_runner_backend,
help="Choose the runner backend for MoE.", help="Choose the runner backend for MoE.",
) )
......
import itertools
import random
import unittest
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.test.test_utils import CustomTestCase
# For test
def ep_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
topk_config: TopKConfig,
# ep config
num_experts: int = 256,
fp8_dtype: torch.types = torch.float8_e4m3fn,
num_experts_per_partition: int = 128,
start_expert_id: int = 0,
end_expert_id: int = 127,
use_fp8_w8a8: bool = False,
w1_scale_inv: Optional[torch.Tensor] = None,
w2_scale_inv: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
):
use_blockwise_fp8 = block_shape is not None
top_k = topk_config.top_k
topk_output = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=topk_config,
)
topk_weights, topk_ids, _ = topk_output
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
gateup_input = torch.empty(
(int(hidden_states.shape[0] * top_k), hidden_states.shape[1]),
device=hidden_states.device,
dtype=(
fp8_dtype
if (use_fp8_w8a8 and not use_blockwise_fp8)
else hidden_states.dtype
),
)
if use_fp8_w8a8 and not use_blockwise_fp8:
max_value = (
torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32)
)
w1_input_scale = max_value / torch.finfo(fp8_dtype).max
else:
w1_input_scale = None
# PreReorder
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
src2dst,
topk_ids,
w1_input_scale,
start_expert_id,
end_expert_id,
top_k,
hidden_states.shape[1],
BLOCK_SIZE=512,
use_per_token_if_dynamic=True,
)
seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2]
weight_indices_cur_rank = torch.arange(
0,
num_experts_per_partition,
device=hidden_states.device,
dtype=torch.int64,
)
# GroupGemm-0
gateup_output = torch.empty(
gateup_input.shape[0],
w1.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
gateup_output = grouped_gemm_triton(
a=gateup_input,
b=w1,
c=gateup_output,
batch_size=num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=use_fp8_w8a8,
scale_a=w1_input_scale,
scale_b=w1_scale_inv,
block_shape=block_shape,
)
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
fp8_dtype
if (use_fp8_w8a8 and not use_blockwise_fp8)
else hidden_states.dtype
),
)
if use_fp8_w8a8 and not use_blockwise_fp8:
w2_input_scale = torch.ones(
num_experts_per_partition,
dtype=torch.float32,
device=hidden_states.device,
)
else:
w2_input_scale = None
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
w2_input_scale,
start_expert_id,
end_expert_id,
BLOCK_SIZE=512,
)
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
w2.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
down_output = grouped_gemm_triton(
a=down_input,
b=w2,
c=down_output,
batch_size=num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=use_fp8_w8a8,
scale_a=w2_input_scale,
scale_b=w2_scale_inv,
block_shape=block_shape,
)
# PostReorder
output = torch.empty_like(hidden_states)
post_reorder_triton_kernel[(hidden_states.size(0),)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
top_k,
hidden_states.size(1),
0,
BLOCK_SIZE=512,
)
return output
# test util
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function converts block-wise quantization to tensor-wise quantization.
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
and the block size.
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
Note only float8 is supported for now.
"""
# process 3D tensor
if x_q_block.dim() == 3:
batch_size = x_q_block.size(0)
return torch.stack(
[block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)]
)
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
x_dq_block_tiles = [
[
x_dq_block[
j * block_n : min((j + 1) * block_n, n),
i * block_k : min((i + 1) * block_k, k),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
return x_dq_block
class TestW8A8BlockFP8EPMoE(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 222, 1024, 2048]
N = [128, 1024, 2048]
K = [256, 4096, 5120]
E = [8, 16]
ep_size = [2, 4]
TOP_KS = [2, 4]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w8a8_block_fp8_ep_moe(
self, M, N, K, E, ep_size, topk, block_size, dtype, seed
):
torch.manual_seed(seed)
random.seed(seed)
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max
w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max
w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_s = (
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
* factor_for_scale
)
w2_s = (
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
* factor_for_scale
)
w1_ref = block_dequant(w1, w1_s, block_size).to(dtype)
w2_ref = block_dequant(w2, w2_s, block_size).to(dtype)
score = torch.randn((M, E), dtype=dtype)
num_experts_per_partition = E // ep_size
cur_rank = random.randint(0, ep_size - 1)
start_id = cur_rank * num_experts_per_partition
end_id = start_id + num_experts_per_partition - 1
topk_config = TopKConfig(
top_k=topk,
renormalize=False,
)
with torch.inference_mode():
out = ep_moe(
hidden_states=a,
w1=w1,
w2=w2,
router_logits=score,
topk_config=topk_config,
use_fp8_w8a8=True,
w1_scale_inv=w1_s,
w2_scale_inv=w2_s,
block_shape=block_size,
num_experts=E,
num_experts_per_partition=num_experts_per_partition,
start_expert_id=start_id,
end_expert_id=end_id,
)
ref_out = ep_moe(
hidden_states=a,
w1=w1_ref,
w2=w2_ref,
router_logits=score,
topk_config=topk_config,
use_fp8_w8a8=False,
w1_scale_inv=None,
w2_scale_inv=None,
block_shape=None,
num_experts=E,
num_experts_per_partition=num_experts_per_partition,
start_expert_id=start_id,
end_expert_id=end_id,
)
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6)
< 0.06
)
def test_w8a8_block_fp8_ep_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.ep_size,
self.TOP_KS,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
ep_size=params[4],
topk=params[5],
block_size=params[6],
dtype=params[7],
seed=params[8],
):
self._w8a8_block_fp8_ep_moe(*params)
torch.cuda.empty_cache()
if __name__ == "__main__":
unittest.main(verbosity=2)
...@@ -120,7 +120,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty ...@@ -120,7 +120,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
) )
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
expert_map = torch.arange(E, dtype=torch.int32, device=device) expert_map = torch.arange(E, dtype=torch.int32, device=device)
expert_map[local_e:] = E expert_map[local_e:] = -1
output = cutlass_moe( output = cutlass_moe(
a, a,
...@@ -138,9 +138,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty ...@@ -138,9 +138,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
c_strides2, c_strides2,
s_strides13, s_strides13,
s_strides2, s_strides2,
0, local_e,
local_e - 1,
E,
a1_scale, a1_scale,
a2_scale, a2_scale,
expert_map, expert_map,
...@@ -178,7 +176,7 @@ def cutlass_moe( ...@@ -178,7 +176,7 @@ def cutlass_moe(
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids_: torch.Tensor, topk_ids: torch.Tensor,
a_strides1: torch.Tensor, a_strides1: torch.Tensor,
b_strides1: torch.Tensor, b_strides1: torch.Tensor,
c_strides1: torch.Tensor, c_strides1: torch.Tensor,
...@@ -187,40 +185,32 @@ def cutlass_moe( ...@@ -187,40 +185,32 @@ def cutlass_moe(
c_strides2: torch.Tensor, c_strides2: torch.Tensor,
s_strides13: torch.Tensor, s_strides13: torch.Tensor,
s_strides2: torch.Tensor, s_strides2: torch.Tensor,
start_expert_id: int, num_local_experts: int,
end_expert_id: int,
E: int,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
): ):
local_topk_ids = topk_ids_ topk_ids = expert_map[topk_ids]
local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
device = a.device device = a.device
local_num_experts = end_expert_id - start_expert_id + 1
expert_offsets = torch.empty( expert_offsets = torch.empty(
(local_num_experts + 1), dtype=torch.int32, device=device (num_local_experts + 1), dtype=torch.int32, device=device
) )
problem_sizes1 = torch.empty( problem_sizes1 = torch.empty(
(local_num_experts, 3), dtype=torch.int32, device=device (num_local_experts, 3), dtype=torch.int32, device=device
) )
problem_sizes2 = torch.empty( problem_sizes2 = torch.empty(
(local_num_experts, 3), dtype=torch.int32, device=device (num_local_experts, 3), dtype=torch.int32, device=device
) )
return cutlass_w4a8_moe( return cutlass_w4a8_moe(
start_expert_id,
end_expert_id,
E,
a, a,
w1_q, w1_q,
w2_q, w2_q,
w1_scale, w1_scale,
w2_scale, w2_scale,
topk_weights, topk_weights,
topk_ids_, topk_ids,
local_topk_ids,
a_strides1, a_strides1,
b_strides1, b_strides1,
c_strides1, c_strides1,
......
...@@ -12,7 +12,7 @@ from sglang.test.test_utils import ( ...@@ -12,7 +12,7 @@ from sglang.test.test_utils import (
) )
class TestEpMoE(CustomTestCase): class TestEp(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
...@@ -34,18 +34,6 @@ class TestEpMoE(CustomTestCase): ...@@ -34,18 +34,6 @@ class TestEpMoE(CustomTestCase):
def tearDownClass(cls): def tearDownClass(cls):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.5)
def test_mgsm_en(self): def test_mgsm_en(self):
args = SimpleNamespace( args = SimpleNamespace(
base_url=self.base_url, base_url=self.base_url,
...@@ -59,7 +47,7 @@ class TestEpMoE(CustomTestCase): ...@@ -59,7 +47,7 @@ class TestEpMoE(CustomTestCase):
self.assertGreaterEqual(metrics["score"], 0.8) self.assertGreaterEqual(metrics["score"], 0.8)
class TestEpMoEFP8(CustomTestCase): class TestEpDeepGEMM(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
...@@ -76,6 +64,8 @@ class TestEpMoEFP8(CustomTestCase): ...@@ -76,6 +64,8 @@ class TestEpMoEFP8(CustomTestCase):
"2", "2",
"--quantization", "--quantization",
"fp8", "fp8",
"--moe-runner-backend",
"deep_gemm",
], ],
) )
...@@ -83,18 +73,6 @@ class TestEpMoEFP8(CustomTestCase): ...@@ -83,18 +73,6 @@ class TestEpMoEFP8(CustomTestCase):
def tearDownClass(cls): def tearDownClass(cls):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.5)
def test_mgsm_en(self): def test_mgsm_en(self):
args = SimpleNamespace( args = SimpleNamespace(
base_url=self.base_url, base_url=self.base_url,
......
...@@ -130,6 +130,7 @@ suites = { ...@@ -130,6 +130,7 @@ suites = {
TestFile("test_modelopt_loader.py", 30), TestFile("test_modelopt_loader.py", 30),
], ],
"per-commit-2-gpu": [ "per-commit-2-gpu": [
TestFile("ep/test_moe_ep.py", 140),
TestFile("lora/test_lora_tp.py", 116), TestFile("lora/test_lora_tp.py", 116),
TestFile("rl/test_update_weights_from_distributed.py", 103), TestFile("rl/test_update_weights_from_distributed.py", 103),
TestFile("test_data_parallelism.py", 73), TestFile("test_data_parallelism.py", 73),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment