Unverified Commit f6f567d0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix set of model parallel in the Trainer when no GPUs are available (#25239)

parent d27e4c18
...@@ -378,14 +378,17 @@ class Trainer: ...@@ -378,14 +378,17 @@ class Trainer:
devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
if len(devices) > 1: if len(devices) > 1:
self.is_model_parallel = True self.is_model_parallel = True
else: elif len(devices) == 1:
self.is_model_parallel = self.args.device != torch.device(devices[0]) self.is_model_parallel = self.args.device != torch.device(devices[0])
else:
self.is_model_parallel = False
# warn users # warn users
logger.info( if self.is_model_parallel:
"You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" logger.info(
" to `True` to avoid any unexpected behavior such as device placement mismatching." "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
) " to `True` to avoid any unexpected behavior such as device placement mismatching."
)
# At this stage the model is already loaded # At this stage the model is already loaded
if getattr(model, "is_quantized", False): if getattr(model, "is_quantized", False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment