"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "803a8cd18f44e7288187b82ae3b48956edf11dd6"
Unverified Commit 75b13f82 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

[Trainer] Deeper length checks for IterableDatasetShard (#15539)



* Unused import

* Make `has_length()` torch-independent to use in callbacks

* Update src/transformers/trainer_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 84eec9e6
......@@ -16,7 +16,6 @@
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import collections
import contextlib
import inspect
import math
......@@ -54,7 +53,7 @@ import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from huggingface_hub import Repository
......@@ -126,6 +125,7 @@ from .trainer_utils import (
default_hp_space,
denumpify_detensorize,
get_last_checkpoint,
has_length,
number_of_arguments,
set_seed,
speed_metrics,
......@@ -429,7 +429,7 @@ class Trainer:
if args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
if (
......@@ -585,7 +585,7 @@ class Trainer:
return dataset.remove_columns(ignored_columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized):
if not has_length(self.train_dataset):
return None
generator = None
......@@ -1190,7 +1190,7 @@ class Trainer:
self.model_wrapped = self.model
# Keeping track whether we can can len() on the dataset or not
train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)
train_dataset_is_sized = has_length(self.train_dataset)
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
......@@ -2383,7 +2383,7 @@ class Trainer:
batch_size = dataloader.batch_size
logger.info(f"***** Running {description} *****")
if isinstance(dataloader.dataset, collections.abc.Sized):
if has_length(dataloader.dataset):
logger.info(f" Num examples = {self.num_examples(dataloader)}")
else:
logger.info(" Num examples: Unknown")
......@@ -2478,7 +2478,7 @@ class Trainer:
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
# Number of samples
if not isinstance(eval_dataset, IterableDataset):
if has_length(eval_dataset):
num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute.
......@@ -2872,7 +2872,7 @@ class Trainer:
"""
args = self.args
if not isinstance(dataloader.dataset, collections.abc.Sized):
if not has_length(dataloader.dataset):
raise ValueError("dataset must implement __len__")
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
......
......@@ -15,7 +15,6 @@
"""
Callbacks to use with the Trainer class and customize the training loop.
"""
import collections
import dataclasses
import json
from dataclasses import dataclass
......@@ -24,7 +23,7 @@ from typing import Dict, List, Optional, Union
import numpy as np
from tqdm.auto import tqdm
from .trainer_utils import IntervalStrategy
from .trainer_utils import IntervalStrategy, has_length
from .training_args import TrainingArguments
from .utils import logging
......@@ -470,7 +469,7 @@ class ProgressCallback(TrainerCallback):
self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero and isinstance(eval_dataloader.dataset, collections.abc.Sized):
if state.is_local_process_zero and has_length(eval_dataloader.dataset):
if self.prediction_bar is None:
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
self.prediction_bar.update(1)
......
......@@ -519,6 +519,17 @@ class TrainerMemoryTracker:
self.update_metrics(stage, metrics)
def has_length(dataset):
"""
Checks if the dataset implements __len__() and it doesn't raise an error
"""
try:
return len(dataset) is not None
except TypeError:
# TypeError: len() of unsized object
return False
def denumpify_detensorize(metrics):
"""
Recursively calls `.item()` on the element of the dictionary passed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment