"docs/source/vscode:/vscode.git/clone" did not exist on "a0e69a9375c808296e69332728021c4f4f35c327"
Unverified Commit 68d53bc7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix Trainer when model is loaded on a different GPU (#23792)

parent 0963a250
......@@ -405,12 +405,12 @@ class Trainer:
else:
self.is_model_parallel = False
if (
getattr(model, "hf_device_map", None) is not None
and len([device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]) > 1
and not self.is_model_parallel
):
self.is_model_parallel = True
if getattr(model, "hf_device_map", None) is not None:
devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
if len(devices) > 1:
self.is_model_parallel = True
else:
self.is_model_parallel = self.args.device != torch.device(devices[0])
# warn users
logger.info(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment