"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e1a5cc338ba9fba27b0ca1fb54c9951c5146a86f"
Unverified Commit 3520e37e authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Enable split_batches through TrainingArguments (#26798)

* Enable split_batches through TrainingArguments

* Extra dispatch_batches

* Keep as default false

* Add to docstring

* Add to docstring

* Remove the capturewarnings change

* Comma
parent 95020f20
...@@ -3906,6 +3906,7 @@ class Trainer: ...@@ -3906,6 +3906,7 @@ class Trainer:
# create accelerator object # create accelerator object
self.accelerator = Accelerator( self.accelerator = Accelerator(
dispatch_batches=self.args.dispatch_batches, dispatch_batches=self.args.dispatch_batches,
split_batches=self.args.split_batches,
deepspeed_plugin=self.args.deepspeed_plugin, deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin,
) )
......
...@@ -621,6 +621,14 @@ class TrainingArguments: ...@@ -621,6 +621,14 @@ class TrainingArguments:
Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions. Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
This flag is experimental and subject to change in future releases. This flag is experimental and subject to change in future releases.
split_batches (`bool`, *optional*):
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices
during distributed training. If
set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it
must be a
round multiple of the number of processes you are using (such as GPUs).
include_tokens_per_second (`bool`, *optional*): include_tokens_per_second (`bool`, *optional*):
Whether or not to compute the number of tokens per second per device for training speed metrics. Whether or not to compute the number of tokens per second per device for training speed metrics.
...@@ -1226,6 +1234,15 @@ class TrainingArguments: ...@@ -1226,6 +1234,15 @@ class TrainingArguments:
}, },
) )
split_batches: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices during distributed training. If"
"set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a"
"round multiple of the number of processes you are using (such as GPUs)."
},
)
include_tokens_per_second: Optional[bool] = field( include_tokens_per_second: Optional[bool] = field(
default=False, default=False,
metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."}, metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment