Unverified Commit f78c0be8 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Fix benchmark_moe.py tuning for CUDA devices (#14164)

parent 66233af7
......@@ -2,6 +2,7 @@
import argparse
import time
from contextlib import nullcontext
from datetime import datetime
from itertools import product
from typing import Any, TypedDict
......@@ -412,7 +413,8 @@ class BenchmarkWorker:
hidden_size, search_space,
is_fp16, topk)
with torch.cuda.device(self.device_id):
with torch.cuda.device(self.device_id) if current_platform.is_rocm(
) else nullcontext():
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment