Unverified Commit 1531b319 authored by Chang Lan's avatar Chang Lan Committed by GitHub
Browse files

Add an argument to set bucket_cap_mb for PyTorch DDP (#14756)

* [trainer] Set bucket_cap_mb for DDP from arguments

* Put find_unused_parameters into kwargs
parent 3883e3a7
...@@ -999,19 +999,23 @@ class Trainer: ...@@ -999,19 +999,23 @@ class Trainer:
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.args.local_rank != -1: elif self.args.local_rank != -1:
kwargs = {}
if self.args.ddp_find_unused_parameters is not None: if self.args.ddp_find_unused_parameters is not None:
find_unused_parameters = self.args.ddp_find_unused_parameters kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
elif isinstance(model, PreTrainedModel): elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per # find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
find_unused_parameters = not model.is_gradient_checkpointing kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
else: else:
find_unused_parameters = True kwargs["find_unused_parameters"] = True
if self.args.ddp_bucket_cap_mb is not None:
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
model = nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None, device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None, output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
find_unused_parameters=find_unused_parameters, **kwargs,
) )
return model return model
......
...@@ -348,6 +348,9 @@ class TrainingArguments: ...@@ -348,6 +348,9 @@ class TrainingArguments:
When using distributed training, the value of the flag :obj:`find_unused_parameters` passed to When using distributed training, the value of the flag :obj:`find_unused_parameters` passed to
:obj:`DistributedDataParallel`. Will default to :obj:`False` if gradient checkpointing is used, :obj:`True` :obj:`DistributedDataParallel`. Will default to :obj:`False` if gradient checkpointing is used, :obj:`True`
otherwise. otherwise.
ddp_bucket_cap_mb (:obj:`int`, `optional`):
When using distributed training, the value of the flag :obj:`bucket_cap_mb` passed to
:obj:`DistributedDataParallel`.
dataloader_pin_memory (:obj:`bool`, `optional`, defaults to :obj:`True`): dataloader_pin_memory (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether you want to pin memory in data loaders or not. Will default to :obj:`True`. Whether you want to pin memory in data loaders or not. Will default to :obj:`True`.
skip_memory_metrics (:obj:`bool`, `optional`, defaults to :obj:`True`): skip_memory_metrics (:obj:`bool`, `optional`, defaults to :obj:`True`):
...@@ -665,6 +668,13 @@ class TrainingArguments: ...@@ -665,6 +668,13 @@ class TrainingArguments:
"`DistributedDataParallel`." "`DistributedDataParallel`."
}, },
) )
ddp_bucket_cap_mb: Optional[int] = field(
default=None,
metadata={
"help": "When using distributed training, the value of the flag `bucket_cap_mb` passed to "
"`DistributedDataParallel`."
},
)
dataloader_pin_memory: bool = field( dataloader_pin_memory: bool = field(
default=True, metadata={"help": "Whether or not to pin memory for DataLoader."} default=True, metadata={"help": "Whether or not to pin memory for DataLoader."}
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment