"vscode:/vscode.git/clone" did not exist on "3b419cfc6fdcdd09ef02ae05772295c836ce3cd5"
Unverified Commit 68d53bc7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix Trainer when model is loaded on a different GPU (#23792)

parent 0963a250
...@@ -405,12 +405,12 @@ class Trainer: ...@@ -405,12 +405,12 @@ class Trainer:
else: else:
self.is_model_parallel = False self.is_model_parallel = False
if ( if getattr(model, "hf_device_map", None) is not None:
getattr(model, "hf_device_map", None) is not None devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
and len([device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]) > 1 if len(devices) > 1:
and not self.is_model_parallel
):
self.is_model_parallel = True self.is_model_parallel = True
else:
self.is_model_parallel = self.args.device != torch.device(devices[0])
# warn users # warn users
logger.info( logger.info(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment