Unverified Commit 1affde2f authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make DataCollator a callable (#5015)



* Make DataCollator a callable

* Update src/transformers/data/data_collator.py
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent f7c93b3c
...@@ -38,7 +38,6 @@ from transformers import ( ...@@ -38,7 +38,6 @@ from transformers import (
BertConfig, BertConfig,
BertForSequenceClassification, BertForSequenceClassification,
BertTokenizer, BertTokenizer,
DefaultDataCollator,
DistilBertConfig, DistilBertConfig,
DistilBertForSequenceClassification, DistilBertForSequenceClassification,
DistilBertTokenizer, DistilBertTokenizer,
...@@ -51,6 +50,7 @@ from transformers import ( ...@@ -51,6 +50,7 @@ from transformers import (
XLNetConfig, XLNetConfig,
XLNetForSequenceClassification, XLNetForSequenceClassification,
XLNetTokenizer, XLNetTokenizer,
default_data_collator,
get_linear_schedule_with_warmup, get_linear_schedule_with_warmup,
) )
from utils_hans import HansDataset, hans_output_modes, hans_processors from utils_hans import HansDataset, hans_output_modes, hans_processors
...@@ -91,10 +91,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -91,10 +91,7 @@ def train(args, train_dataset, model, tokenizer):
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_dataset, train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=default_data_collator,
sampler=train_sampler,
batch_size=args.train_batch_size,
collate_fn=DefaultDataCollator().collate_batch,
) )
if args.max_steps > 0: if args.max_steps > 0:
...@@ -252,10 +249,7 @@ def evaluate(args, model, tokenizer, label_list, prefix=""): ...@@ -252,10 +249,7 @@ def evaluate(args, model, tokenizer, label_list, prefix=""):
# Note that DistributedSampler samples randomly # Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset) eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader( eval_dataloader = DataLoader(
eval_dataset, eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=default_data_collator,
sampler=eval_sampler,
batch_size=args.eval_batch_size,
collate_fn=DefaultDataCollator().collate_batch,
) )
# multi-gpu eval # multi-gpu eval
......
...@@ -34,8 +34,8 @@ from transformers import ( ...@@ -34,8 +34,8 @@ from transformers import (
AutoConfig, AutoConfig,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoTokenizer, AutoTokenizer,
DefaultDataCollator,
GlueDataset, GlueDataset,
default_data_collator,
glue_compute_metrics, glue_compute_metrics,
glue_output_modes, glue_output_modes,
glue_processors, glue_processors,
...@@ -424,7 +424,7 @@ def main(): ...@@ -424,7 +424,7 @@ def main():
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset))))) eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader( eval_dataloader = DataLoader(
eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=DefaultDataCollator().collate_batch eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=default_data_collator
) )
# Compute head entropy and importance score # Compute head entropy and importance score
......
...@@ -364,7 +364,7 @@ if is_torch_available(): ...@@ -364,7 +364,7 @@ if is_torch_available():
# Trainer # Trainer
from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction
from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
# Benchmarks # Benchmarks
......
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, NewType, Tuple from typing import Any, Callable, Dict, List, NewType, Tuple
import torch import torch
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
...@@ -8,28 +7,16 @@ from torch.nn.utils.rnn import pad_sequence ...@@ -8,28 +7,16 @@ from torch.nn.utils.rnn import pad_sequence
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
class DataCollator(ABC):
"""
A `DataCollator` is responsible for batching
and pre-processing samples of data as requested by the training loop.
"""
@abstractmethod
def collate_batch(self) -> Dict[str, torch.Tensor]:
"""
Take a list of samples from a Dataset and collate them into a batch.
Returns:
A dictionary of tensors
"""
pass
InputDataClass = NewType("InputDataClass", Any) InputDataClass = NewType("InputDataClass", Any)
"""
A DataCollator is a function that takes a list of samples from a Dataset
and collate them into a batch, as a dictionary of Tensors.
"""
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])
@dataclass
class DefaultDataCollator(DataCollator): def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
""" """
Very simple data collator that: Very simple data collator that:
- simply collates batches of dict-like objects - simply collates batches of dict-like objects
...@@ -42,41 +29,40 @@ class DefaultDataCollator(DataCollator): ...@@ -42,41 +29,40 @@ class DefaultDataCollator(DataCollator):
See glue and ner for example of how it's useful. See glue and ner for example of how it's useful.
""" """
def collate_batch(self, features: List[InputDataClass]) -> Dict[str, torch.Tensor]: # In this function we'll make the assumption that all `features` in the batch
# In this method we'll make the assumption that all `features` in the batch # have the same attributes.
# have the same attributes. # So we will look at the first element as a proxy for what attributes exist
# So we will look at the first element as a proxy for what attributes exist # on the whole batch.
# on the whole batch. first = features[0]
first = features[0]
# Special handling for labels.
# Special handling for labels. # Ensure that tensor is created with the correct type
# Ensure that tensor is created with the correct type # (it should be automatically the case, but let's make sure of it.)
# (it should be automatically the case, but let's make sure of it.) if hasattr(first, "label") and first.label is not None:
if hasattr(first, "label") and first.label is not None: if type(first.label) is int:
if type(first.label) is int: labels = torch.tensor([f.label for f in features], dtype=torch.long)
labels = torch.tensor([f.label for f in features], dtype=torch.long) else:
else: labels = torch.tensor([f.label for f in features], dtype=torch.float)
labels = torch.tensor([f.label for f in features], dtype=torch.float) batch = {"labels": labels}
batch = {"labels": labels} elif hasattr(first, "label_ids") and first.label_ids is not None:
elif hasattr(first, "label_ids") and first.label_ids is not None: if type(first.label_ids[0]) is int:
if type(first.label_ids[0]) is int: labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
batch = {"labels": labels}
else: else:
batch = {} labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
batch = {"labels": labels}
else:
batch = {}
# Handling of all other possible attributes. # Handling of all other possible attributes.
# Again, we will use the first element to figure out which key/values are not None for this model. # Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in vars(first).items(): for k, v in vars(first).items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long) batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
return batch return batch
@dataclass @dataclass
class DataCollatorForLanguageModeling(DataCollator): class DataCollatorForLanguageModeling:
""" """
Data collator used for language modeling. Data collator used for language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token - collates batches of tensors, honoring their tokenizer's pad_token
...@@ -87,7 +73,7 @@ class DataCollatorForLanguageModeling(DataCollator): ...@@ -87,7 +73,7 @@ class DataCollatorForLanguageModeling(DataCollator):
mlm: bool = True mlm: bool = True
mlm_probability: float = 0.15 mlm_probability: float = 0.15
def collate_batch(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]: def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
batch = self._tensorize_batch(examples) batch = self._tensorize_batch(examples)
if self.mlm: if self.mlm:
inputs, labels = self.mask_tokens(batch) inputs, labels = self.mask_tokens(batch)
......
...@@ -19,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -19,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, DefaultDataCollator from .data.data_collator import DataCollator, default_data_collator
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
...@@ -190,10 +190,7 @@ class Trainer: ...@@ -190,10 +190,7 @@ class Trainer:
""" """
self.model = model.to(args.device) self.model = model.to(args.device)
self.args = args self.args = args
if data_collator is not None: self.data_collator = data_collator if data_collator is not None else default_data_collator
self.data_collator = data_collator
else:
self.data_collator = DefaultDataCollator()
self.train_dataset = train_dataset self.train_dataset = train_dataset
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
...@@ -239,7 +236,7 @@ class Trainer: ...@@ -239,7 +236,7 @@ class Trainer:
self.train_dataset, self.train_dataset,
batch_size=self.args.train_batch_size, batch_size=self.args.train_batch_size,
sampler=train_sampler, sampler=train_sampler,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
) )
...@@ -264,7 +261,7 @@ class Trainer: ...@@ -264,7 +261,7 @@ class Trainer:
eval_dataset, eval_dataset,
sampler=sampler, sampler=sampler,
batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
) )
...@@ -285,7 +282,7 @@ class Trainer: ...@@ -285,7 +282,7 @@ class Trainer:
test_dataset, test_dataset,
sampler=sampler, sampler=sampler,
batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
) )
......
...@@ -11,7 +11,7 @@ if is_torch_available(): ...@@ -11,7 +11,7 @@ if is_torch_available():
Trainer, Trainer,
LineByLineTextDataset, LineByLineTextDataset,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
DefaultDataCollator, default_data_collator,
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
...@@ -31,8 +31,8 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -31,8 +31,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
) )
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = DefaultDataCollator() data_collator = default_data_collator
batch = data_collator.collate_batch(dataset.features) batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.long) self.assertEqual(batch["labels"].dtype, torch.long)
def test_default_regression(self): def test_default_regression(self):
...@@ -42,8 +42,8 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -42,8 +42,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
) )
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = DefaultDataCollator() data_collator = default_data_collator
batch = data_collator.collate_batch(dataset.features) batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.float) self.assertEqual(batch["labels"].dtype, torch.float)
def test_lm_tokenizer_without_padding(self): def test_lm_tokenizer_without_padding(self):
...@@ -55,11 +55,11 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -55,11 +55,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
examples = [dataset[i] for i in range(len(dataset))] examples = [dataset[i] for i in range(len(dataset))]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Expect error due to padding token missing on gpt2: # Expect error due to padding token missing on gpt2:
data_collator.collate_batch(examples) data_collator(examples)
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True) dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))] examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples) batch = data_collator(examples)
self.assertIsInstance(batch, dict) self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
...@@ -71,14 +71,14 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -71,14 +71,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512) dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))] examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples) batch = data_collator(examples)
self.assertIsInstance(batch, dict) self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107))) self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 107))) self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True) dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))] examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples) batch = data_collator(examples)
self.assertIsInstance(batch, dict) self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
......
...@@ -29,7 +29,7 @@ if is_torch_available(): ...@@ -29,7 +29,7 @@ if is_torch_available():
from torch import nn from torch import nn
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
from transformers import DataCollator, Trainer from transformers import Trainer
class DummyDataset(Dataset): class DummyDataset(Dataset):
def __init__(self, length: int = 101): def __init__(self, length: int = 101):
...@@ -41,8 +41,8 @@ if is_torch_available(): ...@@ -41,8 +41,8 @@ if is_torch_available():
def __getitem__(self, i) -> int: def __getitem__(self, i) -> int:
return i return i
class DummyDataCollator(DataCollator): class DummyDataCollator:
def collate_batch(self, features): def __call__(self, features):
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)} return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
class DummyModel(nn.Module): class DummyModel(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment