"...git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "f0cb1d123f22183b0a52c2b4da4a55fa97f170ac"
Unverified Commit b25f7802 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Should check that torch TPU is available (#5636)

parent 3cc23eee
...@@ -34,6 +34,7 @@ from .file_utils import ( ...@@ -34,6 +34,7 @@ from .file_utils import (
cached_path, cached_path,
hf_bucket_url, hf_bucket_url,
is_remote_url, is_remote_url,
is_torch_tpu_available,
) )
from .generation_utils import GenerationMixin from .generation_utils import GenerationMixin
...@@ -794,7 +795,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -794,7 +795,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
} }
return model, loading_info return model, loading_info
if hasattr(config, "xla_device") and config.xla_device: if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
model = xm.send_cpu_data_to_device(model, xm.xla_device()) model = xm.send_cpu_data_to_device(model, xm.xla_device())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment