Unverified Commit f78c0be8 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Fix benchmark_moe.py tuning for CUDA devices (#14164)

parent 66233af7
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import argparse import argparse
import time import time
from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from itertools import product from itertools import product
from typing import Any, TypedDict from typing import Any, TypedDict
...@@ -412,7 +413,8 @@ class BenchmarkWorker: ...@@ -412,7 +413,8 @@ class BenchmarkWorker:
hidden_size, search_space, hidden_size, search_space,
is_fp16, topk) is_fp16, topk)
with torch.cuda.device(self.device_id): with torch.cuda.device(self.device_id) if current_platform.is_rocm(
) else nullcontext():
for config in tqdm(search_space): for config in tqdm(search_space):
try: try:
kernel_time = benchmark_config( kernel_time = benchmark_config(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment