"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "cad1b1192b3ad98808d24f898e28fb56f78720d3"
Unverified Commit ee88ae59 authored by Teven's avatar Teven Committed by GitHub
Browse files

Adding ddp_broadcast_buffers argument to Trainer (#24326)

adding ddp_broadcast_buffers argument
parent 91389950
...@@ -1450,6 +1450,9 @@ class Trainer: ...@@ -1450,6 +1450,9 @@ class Trainer:
if self.args.ddp_bucket_cap_mb is not None: if self.args.ddp_bucket_cap_mb is not None:
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if self.args.ddp_broadcast_buffers is not None:
kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
return model return model
......
...@@ -505,6 +505,9 @@ class TrainingArguments: ...@@ -505,6 +505,9 @@ class TrainingArguments:
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise. `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
ddp_bucket_cap_mb (`int`, *optional*): ddp_bucket_cap_mb (`int`, *optional*):
When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`. When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`.
ddp_broadcast_buffers (`bool`, *optional*):
When using distributed training, the value of the flag `broadcast_buffers` passed to
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
dataloader_pin_memory (`bool`, *optional*, defaults to `True`): dataloader_pin_memory (`bool`, *optional*, defaults to `True`):
Whether you want to pin memory in data loaders or not. Will default to `True`. Whether you want to pin memory in data loaders or not. Will default to `True`.
skip_memory_metrics (`bool`, *optional*, defaults to `True`): skip_memory_metrics (`bool`, *optional*, defaults to `True`):
...@@ -1045,6 +1048,15 @@ class TrainingArguments: ...@@ -1045,6 +1048,15 @@ class TrainingArguments:
) )
}, },
) )
ddp_broadcast_buffers: Optional[bool] = field(
default=None,
metadata={
"help": (
"When using distributed training, the value of the flag `broadcast_buffers` passed to "
"`DistributedDataParallel`."
)
},
)
dataloader_pin_memory: bool = field( dataloader_pin_memory: bool = field(
default=True, metadata={"help": "Whether or not to pin memory for DataLoader."} default=True, metadata={"help": "Whether or not to pin memory for DataLoader."}
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment