Unverified Commit 75b13f82 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

[Trainer] Deeper length checks for IterableDatasetShard (#15539)



* Unused import

* Make `has_length()` torch-independent to use in callbacks

* Update src/transformers/trainer_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 84eec9e6
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
""" """
import collections
import contextlib import contextlib
import inspect import inspect
import math import math
...@@ -54,7 +53,7 @@ import numpy as np ...@@ -54,7 +53,7 @@ import numpy as np
import torch import torch
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from huggingface_hub import Repository from huggingface_hub import Repository
...@@ -126,6 +125,7 @@ from .trainer_utils import ( ...@@ -126,6 +125,7 @@ from .trainer_utils import (
default_hp_space, default_hp_space,
denumpify_detensorize, denumpify_detensorize,
get_last_checkpoint, get_last_checkpoint,
has_length,
number_of_arguments, number_of_arguments,
set_seed, set_seed,
speed_metrics, speed_metrics,
...@@ -429,7 +429,7 @@ class Trainer: ...@@ -429,7 +429,7 @@ class Trainer:
if args.max_steps > 0: if args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs") logger.info("max_steps is given, it will override any value given in num_train_epochs")
if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0: if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
if ( if (
...@@ -585,7 +585,7 @@ class Trainer: ...@@ -585,7 +585,7 @@ class Trainer:
return dataset.remove_columns(ignored_columns) return dataset.remove_columns(ignored_columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized): if not has_length(self.train_dataset):
return None return None
generator = None generator = None
...@@ -1190,7 +1190,7 @@ class Trainer: ...@@ -1190,7 +1190,7 @@ class Trainer:
self.model_wrapped = self.model self.model_wrapped = self.model
# Keeping track whether we can can len() on the dataset or not # Keeping track whether we can can len() on the dataset or not
train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized) train_dataset_is_sized = has_length(self.train_dataset)
# Data loader and number of training steps # Data loader and number of training steps
train_dataloader = self.get_train_dataloader() train_dataloader = self.get_train_dataloader()
...@@ -2383,7 +2383,7 @@ class Trainer: ...@@ -2383,7 +2383,7 @@ class Trainer:
batch_size = dataloader.batch_size batch_size = dataloader.batch_size
logger.info(f"***** Running {description} *****") logger.info(f"***** Running {description} *****")
if isinstance(dataloader.dataset, collections.abc.Sized): if has_length(dataloader.dataset):
logger.info(f" Num examples = {self.num_examples(dataloader)}") logger.info(f" Num examples = {self.num_examples(dataloader)}")
else: else:
logger.info(" Num examples: Unknown") logger.info(" Num examples: Unknown")
...@@ -2478,7 +2478,7 @@ class Trainer: ...@@ -2478,7 +2478,7 @@ class Trainer:
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
# Number of samples # Number of samples
if not isinstance(eval_dataset, IterableDataset): if has_length(eval_dataset):
num_samples = len(eval_dataset) num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right # The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute. # methods. Therefore we need to make sure it also has the attribute.
...@@ -2872,7 +2872,7 @@ class Trainer: ...@@ -2872,7 +2872,7 @@ class Trainer:
""" """
args = self.args args = self.args
if not isinstance(dataloader.dataset, collections.abc.Sized): if not has_length(dataloader.dataset):
raise ValueError("dataset must implement __len__") raise ValueError("dataset must implement __len__")
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
""" """
Callbacks to use with the Trainer class and customize the training loop. Callbacks to use with the Trainer class and customize the training loop.
""" """
import collections
import dataclasses import dataclasses
import json import json
from dataclasses import dataclass from dataclasses import dataclass
...@@ -24,7 +23,7 @@ from typing import Dict, List, Optional, Union ...@@ -24,7 +23,7 @@ from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
from .trainer_utils import IntervalStrategy from .trainer_utils import IntervalStrategy, has_length
from .training_args import TrainingArguments from .training_args import TrainingArguments
from .utils import logging from .utils import logging
...@@ -470,7 +469,7 @@ class ProgressCallback(TrainerCallback): ...@@ -470,7 +469,7 @@ class ProgressCallback(TrainerCallback):
self.current_step = state.global_step self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero and isinstance(eval_dataloader.dataset, collections.abc.Sized): if state.is_local_process_zero and has_length(eval_dataloader.dataset):
if self.prediction_bar is None: if self.prediction_bar is None:
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None) self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
self.prediction_bar.update(1) self.prediction_bar.update(1)
......
...@@ -519,6 +519,17 @@ class TrainerMemoryTracker: ...@@ -519,6 +519,17 @@ class TrainerMemoryTracker:
self.update_metrics(stage, metrics) self.update_metrics(stage, metrics)
def has_length(dataset):
"""
Checks if the dataset implements __len__() and it doesn't raise an error
"""
try:
return len(dataset) is not None
except TypeError:
# TypeError: len() of unsized object
return False
def denumpify_detensorize(metrics): def denumpify_detensorize(metrics):
""" """
Recursively calls `.item()` on the element of the dictionary passed Recursively calls `.item()` on the element of the dictionary passed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment