Unverified Commit a3b810eb authored by mpashkovskiy's avatar mpashkovskiy Committed by GitHub
Browse files

fix: enable multi-GPU Triton fused MoE tuning (#6295)

parent 94959237
......@@ -3,6 +3,7 @@ import argparse
import json
import time
from datetime import datetime
from contextlib import nullcontext
from typing import Any, Dict, List, Tuple, TypedDict
import ray
......@@ -21,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
)
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_hip, is_rocm
_is_hip = is_hip()
......@@ -245,6 +246,9 @@ class BenchmarkWorker:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark(
self,
......@@ -283,6 +287,7 @@ class BenchmarkWorker:
)
else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
kernel_time = benchmark_config(
config,
num_tokens,
......@@ -314,6 +319,7 @@ class BenchmarkWorker:
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment