Unverified Commit a3b810eb authored by mpashkovskiy's avatar mpashkovskiy Committed by GitHub
Browse files

fix: enable multi-GPU Triton fused MoE tuning (#6295)

parent 94959237
...@@ -3,6 +3,7 @@ import argparse ...@@ -3,6 +3,7 @@ import argparse
import json import json
import time import time
from datetime import datetime from datetime import datetime
from contextlib import nullcontext
from typing import Any, Dict, List, Tuple, TypedDict from typing import Any, Dict, List, Tuple, TypedDict
import ray import ray
...@@ -21,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ...@@ -21,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
) )
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip, is_rocm
_is_hip = is_hip() _is_hip = is_hip()
...@@ -245,6 +246,9 @@ class BenchmarkWorker: ...@@ -245,6 +246,9 @@ class BenchmarkWorker:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
self.seed = seed self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark( def benchmark(
self, self,
...@@ -283,6 +287,7 @@ class BenchmarkWorker: ...@@ -283,6 +287,7 @@ class BenchmarkWorker:
) )
else: else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
kernel_time = benchmark_config( kernel_time = benchmark_config(
config, config,
num_tokens, num_tokens,
...@@ -314,6 +319,7 @@ class BenchmarkWorker: ...@@ -314,6 +319,7 @@ class BenchmarkWorker:
) -> Dict[str, int]: ) -> Dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
for config in tqdm(search_space): for config in tqdm(search_space):
try: try:
kernel_time = benchmark_config( kernel_time = benchmark_config(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment