Unverified Commit a3b810eb authored by mpashkovskiy's avatar mpashkovskiy Committed by GitHub
Browse files

fix: enable multi-GPU Triton fused MoE tuning (#6295)

parent 94959237
...@@ -3,6 +3,7 @@ import argparse ...@@ -3,6 +3,7 @@ import argparse
import json import json
import time import time
from datetime import datetime from datetime import datetime
from contextlib import nullcontext
from typing import Any, Dict, List, Tuple, TypedDict from typing import Any, Dict, List, Tuple, TypedDict
import ray import ray
...@@ -21,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ...@@ -21,7 +22,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
) )
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip, is_rocm
_is_hip = is_hip() _is_hip = is_hip()
...@@ -245,6 +246,9 @@ class BenchmarkWorker: ...@@ -245,6 +246,9 @@ class BenchmarkWorker:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
self.seed = seed self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark( def benchmark(
self, self,
...@@ -283,19 +287,20 @@ class BenchmarkWorker: ...@@ -283,19 +287,20 @@ class BenchmarkWorker:
) )
else: else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config( with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
config, kernel_time = benchmark_config(
num_tokens, config,
num_experts, num_tokens,
shard_intermediate_size, num_experts,
hidden_size, shard_intermediate_size,
topk, hidden_size,
dtype, topk,
use_fp8_w8a8, dtype,
use_int8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a8,
block_shape, use_int8_w8a16,
) block_shape,
)
return config, kernel_time return config, kernel_time
def tune( def tune(
...@@ -314,29 +319,30 @@ class BenchmarkWorker: ...@@ -314,29 +319,30 @@ class BenchmarkWorker:
) -> Dict[str, int]: ) -> Dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
for config in tqdm(search_space): with torch.cuda.device(self.device_id) if is_rocm() else nullcontext():
try: for config in tqdm(search_space):
kernel_time = benchmark_config( try:
config, kernel_time = benchmark_config(
num_tokens, config,
num_experts, num_tokens,
shard_intermediate_size, num_experts,
hidden_size, shard_intermediate_size,
topk, hidden_size,
dtype, topk,
use_fp8_w8a8, dtype,
use_int8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a8,
block_shape, use_int8_w8a16,
num_iters=10, block_shape,
) num_iters=10,
except triton.runtime.autotuner.OutOfResources: )
# Some configurations may be invalid and fail to compile. except triton.runtime.autotuner.OutOfResources:
continue # Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time if kernel_time < best_time:
best_config = config best_time = kernel_time
best_config = config
now = datetime.now() now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None assert best_config is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment