Unverified Commit abda2542 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Fix tuning_fused_moe_triton.py (#8175)

parent 8cddfa56
......@@ -18,6 +18,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
get_default_config,
get_moe_configs,
)
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.utils import is_hip
_is_hip = is_hip()
......@@ -115,10 +116,15 @@ def benchmark_config(
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_output = select_experts(x, input_gating, topk, renormalize=True)
def prepare(i: int):
input_gating.copy_(gating_output[i])
input_gating = gating_output[i]
new_topk_output = select_experts(x, input_gating, topk, renormalize=True)
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
topk_output.router_logits.copy_(new_topk_output.router_logits)
def run():
from sglang.srt.layers.moe.fused_moe_triton import override_config
......@@ -128,9 +134,7 @@ def benchmark_config(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
topk_output,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment