Unverified Commit f8b757bc authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix: resolve tuning fused moe issue (#9587)

parent ebd9dbe7
...@@ -22,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ...@@ -22,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
) )
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip, is_rocm from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
...@@ -287,7 +287,7 @@ class BenchmarkWorker: ...@@ -287,7 +287,7 @@ class BenchmarkWorker:
) )
else: else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext(): with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
kernel_time = benchmark_config( kernel_time = benchmark_config(
config, config,
num_tokens, num_tokens,
...@@ -319,7 +319,7 @@ class BenchmarkWorker: ...@@ -319,7 +319,7 @@ class BenchmarkWorker:
) -> Dict[str, int]: ) -> Dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext(): with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
for config in tqdm(search_space): for config in tqdm(search_space):
try: try:
kernel_time = benchmark_config( kernel_time = benchmark_config(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment