Unverified Commit 9c4aa4ac authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean up data collators and datasets (#8308)



* Clean up data collators and datasets

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Remove needless clone
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent b1d3e95e
......@@ -264,7 +264,15 @@ def main():
def tokenize_function(examples):
# Remove empty lines
examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
return tokenizer(examples["text"], padding=padding, truncation=True, max_length=data_args.max_seq_length)
return tokenizer(
examples["text"],
padding=padding,
truncation=True,
max_length=data_args.max_seq_length,
# We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
# receives the `special_tokens_mask`.
return_special_tokens_mask=True,
)
tokenized_datasets = datasets.map(
tokenize_function,
......@@ -275,8 +283,10 @@ def main():
)
else:
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
# efficient when it receives the `special_tokens_mask`.
def tokenize_function(examples):
return tokenizer(examples[text_column_name])
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
tokenized_datasets = datasets.map(
tokenize_function,
......
......@@ -281,7 +281,6 @@ if is_torch_available():
from .data.data_collator import (
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
DataCollatorForTokenClassification,
......
import random
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
......@@ -175,72 +176,111 @@ class DataCollatorForTokenClassification:
return batch
def _collate_batch(examples, tokenizer):
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
# Tensorize if necessary.
if isinstance(examples[0], (list, tuple)):
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
# Check if padding is necessary.
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
# If yes, check if we have a `pad_token`.
if tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({tokenizer.__class__.__name__}) does not have a pad token."
)
# Creating the full tensor and filling it with our data.
max_length = max(x.size(0) for x in examples)
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
for i, example in enumerate(examples):
if tokenizer.padding_side == "right":
result[i, : example.shape[0]] = example
else:
result[i, -example.shape[0] :] = example
return result
@dataclass
class DataCollatorForLanguageModeling:
"""
Data collator used for language modeling.
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
are not all of the same length.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for masked language modeling
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
non-masked tokens and the value to predict for the masked token.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
.. note::
For best performance, this data collator should be used with a dataset having items that are dictionaries or
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
argument :obj:`return_special_tokens_mask=True`.
"""
tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15
def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
)
def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples)
batch = self.tokenizer.pad(examples, return_tensors="pt")
else:
batch = {"input_ids": _collate_batch(examples, self.tokenizer)}
# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
if self.mlm:
inputs, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "labels": labels}
batch["input_ids"], batch["labels"] = self.mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
else:
labels = batch.clone().detach()
labels = batch["input_ids"]
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
return {"input_ids": batch, "labels": labels}
def _tensorize_batch(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> torch.Tensor:
# In order to accept both lists of lists and lists of Tensors
if isinstance(examples[0], (list, tuple)):
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
batch["labels"] = labels
return batch
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def mask_tokens(
self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
......@@ -385,9 +425,16 @@ class DataCollatorForSOP(DataCollatorForLanguageModeling):
- preprocesses batches for both masked language modeling and sentence order prediction
"""
def __init__(self, *args, **kwargs):
warnings.warn(
"DataCollatorForSOP is deprecated and will be removed in a future version, you can now use "
"DataCollatorForLanguageModeling instead.",
FutureWarning,
)
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
input_ids = [example["input_ids"] for example in examples]
input_ids = self._tensorize_batch(input_ids)
input_ids = _collate_batch(input_ids, self.tokenizer)
input_ids, labels, attention_mask = self.mask_tokens(input_ids)
token_type_ids = [example["token_type_ids"] for example in examples]
......@@ -582,136 +629,3 @@ class DataCollatorForPermutationLanguageModeling:
) & masked_indices[i]
return inputs.long(), perm_mask, target_mapping, labels.long()
@dataclass
class DataCollatorForNextSentencePrediction:
"""
Data collator used for next sentence prediction. - collates examples which contains pre-generated negative examples
- preprocesses batches for masked language modeling
"""
tokenizer: PreTrainedTokenizerBase
mlm: bool = True
block_size: int = 512
short_seq_probability: float = 0.1
nsp_probability: float = 0.5
mlm_probability: float = 0.15
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""
The input should contain negative examples, :class:`~transformers.DataCollatorForNextSentencePrediction` will
not generate any negative examples
Args:
examples (:obj:`List[Dict]`): Each dictionary should have the following keys:
- ``tokens_a``: A sequence of tokens, which should appear before ``tokens_b`` in the text.
- ``tokens_b``: A sequence of tokens, which should appear after ``tokens_a`` in the text.
- ``is_random_next``: 1 if this pair is generated randomly, else 0.
"""
tokens_a = [e["tokens_a"] for e in examples]
tokens_b = [e["tokens_b"] for e in examples]
nsp_labels = [1 if e["is_random_next"] else 0 for e in examples]
input_ids = []
segment_ids = []
attention_masks = []
assert len(tokens_a) == len(tokens_b)
for i in range(len(tokens_a)):
input_id, attention_mask, segment_id = self.create_features_from_example(tokens_a[i], tokens_b[i])
input_ids.append(input_id)
segment_ids.append(segment_id)
attention_masks.append(attention_mask)
if self.mlm:
input_ids, mlm_labels = self.mask_tokens(self._tensorize_batch(input_ids))
else:
input_ids = self._tensorize_batch(input_ids)
result = {
"input_ids": input_ids,
"attention_mask": self._tensorize_batch(attention_masks),
"token_type_ids": self._tensorize_batch(segment_ids),
"labels": mlm_labels if self.mlm else None,
"next_sentence_label": torch.tensor(nsp_labels),
}
return result
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
def create_features_from_example(self, tokens_a, tokens_b):
"""Creates examples for a single document."""
max_num_tokens = self.block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
tokens_a, tokens_b, _ = self.tokenizer.truncate_sequences(
tokens_a,
tokens_b,
num_tokens_to_remove=len(tokens_a) + len(tokens_b) - max_num_tokens,
truncation_strategy="longest_first",
)
input_id = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
attention_mask = [1] * len(input_id)
segment_id = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
assert len(input_id) <= self.block_size
# pad
while len(input_id) < self.block_size:
input_id.append(0)
attention_mask.append(0)
segment_id.append(0)
input_id = torch.tensor(input_id)
attention_mask = torch.tensor(attention_mask)
segment_id = torch.tensor(segment_id)
return input_id, attention_mask, segment_id
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
......@@ -3,6 +3,7 @@ import os
import pickle
import random
import time
import warnings
from typing import Dict, List, Optional
import torch
......@@ -17,6 +18,11 @@ from ...utils import logging
logger = logging.get_logger(__name__)
DEPRECATION_WARNING = (
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library."
)
class TextDataset(Dataset):
"""
This will be superseded by a framework-agnostic approach soon.
......@@ -30,6 +36,7 @@ class TextDataset(Dataset):
overwrite_cache=False,
cache_dir: Optional[str] = None,
):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
......@@ -94,6 +101,7 @@ class LineByLineTextDataset(Dataset):
"""
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
# Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the
......@@ -120,6 +128,7 @@ class LineByLineWithRefDataset(Dataset):
"""
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
assert os.path.isfile(ref_path), f"Ref file path {file_path} not found"
# Here, we do not cache the features, operating under the assumption
......@@ -156,6 +165,7 @@ class LineByLineWithSOPTextDataset(Dataset):
"""
def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
assert os.path.isdir(file_dir)
logger.info(f"Creating features from dataset file folder at {file_dir}")
self.examples = []
......@@ -305,6 +315,7 @@ class TextDatasetForNextSentencePrediction(Dataset):
short_seq_probability=0.1,
nsp_probability=0.5,
):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
self.block_size = block_size - tokenizer.num_special_tokens_to_add(pair=True)
......@@ -449,9 +460,18 @@ class TextDatasetForNextSentencePrediction(Dataset):
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
self.examples.append(
{"tokens_a": tokens_a, "tokens_b": tokens_b, "is_random_next": is_random_next}
)
# add special tokens
input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
# add token type ids, 0 for sentence a, 1 for sentence b
token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
example = {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
"next_sentence_label": torch.tensor(1 if is_random_next else 0, dtype=torch.long),
}
self.examples.append(example)
current_chunk = []
current_length = 0
......
......@@ -26,11 +26,6 @@ class DataCollatorForLanguageModeling:
requires_pytorch(self)
class DataCollatorForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class DataCollatorForPermutationLanguageModeling:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
......
......@@ -12,9 +12,7 @@ if is_torch_available():
from transformers import (
DataCollatorForLanguageModeling,
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorWithPadding,
default_data_collator,
......@@ -201,13 +199,16 @@ class DataCollatorIntegrationTest(unittest.TestCase):
def test_nsp(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [{"tokens_a": [0, 1, 2, 3, 4], "tokens_b": [0, 1, 2, 3, 4], "is_random_next": i} for i in range(2)]
data_collator = DataCollatorForNextSentencePrediction(tokenizer)
features = [
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
for i in range(2)
]
data_collator = DataCollatorForLanguageModeling(tokenizer)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
def test_sop(self):
......@@ -216,11 +217,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
{
"input_ids": torch.tensor([0, 1, 2, 3, 4]),
"token_type_ids": torch.tensor([0, 1, 2, 3, 4]),
"sentence_order_label": torch.tensor(i),
"sentence_order_label": i,
}
for i in range(2)
]
data_collator = DataCollatorForSOP(tokenizer)
data_collator = DataCollatorForLanguageModeling(tokenizer)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment