Unverified Commit 253454de authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Integrate triton moe kernel (#7689)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent ea3e7ffe
# python3 benchmark/kernels/fused_moe_triton/sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8
import argparse
import torch
import triton
from transformers import AutoConfig
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward,
)
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
}
print(f"{shape_configs=}")
return shape_configs
def fused_moe_triton_api(
x,
w1,
w2,
input_gating,
topk,
):
return triton_kernel_moe_forward(
x,
w1,
w2,
input_gating,
topk,
renormalize=False,
)
def fused_moe_sglang_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
):
return fused_moe_sglang(
x,
w1,
w2,
input_gating,
topk,
renormalize=False,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=list([128, 256, 512, 1024, 2048, 4096, 8192]),
line_arg="provider",
line_vals=[
"sglang_fused_moe_triton_v340",
"sglang_fused_moe_triton",
],
line_names=[
"sglang_fused_moe_triton_v340",
"sglang_fused_moe_triton",
],
styles=[
("blue", "-"),
("green", "-"),
],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(
batch_size,
provider,
model_config,
use_fp8_w8a8=False,
use_cuda_graph: bool = False,
):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_tokens = batch_size
num_experts = model_config["num_experts"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
block_shape = model_config["block_shape"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
w1_tri = w1.clone()
w2_tri = w2.clone()
w1_tri = w1_tri.transpose(-2, -1).contiguous()
w2_tri = w2_tri.transpose(-2, -1).contiguous()
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
if provider == "sglang_fused_moe_triton_v340":
api_func = fused_moe_triton_api
api_kwargs = {
"x": x,
"w1": w1_tri,
"w2": w2_tri,
"input_gating": input_gating,
"topk": topk,
}
else:
api_func = fused_moe_sglang_api
api_kwargs = {
"x": x,
"w1": w1,
"w2": w2,
"input_gating": input_gating,
"topk": topk,
"use_fp8_w8a8": use_fp8_w8a8,
"block_shape": block_shape,
}
# Warmup
for _ in range(10):
_ = api_func(**api_kwargs)
torch.cuda.synchronize()
if use_cuda_graph:
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
api_func(**api_kwargs)
torch.cuda.synchronize()
bench_lambda = lambda: graph.replay()
else:
bench_lambda = lambda: api_func(**api_kwargs)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(bench_lambda, quantiles=quantiles)
return ms, min_ms, max_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--use-cuda-graph", action="store_true", help="Enable CUDA Graph capture/replay"
)
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/sglang_fused_moe/",
)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
try:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:23456",
world_size=1,
rank=0,
)
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method="tcp://127.0.0.1:23456",
local_rank=0,
backend="nccl" if torch.cuda.is_available() else "gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
use_cuda_graph=args.use_cuda_graph,
)
finally:
destroy_model_parallel()
destroy_distributed_environment()
if __name__ == "__main__":
main()
...@@ -1737,6 +1737,7 @@ def fused_moe( ...@@ -1737,6 +1737,7 @@ def fused_moe(
renormalize: bool, renormalize: bool,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
...@@ -1822,6 +1823,7 @@ def fused_moe( ...@@ -1822,6 +1823,7 @@ def fused_moe(
topk_ids, topk_ids,
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import importlib
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
...@@ -19,6 +20,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -19,6 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
...@@ -29,8 +31,15 @@ from sglang.srt.utils import ( ...@@ -29,8 +31,15 @@ from sglang.srt.utils import (
use_intel_amx_backend, use_intel_amx_backend,
) )
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
if torch.cuda.is_available(): if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
if has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward,
)
else: else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
...@@ -87,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -87,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization.""" """MoE method without quantization."""
def __init__(self, use_triton_kernels: bool = False):
super().__init__()
self.use_triton_kernels = use_triton_kernels
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -97,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -97,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
**extra_weight_attrs, **extra_weight_attrs,
): ):
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
if self.use_triton_kernels:
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w13_weight", w13_weight) layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel) # down_proj (row parallel)
w2_weight_n, w2_weight_k = (
hidden_size,
intermediate_size,
)
if self.use_triton_kernels:
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.empty( torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
num_experts, hidden_size, intermediate_size, dtype=params_dtype
),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
...@@ -192,59 +210,72 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -192,59 +210,72 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if _use_aiter: if self.use_triton_kernels:
assert not no_combine, "unsupported" return triton_kernel_moe_forward(
if apply_router_weight_on_input:
assert (
topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
x = x * topk_weights.to(x.dtype)
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # topk_weights must be FP32 (float32)
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
)
else:
return fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, gating_output=router_logits,
topk_ids=topk_ids, topk=top_k,
inplace=inplace and not no_combine, renormalize=renormalize,
activation=activation, )
apply_router_weight_on_input=apply_router_weight_on_input, else:
no_combine=no_combine, topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
if _use_aiter:
assert not no_combine, "unsupported"
if apply_router_weight_on_input:
assert (
topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
x = x * topk_weights.to(x.dtype)
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # topk_weights must be FP32 (float32)
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
else:
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def forward_cpu( def forward_cpu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -475,9 +506,13 @@ class FusedMoE(torch.nn.Module): ...@@ -475,9 +506,13 @@ class FusedMoE(torch.nn.Module):
self.inplace = inplace self.inplace = inplace
self.no_combine = no_combine self.no_combine = no_combine
self.use_triton_kernels = (
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
)
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = ( self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
UnquantizedFusedMoEMethod() self.use_triton_kernels
) )
else: else:
self.quant_method = quant_config.get_quant_method(self, prefix) self.quant_method = quant_config.get_quant_method(self, prefix)
...@@ -597,6 +632,8 @@ class FusedMoE(torch.nn.Module): ...@@ -597,6 +632,8 @@ class FusedMoE(torch.nn.Module):
) )
else: else:
if not self.use_presharded_weights: if not self.use_presharded_weights:
if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1)
loaded_weight = loaded_weight.narrow( loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size shard_dim, shard_size * tp_rank, shard_size
) )
...@@ -630,6 +667,8 @@ class FusedMoE(torch.nn.Module): ...@@ -630,6 +667,8 @@ class FusedMoE(torch.nn.Module):
) )
else: else:
if not self.use_presharded_weights: if not self.use_presharded_weights:
if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1)
loaded_weight = loaded_weight.narrow( loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size shard_dim, shard_size * tp_rank, shard_size
) )
...@@ -716,6 +755,8 @@ class FusedMoE(torch.nn.Module): ...@@ -716,6 +755,8 @@ class FusedMoE(torch.nn.Module):
# should be whatever dimension intermediate_size is # should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False) is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if self.use_triton_kernels:
is_transposed = True
if is_transposed: if is_transposed:
shard_dim = int(not shard_dim) shard_dim = int(not shard_dim)
......
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
from typing import Optional
import torch
from sgl_kernel import gelu_and_mul, silu_and_mul
from triton_kernels.matmul_ogs import matmul_ogs
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
from sglang.srt.utils import direct_register_custom_op
def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
if not renormalize:
gating_output = torch.softmax(gating_output, dim=-1)
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
return triton_kernel_fused_experts(
hidden_states,
w1,
w2,
routing_data,
gather_idx,
scatter_idx,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
# This is a triton implementation of the fused_experts function
def triton_kernel_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
routing_data: RoutingData,
gather_indx: GatherIndx,
scatter_indx: ScatterIndx,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
assert per_channel_quant == False, "per_channel_quant is not supported"
assert expert_map == None, "expert_map is not supported"
assert w1_scale == None, "w1_scale is not supported"
assert w2_scale == None, "w2_scale is not supported"
assert a1_scale == None, "a1_scale is not supported"
assert a2_scale == None, "a2_scale is not supported"
assert block_shape == None, "block_shape is not supported"
# type check
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
# Shape check
assert hidden_states.ndim == 2, "hidden_states must be 2D"
assert (
hidden_states.shape[-1] == w1.shape[-2]
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
assert (
w2.shape[-1] == w1.shape[1]
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
# feature check
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
M, K = hidden_states.shape
E, _, N = w1.shape
n_expts_act = routing_data.n_expts_act
dtype = hidden_states.dtype
if global_num_experts == -1:
global_num_experts = E
# consistent with default implementation
intermediate_cache2 = torch.empty(
(M * n_expts_act, N // 2), device="cuda", dtype=dtype
)
intermediate_cache1 = matmul_ogs(
hidden_states,
w1,
None,
routing_data,
gather_indx=gather_indx,
gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
)
if activation == "silu":
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
elif activation == "gelu":
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
intermediate_cache3 = matmul_ogs(
intermediate_cache2,
w2,
None,
routing_data,
scatter_indx=scatter_indx,
gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
)
return intermediate_cache3
def triton_kernel_moe_forward_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="forward_cuda_triton",
op_func=triton_kernel_moe_forward,
mutates_args=[],
fake_impl=triton_kernel_moe_forward_fake,
)
...@@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"triton_attention_reduce_in_fp32", "triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens", "num_reserved_decode_tokens",
"weight_loader_disable_mmap", "weight_loader_disable_mmap",
"enable_triton_kernel_moe",
] ]
# Put some global args for easy access # Put some global args for easy access
......
...@@ -222,6 +222,7 @@ class ServerArgs: ...@@ -222,6 +222,7 @@ class ServerArgs:
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False disable_fast_image_processor: bool = False
enable_return_hidden_states: bool = False enable_return_hidden_states: bool = False
enable_triton_kernel_moe: bool = False
warmups: Optional[str] = None warmups: Optional[str] = None
# Debug tensor dumps # Debug tensor dumps
...@@ -1554,6 +1555,11 @@ class ServerArgs: ...@@ -1554,6 +1555,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable returning hidden states with responses.", help="Enable returning hidden states with responses.",
) )
parser.add_argument(
"--enable-triton-kernel-moe",
action="store_true",
help="Use triton moe grouped gemm kernel.",
)
parser.add_argument( parser.add_argument(
"--warmups", "--warmups",
type=str, type=str,
......
import unittest
import torch
import torch.nn.functional as F
from tqdm import tqdm
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward,
)
from sglang.test.test_utils import CustomTestCase
class TestFusedMOE(CustomTestCase):
NUM_EXPERTS = [8, 64]
TOP_KS = [2, 4]
@staticmethod
def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01):
"""Create a random CUDA tensor
Args:
shape: Tensor shape
dtype: Data type
mean: Mean value
std: Standard deviation
Returns:
torch.Tensor: Randomly initialized CUDA tensor
"""
return torch.empty(shape, dtype=dtype, device="cuda").normal_(mean, std)
def get_tolerance(self, dtype):
"""Get tolerance values for different data types
Args:
dtype: Data type
Returns:
tuple: (relative tolerance, absolute tolerance)
"""
if dtype == torch.float32:
return 1e-5, 1e-5
elif dtype in [torch.float16, torch.bfloat16]:
return 1e-5, 1e-5
else:
return 1e-2, 1e-2 # Default values for other types
def torch_naive_moe(
self,
a,
w1,
w2,
score,
topk,
):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
if w1.dtype == torch.float8_e4m3fn:
w1_compute = w1.to(a.dtype)
w2_compute = w2.to(a.dtype)
else:
w1_compute = w1
w2_compute = w2
for i in range(w1_compute.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1_compute[i].transpose(0, 1)
) @ w2_compute[i].transpose(0, 1)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def _test_case(self, m, n, k, e, topk, dtype):
rtol, atol = self.get_tolerance(dtype)
a = self.create_random_cuda_tensor((m, k), dtype)
w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
w2 = self.create_random_cuda_tensor((e, k, n), dtype)
w1_tri = w1.clone()
w2_tri = w2.clone()
w1_tri = w1_tri.transpose(-2, -1).contiguous()
w2_tri = w2_tri.transpose(-2, -1).contiguous()
score = self.create_random_cuda_tensor((m, e), dtype)
triton_output = triton_kernel_moe_forward(
a, w1_tri, w2_tri, score, topk, renormalize=False
)
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, rtol=rtol, atol=atol)
def test_various_configurations(self):
m_values = [1, 32, 64, 256]
n_values = [128, 1024]
k_values = [128, 512, 1024]
dtypes = [torch.bfloat16]
# Calculate total number of tests
total_tests = (
len(m_values)
* len(n_values)
* len(k_values)
* len(self.NUM_EXPERTS)
* len(self.TOP_KS)
* len(dtypes)
)
# Create progress bar
with tqdm(total=total_tests, desc="Running MoE tests") as pbar:
for m in m_values:
for n in n_values:
for k in k_values:
for e in self.NUM_EXPERTS:
for topk in self.TOP_KS:
for dtype in dtypes:
with self.subTest(
m=m,
n=n,
k=k,
e=e,
topk=topk,
dtype=dtype,
):
self._test_case(
m,
n,
k,
e,
topk,
dtype,
)
torch.cuda.empty_cache()
pbar.update(1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment