Unverified Commit d7c8ce57 authored by Sander Land's avatar Sander Land Committed by GitHub
Browse files

Avoid accessing .dataset of a DataLoader in Trainer (#16451)



* Avoid accessing .dataset of a dataloader

* style

* fix

* cleaning up, reverting some misunderstandings

* black

* add train_dataset argument to get_train_dataloader, and fix other instances of length checks

* flake8

* address comments

* fix bug

* cleanup

* add test

* Update tests/trainer/test_trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* under torch

* merge

* stylistic suggestion
Co-authored-by: default avatarSander Land <sander@chatdesk.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 781af736
...@@ -585,7 +585,7 @@ class Trainer: ...@@ -585,7 +585,7 @@ class Trainer:
return dataset.remove_columns(ignored_columns) return dataset.remove_columns(ignored_columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if not has_length(self.train_dataset): if self.train_dataset is None or not has_length(self.train_dataset):
return None return None
generator = None generator = None
...@@ -661,8 +661,8 @@ class Trainer: ...@@ -661,8 +661,8 @@ class Trainer:
""" """
Returns the training [`~torch.utils.data.DataLoader`]. Returns the training [`~torch.utils.data.DataLoader`].
Will use no sampler if `self.train_dataset` does not implement `__len__`, a random sampler (adapted to Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
distributed training if necessary) otherwise. training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior. Subclass and override this method if you want to inject some custom behavior.
""" """
...@@ -937,11 +937,13 @@ class Trainer: ...@@ -937,11 +937,13 @@ class Trainer:
def num_examples(self, dataloader: DataLoader) -> int: def num_examples(self, dataloader: DataLoader) -> int:
""" """
Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
dataloader.dataset does not exist or has no length, estimates as best it can
Will raise an exception if the underlying dataset does not implement method `__len__`
""" """
try:
return len(dataloader.dataset) return len(dataloader.dataset)
except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader
return len(dataloader) * self.args.per_device_train_batch_size
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
"""HP search setup code""" """HP search setup code"""
...@@ -1198,9 +1200,6 @@ class Trainer: ...@@ -1198,9 +1200,6 @@ class Trainer:
self._move_model_to_device(self.model, args.device) self._move_model_to_device(self.model, args.device)
self.model_wrapped = self.model self.model_wrapped = self.model
# Keeping track whether we can can len() on the dataset or not
train_dataset_is_sized = has_length(self.train_dataset)
# Data loader and number of training steps # Data loader and number of training steps
train_dataloader = self.get_train_dataloader() train_dataloader = self.get_train_dataloader()
...@@ -1209,28 +1208,36 @@ class Trainer: ...@@ -1209,28 +1208,36 @@ class Trainer:
# number of training steps per epoch: num_update_steps_per_epoch # number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps # total number of training steps to execute: max_steps
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
if train_dataset_is_sized:
num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps len_dataloader = None
if has_length(train_dataloader):
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
num_examples = self.num_examples(train_dataloader)
if args.max_steps > 0: if args.max_steps > 0:
max_steps = args.max_steps max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0 args.max_steps % num_update_steps_per_epoch > 0
) )
# May be slightly incorrect if the last batch in the training datalaoder has a smaller size but it's # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# the best we can do. # the best we can do.
num_train_samples = args.max_steps * total_train_batch_size num_train_samples = args.max_steps * total_train_batch_size
else: else:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs) num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = len(self.train_dataset) * args.num_train_epochs num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
else: elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
# see __init__. max_steps is set when the dataset has no __len__
max_steps = args.max_steps max_steps = args.max_steps
# Setting a very large number of epochs so we go as many times as necessary over the iterator. # Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps num_update_steps_per_epoch = max_steps
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size num_train_samples = args.max_steps * total_train_batch_size
else:
raise ValueError(
f"args.max_steps must be set to a positive value if dataloader does not have a length, was {args.max_steps}"
)
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
...@@ -1281,10 +1288,6 @@ class Trainer: ...@@ -1281,10 +1288,6 @@ class Trainer:
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
# Train! # Train!
num_examples = (
self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps
)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}") logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Num Epochs = {num_train_epochs}")
...@@ -1370,7 +1373,7 @@ class Trainer: ...@@ -1370,7 +1373,7 @@ class Trainer:
for epoch in range(epochs_trained, num_train_epochs): for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch) train_dataloader.sampler.set_epoch(epoch)
elif isinstance(train_dataloader.dataset, IterableDatasetShard): elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
train_dataloader.dataset.set_epoch(epoch) train_dataloader.dataset.set_epoch(epoch)
if is_torch_tpu_available(): if is_torch_tpu_available():
...@@ -1384,7 +1387,9 @@ class Trainer: ...@@ -1384,7 +1387,9 @@ class Trainer:
self._past = None self._past = None
steps_in_epoch = ( steps_in_epoch = (
len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps len(epoch_iterator)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
) )
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
...@@ -2407,10 +2412,10 @@ class Trainer: ...@@ -2407,10 +2412,10 @@ class Trainer:
elif args.bf16_full_eval: elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device) model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = dataloader.batch_size batch_size = self.args.per_device_eval_batch_size
logger.info(f"***** Running {description} *****") logger.info(f"***** Running {description} *****")
if has_length(dataloader.dataset): if has_length(dataloader):
logger.info(f" Num examples = {self.num_examples(dataloader)}") logger.info(f" Num examples = {self.num_examples(dataloader)}")
else: else:
logger.info(" Num examples: Unknown") logger.info(" Num examples: Unknown")
...@@ -2420,7 +2425,7 @@ class Trainer: ...@@ -2420,7 +2425,7 @@ class Trainer:
self.callback_handler.eval_dataloader = dataloader self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping. # Do this before wrapping.
eval_dataset = dataloader.dataset eval_dataset = getattr(dataloader, "dataset", None)
if is_torch_tpu_available(): if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
...@@ -2512,6 +2517,9 @@ class Trainer: ...@@ -2512,6 +2517,9 @@ class Trainer:
elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"): elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
num_samples = eval_dataset.num_examples num_samples = eval_dataset.num_examples
else: else:
if has_length(dataloader):
num_samples = self.num_examples(dataloader)
else: # both len(dataloader.dataset) and len(dataloader) fail
num_samples = observed_num_examples num_samples = observed_num_examples
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
...@@ -2899,8 +2907,9 @@ class Trainer: ...@@ -2899,8 +2907,9 @@ class Trainer:
""" """
args = self.args args = self.args
if not has_length(dataloader.dataset): if not has_length(dataloader):
raise ValueError("dataset must implement __len__") raise ValueError("dataloader must implement a working __len__")
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
# if eval is called w/o train init deepspeed here # if eval is called w/o train init deepspeed here
......
...@@ -473,7 +473,7 @@ class ProgressCallback(TrainerCallback): ...@@ -473,7 +473,7 @@ class ProgressCallback(TrainerCallback):
self.current_step = state.global_step self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero and has_length(eval_dataloader.dataset): if state.is_local_process_zero and has_length(eval_dataloader):
if self.prediction_bar is None: if self.prediction_bar is None:
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None) self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
self.prediction_bar.update(1) self.prediction_bar.update(1)
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import re import re
import time import time
from typing import Optional from typing import Optional
...@@ -21,7 +20,7 @@ from typing import Optional ...@@ -21,7 +20,7 @@ from typing import Optional
import IPython.display as disp import IPython.display as disp
from ..trainer_callback import TrainerCallback from ..trainer_callback import TrainerCallback
from ..trainer_utils import IntervalStrategy from ..trainer_utils import IntervalStrategy, has_length
def format_time(t): def format_time(t):
...@@ -294,7 +293,7 @@ class NotebookProgressCallback(TrainerCallback): ...@@ -294,7 +293,7 @@ class NotebookProgressCallback(TrainerCallback):
self._force_next_update = False self._force_next_update = False
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if not isinstance(eval_dataloader.dataset, collections.abc.Sized): if not has_length(eval_dataloader):
return return
if self.prediction_bar is None: if self.prediction_bar is None:
if self.training_tracker is not None: if self.training_tracker is not None:
......
...@@ -189,6 +189,26 @@ if is_torch_available(): ...@@ -189,6 +189,26 @@ if is_torch_available():
yield self.dataset[self.current_sample] yield self.dataset[self.current_sample]
self.current_sample += 1 self.current_sample += 1
class MultiLoader:
def __init__(self, loaders):
self.loaders = loaders
def __len__(self):
return sum(len(loader) for loader in self.loaders)
def __iter__(self):
for loader in self.loaders:
yield from loader
class CustomDataloaderTrainer(Trainer):
def get_train_dataloader(self):
dataloaders = [super().get_train_dataloader(), super().get_train_dataloader()]
return MultiLoader(dataloaders)
def get_eval_dataloader(self, eval_dataset):
dataloaders = [super().get_eval_dataloader(eval_dataset), super().get_eval_dataloader(eval_dataset)]
return MultiLoader(dataloaders)
class RegressionModel(nn.Module): class RegressionModel(nn.Module):
def __init__(self, a=0, b=0, double_output=False): def __init__(self, a=0, b=0, double_output=False):
super().__init__() super().__init__()
...@@ -647,6 +667,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -647,6 +667,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
new_eval_dataset = RegressionDataset(length=128) new_eval_dataset = RegressionDataset(length=128)
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu)) self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
# tests that we do not require dataloader to have a .dataset attribute
def test_dataloader_without_dataset(self):
train_dataset = RegressionDataset(length=128)
trainer = CustomDataloaderTrainer(
model=RegressionModel(), train_dataset=train_dataset, eval_dataset=train_dataset
)
trainer.train()
trainer.evaluate()
def test_sampler_seed(self): def test_sampler_seed(self):
# nb: we don't want to inherit from IterableDataset to hit the right code path # nb: we don't want to inherit from IterableDataset to hit the right code path
class DummyDataset(torch.utils.data.Dataset): class DummyDataset(torch.utils.data.Dataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment