"tests/vscode:/vscode.git/clone" did not exist on "fd6902838afa35973f4fcc97ec0dcd1de888883e"
Unverified Commit cab048fb authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Trainer`] Force `is_model_parallel` when model is loaded in multiple GPUs...

[`Trainer`] Force `is_model_parallel` when model is loaded in multiple GPUs using `accelerate` (#22532)

* add `is_model_parallel` arg on Trainer

* add warning

* adapt from suggestions

* revert t5 changes

* remove commas

* adapt from suggestions
parent aecbcb36
......@@ -370,6 +370,19 @@ class Trainer:
else:
self.is_model_parallel = False
if (
getattr(model, "hf_device_map", None) is not None
and len([device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]) > 1
and not self.is_model_parallel
):
self.is_model_parallel = True
# warn users
logger.info(
"You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
" to `True` to avoid any unexpected behavior such as device placement mismatching."
)
# At this stage the model is already loaded
if getattr(model, "is_loaded_in_8bit", False):
if getattr(model, "_is_int8_training_enabled", False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment