Unverified Commit 1affde2f authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make DataCollator a callable (#5015)



* Make DataCollator a callable

* Update src/transformers/data/data_collator.py
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent f7c93b3c
......@@ -38,7 +38,6 @@ from transformers import (
BertConfig,
BertForSequenceClassification,
BertTokenizer,
DefaultDataCollator,
DistilBertConfig,
DistilBertForSequenceClassification,
DistilBertTokenizer,
......@@ -51,6 +50,7 @@ from transformers import (
XLNetConfig,
XLNetForSequenceClassification,
XLNetTokenizer,
default_data_collator,
get_linear_schedule_with_warmup,
)
from utils_hans import HansDataset, hans_output_modes, hans_processors
......@@ -91,10 +91,7 @@ def train(args, train_dataset, model, tokenizer):
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=args.train_batch_size,
collate_fn=DefaultDataCollator().collate_batch,
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=default_data_collator,
)
if args.max_steps > 0:
......@@ -252,10 +249,7 @@ def evaluate(args, model, tokenizer, label_list, prefix=""):
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(
eval_dataset,
sampler=eval_sampler,
batch_size=args.eval_batch_size,
collate_fn=DefaultDataCollator().collate_batch,
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=default_data_collator,
)
# multi-gpu eval
......
......@@ -34,8 +34,8 @@ from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
DefaultDataCollator,
GlueDataset,
default_data_collator,
glue_compute_metrics,
glue_output_modes,
glue_processors,
......@@ -424,7 +424,7 @@ def main():
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(
eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=DefaultDataCollator().collate_batch
eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=default_data_collator
)
# Compute head entropy and importance score
......
......@@ -364,7 +364,7 @@ if is_torch_available():
# Trainer
from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction
from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling
from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
# Benchmarks
......
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, NewType, Tuple
from typing import Any, Callable, Dict, List, NewType, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
......@@ -8,28 +7,16 @@ from torch.nn.utils.rnn import pad_sequence
from ..tokenization_utils import PreTrainedTokenizer
class DataCollator(ABC):
"""
A `DataCollator` is responsible for batching
and pre-processing samples of data as requested by the training loop.
"""
@abstractmethod
def collate_batch(self) -> Dict[str, torch.Tensor]:
"""
Take a list of samples from a Dataset and collate them into a batch.
Returns:
A dictionary of tensors
"""
pass
InputDataClass = NewType("InputDataClass", Any)
"""
A DataCollator is a function that takes a list of samples from a Dataset
and collate them into a batch, as a dictionary of Tensors.
"""
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])
@dataclass
class DefaultDataCollator(DataCollator):
def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
"""
Very simple data collator that:
- simply collates batches of dict-like objects
......@@ -42,41 +29,40 @@ class DefaultDataCollator(DataCollator):
See glue and ner for example of how it's useful.
"""
def collate_batch(self, features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
# In this method we'll make the assumption that all `features` in the batch
# have the same attributes.
# So we will look at the first element as a proxy for what attributes exist
# on the whole batch.
first = features[0]
# Special handling for labels.
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if hasattr(first, "label") and first.label is not None:
if type(first.label) is int:
labels = torch.tensor([f.label for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label for f in features], dtype=torch.float)
batch = {"labels": labels}
elif hasattr(first, "label_ids") and first.label_ids is not None:
if type(first.label_ids[0]) is int:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
batch = {"labels": labels}
# In this function we'll make the assumption that all `features` in the batch
# have the same attributes.
# So we will look at the first element as a proxy for what attributes exist
# on the whole batch.
first = features[0]
# Special handling for labels.
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if hasattr(first, "label") and first.label is not None:
if type(first.label) is int:
labels = torch.tensor([f.label for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label for f in features], dtype=torch.float)
batch = {"labels": labels}
elif hasattr(first, "label_ids") and first.label_ids is not None:
if type(first.label_ids[0]) is int:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
else:
batch = {}
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
batch = {"labels": labels}
else:
batch = {}
# Handling of all other possible attributes.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in vars(first).items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
return batch
# Handling of all other possible attributes.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in vars(first).items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
return batch
@dataclass
class DataCollatorForLanguageModeling(DataCollator):
class DataCollatorForLanguageModeling:
"""
Data collator used for language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token
......@@ -87,7 +73,7 @@ class DataCollatorForLanguageModeling(DataCollator):
mlm: bool = True
mlm_probability: float = 0.15
def collate_batch(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
batch = self._tensorize_batch(examples)
if self.mlm:
inputs, labels = self.mask_tokens(batch)
......
......@@ -19,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, DefaultDataCollator
from .data.data_collator import DataCollator, default_data_collator
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
......@@ -190,10 +190,7 @@ class Trainer:
"""
self.model = model.to(args.device)
self.args = args
if data_collator is not None:
self.data_collator = data_collator
else:
self.data_collator = DefaultDataCollator()
self.data_collator = data_collator if data_collator is not None else default_data_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics
......@@ -239,7 +236,7 @@ class Trainer:
self.train_dataset,
batch_size=self.args.train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator.collate_batch,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
)
......@@ -264,7 +261,7 @@ class Trainer:
eval_dataset,
sampler=sampler,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator.collate_batch,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
)
......@@ -285,7 +282,7 @@ class Trainer:
test_dataset,
sampler=sampler,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator.collate_batch,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
)
......
......@@ -11,7 +11,7 @@ if is_torch_available():
Trainer,
LineByLineTextDataset,
AutoModelForSequenceClassification,
DefaultDataCollator,
default_data_collator,
DataCollatorForLanguageModeling,
GlueDataset,
GlueDataTrainingArguments,
......@@ -31,8 +31,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = DefaultDataCollator()
batch = data_collator.collate_batch(dataset.features)
data_collator = default_data_collator
batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.long)
def test_default_regression(self):
......@@ -42,8 +42,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = DefaultDataCollator()
batch = data_collator.collate_batch(dataset.features)
data_collator = default_data_collator
batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.float)
def test_lm_tokenizer_without_padding(self):
......@@ -55,11 +55,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
examples = [dataset[i] for i in range(len(dataset))]
with self.assertRaises(ValueError):
# Expect error due to padding token missing on gpt2:
data_collator.collate_batch(examples)
data_collator(examples)
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples)
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
......@@ -71,14 +71,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples)
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples)
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
......
......@@ -29,7 +29,7 @@ if is_torch_available():
from torch import nn
from torch.utils.data.dataset import Dataset
from transformers import DataCollator, Trainer
from transformers import Trainer
class DummyDataset(Dataset):
def __init__(self, length: int = 101):
......@@ -41,8 +41,8 @@ if is_torch_available():
def __getitem__(self, i) -> int:
return i
class DummyDataCollator(DataCollator):
def collate_batch(self, features):
class DummyDataCollator:
def __call__(self, features):
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
class DummyModel(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment