Unverified Commit d7c8ce57 authored by Sander Land's avatar Sander Land Committed by GitHub
Browse files

Avoid accessing .dataset of a DataLoader in Trainer (#16451)



* Avoid accessing .dataset of a dataloader

* style

* fix

* cleaning up, reverting some misunderstandings

* black

* add train_dataset argument to get_train_dataloader, and fix other instances of length checks

* flake8

* address comments

* fix bug

* cleanup

* add test

* Update tests/trainer/test_trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* under torch

* merge

* stylistic suggestion
Co-authored-by: default avatarSander Land <sander@chatdesk.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 781af736
......@@ -585,7 +585,7 @@ class Trainer:
return dataset.remove_columns(ignored_columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if not has_length(self.train_dataset):
if self.train_dataset is None or not has_length(self.train_dataset):
return None
generator = None
......@@ -661,8 +661,8 @@ class Trainer:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Will use no sampler if `self.train_dataset` does not implement `__len__`, a random sampler (adapted to
distributed training if necessary) otherwise.
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
"""
......@@ -937,11 +937,13 @@ class Trainer:
def num_examples(self, dataloader: DataLoader) -> int:
"""
Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset.
Will raise an exception if the underlying dataset does not implement method `__len__`
Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
dataloader.dataset does not exist or has no length, estimates as best it can
"""
try:
return len(dataloader.dataset)
except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader
return len(dataloader) * self.args.per_device_train_batch_size
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
"""HP search setup code"""
......@@ -1198,9 +1200,6 @@ class Trainer:
self._move_model_to_device(self.model, args.device)
self.model_wrapped = self.model
# Keeping track whether we can can len() on the dataset or not
train_dataset_is_sized = has_length(self.train_dataset)
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
......@@ -1209,28 +1208,36 @@ class Trainer:
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
if train_dataset_is_sized:
num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
len_dataloader = None
if has_length(train_dataloader):
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
num_examples = self.num_examples(train_dataloader)
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0
)
# May be slightly incorrect if the last batch in the training datalaoder has a smaller size but it's
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# the best we can do.
num_train_samples = args.max_steps * total_train_batch_size
else:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = len(self.train_dataset) * args.num_train_epochs
else:
# see __init__. max_steps is set when the dataset has no __len__
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
max_steps = args.max_steps
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
else:
raise ValueError(
f"args.max_steps must be set to a positive value if dataloader does not have a length, was {args.max_steps}"
)
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1:
......@@ -1281,10 +1288,6 @@ class Trainer:
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
# Train!
num_examples = (
self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
......@@ -1370,7 +1373,7 @@ class Trainer:
for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
elif isinstance(train_dataloader.dataset, IterableDatasetShard):
elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
train_dataloader.dataset.set_epoch(epoch)
if is_torch_tpu_available():
......@@ -1384,7 +1387,9 @@ class Trainer:
self._past = None
steps_in_epoch = (
len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps
len(epoch_iterator)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
......@@ -2407,10 +2412,10 @@ class Trainer:
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = dataloader.batch_size
batch_size = self.args.per_device_eval_batch_size
logger.info(f"***** Running {description} *****")
if has_length(dataloader.dataset):
if has_length(dataloader):
logger.info(f" Num examples = {self.num_examples(dataloader)}")
else:
logger.info(" Num examples: Unknown")
......@@ -2420,7 +2425,7 @@ class Trainer:
self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
eval_dataset = dataloader.dataset
eval_dataset = getattr(dataloader, "dataset", None)
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
......@@ -2512,6 +2517,9 @@ class Trainer:
elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
num_samples = eval_dataset.num_examples
else:
if has_length(dataloader):
num_samples = self.num_examples(dataloader)
else: # both len(dataloader.dataset) and len(dataloader) fail
num_samples = observed_num_examples
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
......@@ -2899,8 +2907,9 @@ class Trainer:
"""
args = self.args
if not has_length(dataloader.dataset):
raise ValueError("dataset must implement __len__")
if not has_length(dataloader):
raise ValueError("dataloader must implement a working __len__")
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
# if eval is called w/o train init deepspeed here
......
......@@ -473,7 +473,7 @@ class ProgressCallback(TrainerCallback):
self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero and has_length(eval_dataloader.dataset):
if state.is_local_process_zero and has_length(eval_dataloader):
if self.prediction_bar is None:
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
self.prediction_bar.update(1)
......
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import re
import time
from typing import Optional
......@@ -21,7 +20,7 @@ from typing import Optional
import IPython.display as disp
from ..trainer_callback import TrainerCallback
from ..trainer_utils import IntervalStrategy
from ..trainer_utils import IntervalStrategy, has_length
def format_time(t):
......@@ -294,7 +293,7 @@ class NotebookProgressCallback(TrainerCallback):
self._force_next_update = False
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if not isinstance(eval_dataloader.dataset, collections.abc.Sized):
if not has_length(eval_dataloader):
return
if self.prediction_bar is None:
if self.training_tracker is not None:
......
......@@ -189,6 +189,26 @@ if is_torch_available():
yield self.dataset[self.current_sample]
self.current_sample += 1
class MultiLoader:
def __init__(self, loaders):
self.loaders = loaders
def __len__(self):
return sum(len(loader) for loader in self.loaders)
def __iter__(self):
for loader in self.loaders:
yield from loader
class CustomDataloaderTrainer(Trainer):
def get_train_dataloader(self):
dataloaders = [super().get_train_dataloader(), super().get_train_dataloader()]
return MultiLoader(dataloaders)
def get_eval_dataloader(self, eval_dataset):
dataloaders = [super().get_eval_dataloader(eval_dataset), super().get_eval_dataloader(eval_dataset)]
return MultiLoader(dataloaders)
class RegressionModel(nn.Module):
def __init__(self, a=0, b=0, double_output=False):
super().__init__()
......@@ -647,6 +667,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
new_eval_dataset = RegressionDataset(length=128)
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
# tests that we do not require dataloader to have a .dataset attribute
def test_dataloader_without_dataset(self):
train_dataset = RegressionDataset(length=128)
trainer = CustomDataloaderTrainer(
model=RegressionModel(), train_dataset=train_dataset, eval_dataset=train_dataset
)
trainer.train()
trainer.evaluate()
def test_sampler_seed(self):
# nb: we don't want to inherit from IterableDataset to hit the right code path
class DummyDataset(torch.utils.data.Dataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment