Unverified Commit f6f567d0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix set of model parallel in the Trainer when no GPUs are available (#25239)

parent d27e4c18
......@@ -378,14 +378,17 @@ class Trainer:
devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
if len(devices) > 1:
self.is_model_parallel = True
else:
elif len(devices) == 1:
self.is_model_parallel = self.args.device != torch.device(devices[0])
else:
self.is_model_parallel = False
# warn users
logger.info(
"You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
" to `True` to avoid any unexpected behavior such as device placement mismatching."
)
if self.is_model_parallel:
logger.info(
"You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
" to `True` to avoid any unexpected behavior such as device placement mismatching."
)
# At this stage the model is already loaded
if getattr(model, "is_quantized", False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment