Unverified Commit a3b810eb authored by mpashkovskiy's avatar mpashkovskiy Committed by GitHub
Browse files

fix: enable multi-GPU Triton fused MoE tuning (#6295)

parent 94959237
......@@ -3,6 +3,7 @@ import argparse
import json
import time
from datetime import datetime
from contextlib import nullcontext
from typing import Any, Dict, List, Tuple, TypedDict
import ray
......@@ -21,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
)
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_hip, is_rocm
_is_hip = is_hip()
......@@ -245,6 +246,9 @@ class BenchmarkWorker:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark(
self,
......@@ -283,19 +287,20 @@ class BenchmarkWorker:
)
else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
return config, kernel_time
def tune(
......@@ -314,29 +319,30 @@ class BenchmarkWorker:
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
num_iters=10,
)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
num_iters=10,
)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment