Unverified Commit d2b8c412 authored by Yongfei Xu's avatar Yongfei Xu Committed by GitHub
Browse files

Opt fused triton moe: add tma for down proj kernel (#10567)


Co-authored-by: default avatarybyang <10629930+whybeyoung@users.noreply.github.com>
parent bf8f7a94
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
import argparse
import json
import os
import time
from contextlib import nullcontext
from datetime import datetime
from typing import Any, Dict, List, Tuple, TypedDict
import ray
import torch
import triton
import triton.language as tl
from ray.experimental.tqdm_ray import tqdm
from sgl_kernel import silu_and_mul
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
get_config_dtype_str,
invoke_fused_moe_kernel,
moe_align_block_size,
)
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
get_config_file_name,
)
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip
_is_hip = is_hip()
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def benchmark_config(
config: BenchmarkConfig,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
topk_ids_dir: str,
block_shape: List[int] = None,
num_iters: int = 100,
) -> float:
ncu_enable = os.getenv("NCU_ENABLE", "0") == "1"
if ncu_enable:
num_iters = 1
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16 or use_int8_w8a8:
w1 = torch.randint(
-127,
127,
(
num_experts,
shard_intermediate_size,
hidden_size,
),
dtype=torch.int8,
)
w2 = torch.randint(
-127,
127,
(
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8,
)
else:
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8 or use_int8_w8a8:
if use_int8_w8a8 and block_shape is None:
w1_scale = torch.randn(
num_experts, shard_intermediate_size, dtype=torch.float32
)
w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
elif block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand(
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
)
w2_scale = torch.rand(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
if use_fp8_w8a8:
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_config = TopKConfig(
top_k=topk,
renormalize=True,
)
topk_output = select_experts(hidden_states, input_gating, topk_config)
def prepare(i: int):
input_gating = gating_output[i]
topk_ids = torch.load(f"{topk_ids_dir}/topk_ids_layer{i%58+3}_idx{i//58}.pt")
new_topk_output = select_experts(hidden_states, input_gating, topk_config)
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
tokens, _topk = topk_output.topk_ids.shape
topk_output.topk_ids.copy_(topk_ids[:tokens, :_topk])
topk_output.router_logits.copy_(new_topk_output.router_logits)
moe_use_tma = False
def run():
moe_runner_config = MoeRunnerConfig(
inplace=True,
)
topk_weights, topk_ids, _ = topk_output
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], num_experts
)
M = hidden_states.shape[0]
E, N, _ = w1.shape
topk = topk_ids.shape[1]
padded_tokens = (
min(M * topk, E + 1) * (config["BLOCK_SIZE_M"] - 1) if moe_use_tma else 0
)
total_tokens = M * topk + padded_tokens
cache = torch.empty(
total_tokens * max(N, w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = cache[: total_tokens * N].view(
(total_tokens, N),
)
intermediate_cache2 = torch.empty(
(total_tokens, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = cache[: M * topk * w2.shape[1]].view(
(M, topk, w2.shape[1]),
)
compute_type = (
tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
)
apply_router_weight_on_input = moe_runner_config.apply_router_weight_on_input
with override_config(config):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(10 if not ncu_enable else 1):
invoke_fused_moe_kernel(
hidden_states,
w1,
None,
intermediate_cache1,
None,
w1_scale,
None,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
apply_router_weight_on_input,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=block_shape,
b_use_tma=moe_use_tma,
c_sorted=moe_use_tma,
filter_expert=False,
)
end_event.record()
end_event.synchronize()
time_cost0 = start_event.elapsed_time(end_event)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
for _ in range(10 if not ncu_enable else 1):
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
None,
intermediate_cache3,
a2_scale,
w2_scale,
None,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=block_shape,
a_use_tma=moe_use_tma,
b_use_tma=moe_use_tma,
filter_expert=False,
)
end_event.record()
end_event.synchronize()
time_cost1 = start_event.elapsed_time(end_event)
return time_cost0, time_cost1
# JIT compilation & warmup
if not ncu_enable:
moe_use_tma = False
run()
moe_use_tma = True
run()
latencies: List[float] = []
latencies1: List[float] = []
latencies_tma: List[float] = []
latencies1_tma: List[float] = []
for i in range(num_iters):
prepare(i)
torch.cuda.synchronize()
moe_use_tma = False
t0, t1 = run()
torch.cuda.synchronize()
latencies.append(t0)
latencies1.append(t1)
moe_use_tma = True
t0, t1 = run()
torch.cuda.synchronize()
latencies_tma.append(t0)
latencies1_tma.append(t1)
avg = sum(latencies) / (num_iters * 10) * 1000 # us
avg_tma = sum(latencies_tma) / (num_iters * 10) * 1000 # us
avg1 = sum(latencies1) / (num_iters * 10) * 1000 # us
avg1_tma = sum(latencies1_tma) / (num_iters * 10) * 1000 # us
return avg, avg_tma, avg1, avg1_tma
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
waves_per_eu_range = 0
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_stages in [2]:
for num_warps in [1, 2, 4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs: List[BenchmarkConfig] = []
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_stages in [2, 3, 4, 5]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
class BestConfigTrace:
def __init__(self, name):
self.name = name
self.config = None
self.time_cost = float("inf")
self.time_cost_all = None # kernel0 without tma,, kernel0 with tma, kernel1 without tma, kernel1 with tma
def update(self, config, time_cost, time_cost_all):
if time_cost < self.time_cost:
print(
f"New best config for {self.name}: {config}, {time_cost=}, {time_cost_all=}, org: {self.config}, {self.time_cost_all}",
flush=True,
)
self.config = config
self.time_cost = time_cost
self.time_cost_all = time_cost_all
@property
def total_time(self):
return self.time_cost_all[0] + min(self.time_cost_all[2], self.time_cost_all[3])
def config_dict(self, down_moe=False):
if not down_moe:
return self.config
else:
return {
**self.config,
"USE_TMA": self.time_cost_all[2] > self.time_cost_all[3],
}
class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU.
self.device_id = 0 # int(ray.get_gpu_ids()[0])
def benchmark(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
cfg: Dict[str, int],
topk_ids_dir: str,
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(0)
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
kernel_time = benchmark_config(
cfg,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
topk_ids_dir,
block_shape,
)
return cfg, kernel_time
def tune(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
search_space: List[Dict[str, int]],
topk_ids_dir: str,
) -> Dict[str, int]:
trace0 = BestConfigTrace("kernel0")
trace1 = BestConfigTrace("kernel1")
trace2 = BestConfigTrace("kernel all")
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
for config in tqdm(search_space):
try:
kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
topk_ids_dir,
block_shape,
num_iters=10,
)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
kt0 = kt0_no_tma
kt1 = min(kt1_no_tma, kt1_tma)
trace0.update(
config,
kt0,
(kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma),
)
trace1.update(
config,
kt1,
(kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma),
)
trace2.update(
config,
kt0 + kt1,
(kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma),
)
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert trace0.config is not None
assert trace1.config is not None
print(
f"{num_tokens=}, {trace0.config=}, {trace0.time_cost_all=}, {trace1.config=}, {trace1.time_cost_all=}"
)
if trace0.config["BLOCK_SIZE_M"] != trace1.config["BLOCK_SIZE_M"]:
best_trace = trace0 if trace0.total_time < trace1.total_time else trace1
best_trace = (
best_trace if best_trace.total_time < trace2.total_time else trace2
)
return (
best_trace.config_dict(),
best_trace.config_dict(True),
best_trace.time_cost_all,
best_trace.time_cost_all,
)
return (
trace0.config_dict(),
trace1.config_dict(True),
trace0.time_cost_all,
trace1.time_cost_all,
)
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
**({"USE_TMA": config["USE_TMA"]} if "USE_TMA" in config else {}),
}
def save_configs(
configs: Dict[int, BenchmarkConfig],
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
down_moe: bool = False,
) -> None:
dtype_str = get_config_dtype_str(
dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(
num_experts,
shard_intermediate_size // 2,
dtype_str,
block_shape,
down_moe=down_moe,
)
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = (
config.num_experts_per_tok
+ (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.num_experts_per_tok
)
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1
)
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
topk_ids_dir = args.topk_ids_dir
if args.batch_size is None:
batch_sizes = [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
8192,
]
batch_sizes.reverse()
else:
batch_sizes = [args.batch_size]
if len(batch_sizes) == 1:
worker = BenchmarkWorker(args.seed)
if args.tune:
search_space = get_configs_compute_bound()
worker.tune(
batch_sizes[0],
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
search_space,
topk_ids_dir,
)
else:
cfg = {
"BLOCK_SIZE_M": args.configs[0],
"BLOCK_SIZE_N": args.configs[1],
"BLOCK_SIZE_K": args.configs[2],
"GROUP_SIZE_M": args.configs[3],
"num_warps": args.configs[4],
"num_stages": args.configs[5],
}
_, (t0, t0_tma, t1, t1_tma) = worker.benchmark(
args.batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
cfg,
topk_ids_dir,
)
print(f"{t0=}, {t0_tma=}, {t1=}, {t1_tma=}")
return
assert args.tune
ray.init()
num_gpus = int(ray.available_resources()["GPU"])
workers = [
ray.remote(num_gpus=1)(BenchmarkWorker).remote(args.seed)
for _ in range(num_gpus)
]
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
outputs = []
worker_idx = 0
for input_args in inputs:
worker = workers[worker_idx]
worker_method = getattr(worker, method)
output = worker_method.remote(*input_args)
outputs.append(output)
worker_idx = (worker_idx + 1) % num_gpus
return ray.get(outputs)
search_space = get_configs_compute_bound()
if block_shape is not None:
block_n, block_k = block_shape[0], block_shape[1]
search_space = [
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
]
print(f"Start tuning over {len(search_space)} configurations...")
start = time.perf_counter()
configs = _distribute(
"tune",
[
(
batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
search_space,
topk_ids_dir,
)
for batch_size in batch_sizes
],
)
print(f"{configs=}", flush=True)
cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
with open(f"tuning_result_{cur_time}.txt", "w") as f:
print(configs, file=f)
batch_sizes.reverse()
configs0 = [config[0] for config in configs]
configs1 = [config[1] for config in configs]
configs0.reverse()
configs1.reverse()
best_configs0 = {M: sort_config(config) for M, config in zip(batch_sizes, configs0)}
save_configs(
best_configs0,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
best_configs1 = {M: sort_config(config) for M, config in zip(batch_sizes, configs1)}
save_configs(
best_configs1,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
down_moe=True,
)
end = time.perf_counter()
print(f"Tuning took {end - start:.2f} seconds")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument(
"--dtype",
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
default="auto",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
parser.add_argument("--configs", type=int, nargs="+", required=False)
parser.add_argument("--topk-ids-dir", type=str, required=True)
args = parser.parse_args()
main(args)
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5,
"USE_TMA": false
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"USE_TMA": false
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"USE_TMA": false
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5,
"USE_TMA": false
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4,
"USE_TMA": false
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
}
}
...@@ -23,7 +23,11 @@ from sglang.srt.utils import ( ...@@ -23,7 +23,11 @@ from sglang.srt.utils import (
) )
from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config
from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton from .fused_moe_triton_kernels import (
invoke_fused_moe_kernel,
moe_sum_reduce_triton,
support_tensor_descriptor,
)
from .moe_align_block_size import moe_align_block_size from .moe_align_block_size import moe_align_block_size
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -78,6 +82,7 @@ def inplace_fused_experts( ...@@ -78,6 +82,7 @@ def inplace_fused_experts(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> None: ) -> None:
fused_experts_impl( fused_experts_impl(
hidden_states, hidden_states,
...@@ -106,6 +111,7 @@ def inplace_fused_experts( ...@@ -106,6 +111,7 @@ def inplace_fused_experts(
routed_scaling_factor, routed_scaling_factor,
gemm1_alpha, gemm1_alpha,
gemm1_limit, gemm1_limit,
filter_expert,
) )
...@@ -134,6 +140,7 @@ def inplace_fused_experts_fake( ...@@ -134,6 +140,7 @@ def inplace_fused_experts_fake(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> None: ) -> None:
pass pass
...@@ -172,6 +179,7 @@ def outplace_fused_experts( ...@@ -172,6 +179,7 @@ def outplace_fused_experts(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
...@@ -200,6 +208,7 @@ def outplace_fused_experts( ...@@ -200,6 +208,7 @@ def outplace_fused_experts(
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha, gemm1_alpha=gemm1_alpha,
gemm1_limit=gemm1_limit, gemm1_limit=gemm1_limit,
filter_expert=filter_expert,
) )
...@@ -229,6 +238,7 @@ def outplace_fused_experts_fake( ...@@ -229,6 +238,7 @@ def outplace_fused_experts_fake(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -263,6 +273,10 @@ def fused_experts( ...@@ -263,6 +273,10 @@ def fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
): ):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
filter_expert = (
moe_runner_config.num_experts is None
or moe_runner_config.num_experts != moe_runner_config.num_local_experts
)
if moe_runner_config.inplace: if moe_runner_config.inplace:
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense" assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts( torch.ops.sglang.inplace_fused_experts(
...@@ -290,6 +304,7 @@ def fused_experts( ...@@ -290,6 +304,7 @@ def fused_experts(
moe_runner_config.routed_scaling_factor, moe_runner_config.routed_scaling_factor,
moe_runner_config.gemm1_alpha, moe_runner_config.gemm1_alpha,
moe_runner_config.gemm1_clamp_limit, moe_runner_config.gemm1_clamp_limit,
filter_expert,
) )
return hidden_states return hidden_states
else: else:
...@@ -319,6 +334,7 @@ def fused_experts( ...@@ -319,6 +334,7 @@ def fused_experts(
routed_scaling_factor=moe_runner_config.routed_scaling_factor, routed_scaling_factor=moe_runner_config.routed_scaling_factor,
gemm1_alpha=moe_runner_config.gemm1_alpha, gemm1_alpha=moe_runner_config.gemm1_alpha,
gemm1_limit=moe_runner_config.gemm1_clamp_limit, gemm1_limit=moe_runner_config.gemm1_clamp_limit,
filter_expert=filter_expert,
) )
...@@ -336,6 +352,11 @@ def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit): ...@@ -336,6 +352,11 @@ def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1) return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1)
@functools.lru_cache()
def _down_moe_use_tma():
return support_tensor_descriptor()
def fused_experts_impl( def fused_experts_impl(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
...@@ -363,6 +384,7 @@ def fused_experts_impl( ...@@ -363,6 +384,7 @@ def fused_experts_impl(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
): ):
padded_size = padding_size padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
...@@ -402,25 +424,27 @@ def fused_experts_impl( ...@@ -402,25 +424,27 @@ def fused_experts_impl(
topk_ids.shape[1], topk_ids.shape[1],
config_dtype, config_dtype,
block_shape=block_shape, block_shape=block_shape,
return_down_config=True,
) )
config = get_config_func(M) config, (down_config, max_block_m) = get_config_func(M)
down_moe_use_tma = (
cache = torch.empty( _down_moe_use_tma()
M * topk_ids.shape[1] * max(N, w2.shape[1]), and down_config is not None
device=hidden_states.device, and down_config.pop("USE_TMA", False)
dtype=hidden_states.dtype,
) )
intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view( topk = topk_ids.shape[1]
(M, topk_ids.shape[1], N), max_padded_tokens = (
min(M * topk, E + 1) * (max_block_m - 1) if down_moe_use_tma else 0
) )
intermediate_cache2 = torch.empty( total_tokens = M * topk + max_padded_tokens
(M * topk_ids.shape[1], N // 2), cache = torch.empty(
total_tokens * max(N, w2.shape[1]),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view( intermediate_cache3 = cache[: M * topk * w2.shape[1]].view(
(M, topk_ids.shape[1], w2.shape[1]), (M, topk, w2.shape[1]),
) )
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
...@@ -428,7 +452,7 @@ def fused_experts_impl( ...@@ -428,7 +452,7 @@ def fused_experts_impl(
if no_combine: if no_combine:
assert not inplace assert not inplace
out_hidden_states = torch.empty( out_hidden_states = torch.empty(
(num_tokens, topk_ids.shape[1], w2.shape[1]), (num_tokens, topk, w2.shape[1]),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
...@@ -453,12 +477,28 @@ def fused_experts_impl( ...@@ -453,12 +477,28 @@ def fused_experts_impl(
# chunk. Note that in most cases we only have one chunk # chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and # so the cache size and config are already set correctly and
# do not need to be adjusted. # do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] config, (down_config, _) = get_config_func(tokens_in_chunk)
intermediate_cache2 = intermediate_cache2[ down_moe_use_tma = (
: tokens_in_chunk * topk_ids.shape[1] _down_moe_use_tma()
] and down_config is not None
and down_config.pop("USE_TMA", False)
)
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
config = get_config_func(tokens_in_chunk)
padded_tokens = (
min(tokens_in_chunk * topk, E + 1) * (config["BLOCK_SIZE_M"] - 1)
if down_moe_use_tma
else 0
)
total_tokens = tokens_in_chunk * topk + padded_tokens
intermediate_cache1 = cache[: total_tokens * N].view(
(total_tokens, N),
)
intermediate_cache2 = torch.empty(
(total_tokens, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
...@@ -490,6 +530,8 @@ def fused_experts_impl( ...@@ -490,6 +530,8 @@ def fused_experts_impl(
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
c_sorted=down_moe_use_tma,
filter_expert=filter_expert,
) )
if activation == "silu": if activation == "silu":
if gemm1_alpha is not None: if gemm1_alpha is not None:
...@@ -536,7 +578,7 @@ def fused_experts_impl( ...@@ -536,7 +578,7 @@ def fused_experts_impl(
num_tokens_post_padded, num_tokens_post_padded,
not apply_router_weight_on_input, not apply_router_weight_on_input,
1, 1,
config, down_config or config,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
...@@ -544,6 +586,9 @@ def fused_experts_impl( ...@@ -544,6 +586,9 @@ def fused_experts_impl(
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
a_use_tma=down_moe_use_tma,
b_use_tma=down_moe_use_tma,
filter_expert=filter_expert,
) )
if routed_scaling_factor is None: if routed_scaling_factor is None:
......
...@@ -21,6 +21,7 @@ def get_config_file_name( ...@@ -21,6 +21,7 @@ def get_config_file_name(
dtype: Optional[str], dtype: Optional[str],
block_shape: Optional[int] = None, block_shape: Optional[int] = None,
per_channel_quant: bool = False, per_channel_quant: bool = False,
down_moe: bool = False,
) -> str: ) -> str:
device_name = get_device_name().replace(" ", "_") device_name = get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}" dtype_selector = "" if not dtype else f",dtype={dtype}"
...@@ -28,7 +29,8 @@ def get_config_file_name( ...@@ -28,7 +29,8 @@ def get_config_file_name(
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
) )
per_channel_quant_selector = ",per_channel_quant=True" if per_channel_quant else "" per_channel_quant_selector = ",per_channel_quant=True" if per_channel_quant else ""
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}.json" down_moe_selector = "_down" if down_moe else ""
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}{down_moe_selector}.json"
@functools.lru_cache @functools.lru_cache
...@@ -39,6 +41,7 @@ def get_moe_configs( ...@@ -39,6 +41,7 @@ def get_moe_configs(
block_n: Optional[int] = 0, block_n: Optional[int] = 0,
block_k: Optional[int] = 0, block_k: Optional[int] = 0,
per_channel_quant: bool = False, per_channel_quant: bool = False,
down_moe: bool = False,
) -> Optional[Dict[int, Any]]: ) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
...@@ -54,7 +57,12 @@ def get_moe_configs( ...@@ -54,7 +57,12 @@ def get_moe_configs(
# First look up if an optimized configuration is available in the configs # First look up if an optimized configuration is available in the configs
# directory # directory
json_file_name = get_config_file_name( json_file_name = get_config_file_name(
E, N, dtype, [block_n, block_k], per_channel_quant E,
N,
dtype,
[block_n, block_k],
per_channel_quant,
down_moe=down_moe,
) )
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains, # We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
...@@ -177,9 +185,12 @@ def try_get_optimal_moe_config( ...@@ -177,9 +185,12 @@ def try_get_optimal_moe_config(
M: int, M: int,
is_marlin: bool = False, is_marlin: bool = False,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
return_down_config: bool = False,
): ):
from sglang.srt.layers.moe.fused_moe_triton import get_config from sglang.srt.layers.moe.fused_moe_triton import get_config
down_config = None
max_block_m = None
override_config = get_config() override_config = get_config()
if override_config: if override_config:
config = override_config config = override_config
...@@ -188,7 +199,7 @@ def try_get_optimal_moe_config( ...@@ -188,7 +199,7 @@ def try_get_optimal_moe_config(
E, _, N = w2_shape E, _, N = w2_shape
block_n = block_shape[0] if block_shape else 0 block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0 block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k) configs = get_moe_configs(E, N, dtype, block_n, block_k, down_moe=False)
if configs: if configs:
# If an optimal configuration map has been found, look up the # If an optimal configuration map has been found, look up the
...@@ -199,6 +210,21 @@ def try_get_optimal_moe_config( ...@@ -199,6 +210,21 @@ def try_get_optimal_moe_config(
config = get_default_config( config = get_default_config(
M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
) )
if return_down_config:
down_configs = get_moe_configs(E, N, dtype, block_n, block_k, down_moe=True)
if down_configs:
down_config = down_configs[
min(down_configs.keys(), key=lambda x: abs(x - M))
]
down_config = dict(**down_config)
max_block_m = max(
[cfg["BLOCK_SIZE_M"] for cfg in down_configs.values()]
)
if return_down_config:
assert (
down_config is None or config["BLOCK_SIZE_M"] == down_config["BLOCK_SIZE_M"]
)
return config, (down_config, max_block_m)
return config return config
......
...@@ -25,6 +25,13 @@ from sglang.srt.utils import ( ...@@ -25,6 +25,13 @@ from sglang.srt.utils import (
is_hip, is_hip,
) )
try:
from triton.tools.tensor_descriptor import TensorDescriptor
_support_tensor_descriptor = True
except:
_support_tensor_descriptor = False
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
...@@ -41,6 +48,10 @@ elif _is_hip: ...@@ -41,6 +48,10 @@ elif _is_hip:
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0 padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
def support_tensor_descriptor():
return _support_tensor_descriptor
@triton.jit @triton.jit
def write_zeros_to_output( def write_zeros_to_output(
c_ptr, c_ptr,
...@@ -108,6 +119,7 @@ def fused_moe_kernel_gptq_awq( ...@@ -108,6 +119,7 @@ def fused_moe_kernel_gptq_awq(
use_int4_w4a16: tl.constexpr, use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr, use_int8_w8a16: tl.constexpr,
even_Ks: tl.constexpr, even_Ks: tl.constexpr,
filter_expert: tl.constexpr,
): ):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
...@@ -161,7 +173,7 @@ def fused_moe_kernel_gptq_awq( ...@@ -161,7 +173,7 @@ def fused_moe_kernel_gptq_awq(
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1: if filter_expert and off_experts == -1:
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back zeros to the output when the expert is not # Write back zeros to the output when the expert is not
# in the current expert parallel rank. # in the current expert parallel rank.
...@@ -296,7 +308,9 @@ def fused_moe_kernel_gptq_awq( ...@@ -296,7 +308,9 @@ def fused_moe_kernel_gptq_awq(
def fused_moe_kernel( def fused_moe_kernel(
# Pointers to matrices # Pointers to matrices
a_ptr, a_ptr,
a_desc,
b_ptr, b_ptr,
b_desc,
bias_ptr, bias_ptr,
c_ptr, c_ptr,
a_scale_ptr, a_scale_ptr,
...@@ -344,6 +358,8 @@ def fused_moe_kernel( ...@@ -344,6 +358,8 @@ def fused_moe_kernel(
use_int8_w8a16: tl.constexpr, use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr, per_channel_quant: tl.constexpr,
even_Ks: tl.constexpr, even_Ks: tl.constexpr,
c_sorted: tl.constexpr,
filter_expert: tl.constexpr,
): ):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
...@@ -399,9 +415,10 @@ def fused_moe_kernel( ...@@ -399,9 +415,10 @@ def fused_moe_kernel(
offs_token = offs_token.to(tl.int64) offs_token = offs_token.to(tl.int64)
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) off_experts_i32 = tl.load(expert_ids_ptr + pid_m)
off_experts = off_experts_i32.to(tl.int64)
if off_experts == -1: if filter_expert and off_experts == -1:
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back zeros to the output when the expert is not # Write back zeros to the output when the expert is not
# in the current expert parallel rank. # in the current expert parallel rank.
...@@ -421,15 +438,23 @@ def fused_moe_kernel( ...@@ -421,15 +438,23 @@ def fused_moe_kernel(
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + ( if a_desc is not None:
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak assert use_fp8_w8a8 and group_n > 0 and group_k > 0
) start_offs_m = pid_m * BLOCK_SIZE_M
else:
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
if b_desc is not None:
start_offs_n = pid_n * BLOCK_SIZE_N
else:
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if bias_ptr is not None: if bias_ptr is not None:
bias = tl.load( bias = tl.load(
bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
...@@ -443,8 +468,14 @@ def fused_moe_kernel( ...@@ -443,8 +468,14 @@ def fused_moe_kernel(
if use_fp8_w8a8 or use_int8_w8a8: if use_fp8_w8a8 or use_int8_w8a8:
# block-wise # block-wise
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm if a_desc is not None:
offs_bsn = offs_bn // group_n a_scale_ptrs = a_scale_ptr + offs_token_id * stride_asm
else:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
if BLOCK_SIZE_N > group_n:
offs_bsn = offs_bn // group_n
else:
offs_bsn = pid_n * BLOCK_SIZE_N // group_n
b_scale_ptrs = ( b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
) )
...@@ -469,37 +500,49 @@ def fused_moe_kernel( ...@@ -469,37 +500,49 @@ def fused_moe_kernel(
# `accumulator` will be converted back to fp16 after the loop. # `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k_start in range(0, K, BLOCK_SIZE_K):
# Load the next block of A and B, generate a mask by checking the # Load the next block of A and B, generate a mask by checking the
# K dimension. # K dimension.
if even_Ks: if a_desc is not None:
a = a_desc.load([start_offs_m, k_start])
elif even_Ks:
a = tl.load( a = tl.load(
a_ptrs, a_ptrs,
mask=token_mask[:, None], mask=token_mask[:, None],
other=0.0, other=0.0,
) )
b = tl.load(b_ptrs)
else: else:
a = tl.load( a = tl.load(
a_ptrs, a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), mask=token_mask[:, None] & (offs_k[None, :] < K - k_start),
other=0.0, other=0.0,
) )
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
if b_desc is not None:
b = (
b_desc.load([off_experts_i32, start_offs_n, k_start])
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
.T
)
elif even_Ks:
b = tl.load(b_ptrs)
else:
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_start, other=0.0)
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_int8_w8a16: if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8 or use_int8_w8a8: elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k offs_ks = k_start // group_k
a_scale = tl.load( a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
) )
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
if BLOCK_SIZE_N > group_n:
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale)
else: else:
if use_fp8_w8a8: if use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator) accumulator = tl.dot(a, b, acc=accumulator)
...@@ -508,8 +551,10 @@ def fused_moe_kernel( ...@@ -508,8 +551,10 @@ def fused_moe_kernel(
else: else:
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block. # Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak if a_desc is None:
b_ptrs += BLOCK_SIZE_K * stride_bk a_ptrs += BLOCK_SIZE_K * stride_ak
if b_desc is None:
b_ptrs += BLOCK_SIZE_K * stride_bk
if use_int8_w8a16: if use_int8_w8a16:
accumulator *= b_scale accumulator *= b_scale
...@@ -528,7 +573,12 @@ def fused_moe_kernel( ...@@ -528,7 +573,12 @@ def fused_moe_kernel(
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] if c_sorted:
c_ptrs = (
c_ptr + stride_cm * offs_token_id[:, None] + stride_cn * offs_cn[None, :]
)
else:
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N) c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
...@@ -557,6 +607,10 @@ def invoke_fused_moe_kernel( ...@@ -557,6 +607,10 @@ def invoke_fused_moe_kernel(
per_channel_quant: bool, per_channel_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
a_use_tma: bool = False,
b_use_tma: bool = False,
c_sorted: bool = False,
filter_expert: bool = True,
) -> None: ) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -662,14 +716,38 @@ def invoke_fused_moe_kernel( ...@@ -662,14 +716,38 @@ def invoke_fused_moe_kernel(
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
even_Ks=even_Ks, even_Ks=even_Ks,
filter_expert=filter_expert,
**config, **config,
) )
else: else:
if a_use_tma or b_use_tma:
# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device="cuda", dtype=torch.int8)
triton.set_allocator(alloc_fn)
if a_use_tma:
a_desc = TensorDescriptor(
A, A.shape, A.stride(), [config["BLOCK_SIZE_M"], config["BLOCK_SIZE_K"]]
)
else:
a_desc = None
if b_use_tma:
b_desc = TensorDescriptor(
B,
B.shape,
B.stride(),
[1, config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"]],
)
else:
b_desc = None
fused_moe_kernel[grid]( fused_moe_kernel[grid](
A, A,
a_desc,
B, B,
b_desc,
bias, bias,
C, C,
A_scale, A_scale,
...@@ -689,8 +767,8 @@ def invoke_fused_moe_kernel( ...@@ -689,8 +767,8 @@ def invoke_fused_moe_kernel(
B.stride(1), B.stride(1),
bias.stride(0) if bias is not None else 0, bias.stride(0) if bias is not None else 0,
bias.stride(1) if bias is not None else 0, bias.stride(1) if bias is not None else 0,
C.stride(1), C.stride(-2),
C.stride(2), C.stride(-1),
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
...@@ -706,6 +784,8 @@ def invoke_fused_moe_kernel( ...@@ -706,6 +784,8 @@ def invoke_fused_moe_kernel(
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
even_Ks=even_Ks, even_Ks=even_Ks,
c_sorted=c_sorted,
filter_expert=filter_expert,
**config, **config,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment