Unverified Commit 3b734f50 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Add dispatch_batches to training arguments (#25038)

* Dispatch batches

* Copy items
parent 9d2b983e
...@@ -3806,7 +3806,9 @@ class Trainer: ...@@ -3806,7 +3806,9 @@ class Trainer:
# create accelerator object # create accelerator object
self.accelerator = Accelerator( self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin dispatch_batches=self.args.dispatch_batches,
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
) )
# deepspeed and accelerate flags covering both trainer args and accelerate launcher # deepspeed and accelerate flags covering both trainer args and accelerate launcher
......
...@@ -1200,6 +1200,15 @@ class TrainingArguments: ...@@ -1200,6 +1200,15 @@ class TrainingArguments:
}, },
) )
dispatch_batches: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to dispatch batches across devices in distributed training. If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
"and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
"underlying dataset is an `IterableDataset`, `False` otherwise."
},
)
def __post_init__(self): def __post_init__(self):
# expand paths, if not os.makedirs("~/bar") will make directory # expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home # in the current directory instead of the actual home
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
from typing import Dict from typing import Dict
import numpy as np
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
TestCasePlus, TestCasePlus,
...@@ -33,7 +35,7 @@ logger = logging.get_logger(__name__) ...@@ -33,7 +35,7 @@ logger = logging.get_logger(__name__)
if is_torch_available(): if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from torch.utils.data import Dataset from torch.utils.data import Dataset, IterableDataset
from transformers import Trainer from transformers import Trainer
...@@ -63,6 +65,56 @@ if is_torch_available(): ...@@ -63,6 +65,56 @@ if is_torch_available():
else: else:
return input_ids return input_ids
class RegressionModel(nn.Module):
def __init__(self, a=0, b=0, double_output=False):
super().__init__()
self.a = nn.Parameter(torch.tensor(a).float())
self.b = nn.Parameter(torch.tensor(b).float())
self.double_output = double_output
self.config = None
def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
if labels is None:
return (y, y) if self.double_output else (y,)
loss = nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)
class SampleIterableDataset(IterableDataset):
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
self.dataset = RegressionDataset(a=a, b=b, length=length, seed=seed, label_names=label_names)
def __iter__(self):
for i in range(len(self.dataset)):
yield self.dataset[i]
class FiniteIterableDataset(SampleIterableDataset):
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
super().__init__(a, b, length, seed, label_names)
self.current_sample = 0
def __iter__(self):
while self.current_sample < len(self.dataset):
yield self.dataset[self.current_sample]
self.current_sample += 1
class RegressionDataset:
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
np.random.seed(seed)
self.label_names = ["labels"] if label_names is None else label_names
self.length = length
self.x = np.random.normal(size=(length,)).astype(np.float32)
self.ys = [a * self.x + b + np.random.normal(scale=0.1, size=(length,)) for _ in self.label_names]
self.ys = [y.astype(np.float32) for y in self.ys]
def __len__(self):
return self.length
def __getitem__(self, i):
result = {name: y[i] for name, y in zip(self.label_names, self.ys)}
result["input_x"] = self.x[i]
return result
class TestTrainerDistributedNeuronCore(TestCasePlus): class TestTrainerDistributedNeuronCore(TestCasePlus):
@require_torch_neuroncore @require_torch_neuroncore
...@@ -168,3 +220,14 @@ if __name__ == "__main__": ...@@ -168,3 +220,14 @@ if __name__ == "__main__":
exit(1) exit(1)
trainer.args.eval_accumulation_steps = None trainer.args.eval_accumulation_steps = None
# Check that `dispatch_batches=False` will work on a finite iterable dataset
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
model = RegressionModel()
training_args.per_device_train_batch_size = 1
training_args.max_steps = 1
training_args.dispatch_batches = False
trainer = Trainer(model, training_args, train_dataset=train_dataset)
trainer.train()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment