Unverified Commit 63a0c8f1 authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[tests] enable benchmark unit tests on XPU (#29284)

* add xpu for benchmark

* no auto_map

* use require_torch_gpu

* use gpu

* revert

* revert

* fix style
parent 6d3b643e
...@@ -17,7 +17,14 @@ ...@@ -17,7 +17,14 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple from typing import Tuple
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends from ..utils import (
cached_property,
is_torch_available,
is_torch_tpu_available,
is_torch_xpu_available,
logging,
requires_backends,
)
from .benchmark_args_utils import BenchmarkArguments from .benchmark_args_utils import BenchmarkArguments
...@@ -84,6 +91,9 @@ class PyTorchBenchmarkArguments(BenchmarkArguments): ...@@ -84,6 +91,9 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
elif is_torch_tpu_available(): elif is_torch_tpu_available():
device = xm.xla_device() device = xm.xla_device()
n_gpu = 0 n_gpu = 0
elif is_torch_xpu_available():
device = torch.device("xpu")
n_gpu = torch.xpu.device_count()
else: else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count() n_gpu = torch.cuda.device_count()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment