Unverified Commit ee88ae59 authored by Teven's avatar Teven Committed by GitHub
Browse files

Adding ddp_broadcast_buffers argument to Trainer (#24326)

adding ddp_broadcast_buffers argument
parent 91389950
......@@ -1450,6 +1450,9 @@ class Trainer:
if self.args.ddp_bucket_cap_mb is not None:
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if self.args.ddp_broadcast_buffers is not None:
kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
return model
......
......@@ -505,6 +505,9 @@ class TrainingArguments:
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
ddp_bucket_cap_mb (`int`, *optional*):
When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`.
ddp_broadcast_buffers (`bool`, *optional*):
When using distributed training, the value of the flag `broadcast_buffers` passed to
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
dataloader_pin_memory (`bool`, *optional*, defaults to `True`):
Whether you want to pin memory in data loaders or not. Will default to `True`.
skip_memory_metrics (`bool`, *optional*, defaults to `True`):
......@@ -1045,6 +1048,15 @@ class TrainingArguments:
)
},
)
ddp_broadcast_buffers: Optional[bool] = field(
default=None,
metadata={
"help": (
"When using distributed training, the value of the flag `broadcast_buffers` passed to "
"`DistributedDataParallel`."
)
},
)
dataloader_pin_memory: bool = field(
default=True, metadata={"help": "Whether or not to pin memory for DataLoader."}
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment