Unverified Commit e4910213 authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Warn on TPUs when the custom optimizer and model device are not the same (#18668)

* Check optimizer for device on TPU

* Typo
parent cdde85a0
...@@ -465,6 +465,21 @@ class Trainer: ...@@ -465,6 +465,21 @@ class Trainer:
"Passing a `model_init` is incompatible with providing the `optimizers` argument. " "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
) )
if is_torch_tpu_available() and self.optimizer is not None:
for param in self.model.parameters():
model_device = param.device
break
for param_group in self.optimizer.param_groups:
if len(param_group["params"]) > 0:
optimizer_device = param_group["params"][0].device
break
if model_device != optimizer_device:
raise ValueError(
"The model and the optimizer parameters are not on the same device, which probably means you"
" created an optimizer around your model **before** putting on the device and passing it to the"
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
)
if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and ( if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
self.optimizer is not None or self.lr_scheduler is not None self.optimizer is not None or self.lr_scheduler is not None
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment