Unverified Commit c7b7bd99 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a flag for find_unused_parameters (#9820)



* Add a flag for find_unused_parameters

* Apply suggestions from code review
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Remove negation
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 4adbdce5
...@@ -761,18 +761,20 @@ class Trainer: ...@@ -761,18 +761,20 @@ class Trainer:
elif is_sagemaker_distributed_available(): elif is_sagemaker_distributed_available():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.args.local_rank != -1: elif self.args.local_rank != -1:
if self.args.ddp_find_unused_parameters is not None:
find_unused_parameters = self.args.ddp_find_unused_parameters
elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False)
else:
find_unused_parameters = True
model = torch.nn.parallel.DistributedDataParallel( model = torch.nn.parallel.DistributedDataParallel(
model, model,
device_ids=[self.args.local_rank], device_ids=[self.args.local_rank],
output_device=self.args.local_rank, output_device=self.args.local_rank,
find_unused_parameters=( find_unused_parameters=find_unused_parameters,
not getattr(model.config, "gradient_checkpointing", False)
if isinstance(model, PreTrainedModel)
else True
),
) )
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
# for the rest of this function `model` is the outside model, whether it was wrapped or not # for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model: if model is not self.model:
......
...@@ -240,6 +240,10 @@ class TrainingArguments: ...@@ -240,6 +240,10 @@ class TrainingArguments:
report_to (:obj:`List[str]`, `optional`, defaults to the list of integrations platforms installed): report_to (:obj:`List[str]`, `optional`, defaults to the list of integrations platforms installed):
The list of integrations to report the results and logs to. Supported platforms are :obj:`"azure_ml"`, The list of integrations to report the results and logs to. Supported platforms are :obj:`"azure_ml"`,
:obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`. :obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`.
ddp_find_unused_parameters (:obj:`bool`, `optional`):
When using distributed training, the value of the flag :obj:`find_unused_parameters` passed to
:obj:`DistributedDataParallel`. Will defaut to :obj:`False` if gradient checkpointing is used, :obj:`True`
otherwise.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -425,6 +429,13 @@ class TrainingArguments: ...@@ -425,6 +429,13 @@ class TrainingArguments:
report_to: Optional[List[str]] = field( report_to: Optional[List[str]] = field(
default=None, metadata={"help": "The list of integrations to report the results and logs to."} default=None, metadata={"help": "The list of integrations to report the results and logs to."}
) )
ddp_find_unused_parameters: Optional[bool] = field(
default=None,
metadata={
"help": "When using distributed training, the value of the flag `find_unused_parameters` passed to "
"`DistributedDataParallel`."
},
)
_n_gpu: int = field(init=False, repr=False, default=-1) _n_gpu: int = field(init=False, repr=False, default=-1)
def __post_init__(self): def __post_init__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment