"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "70a127a37a1d168898ec5631872a7aadeec6176a"
Unverified Commit 9c4aa4ac authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean up data collators and datasets (#8308)



* Clean up data collators and datasets

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Remove needless clone
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent b1d3e95e
...@@ -264,7 +264,15 @@ def main(): ...@@ -264,7 +264,15 @@ def main():
def tokenize_function(examples): def tokenize_function(examples):
# Remove empty lines # Remove empty lines
examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()] examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
return tokenizer(examples["text"], padding=padding, truncation=True, max_length=data_args.max_seq_length) return tokenizer(
examples["text"],
padding=padding,
truncation=True,
max_length=data_args.max_seq_length,
# We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
# receives the `special_tokens_mask`.
return_special_tokens_mask=True,
)
tokenized_datasets = datasets.map( tokenized_datasets = datasets.map(
tokenize_function, tokenize_function,
...@@ -275,8 +283,10 @@ def main(): ...@@ -275,8 +283,10 @@ def main():
) )
else: else:
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
# efficient when it receives the `special_tokens_mask`.
def tokenize_function(examples): def tokenize_function(examples):
return tokenizer(examples[text_column_name]) return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
tokenized_datasets = datasets.map( tokenized_datasets = datasets.map(
tokenize_function, tokenize_function,
......
...@@ -281,7 +281,6 @@ if is_torch_available(): ...@@ -281,7 +281,6 @@ if is_torch_available():
from .data.data_collator import ( from .data.data_collator import (
DataCollator, DataCollator,
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling, DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP, DataCollatorForSOP,
DataCollatorForTokenClassification, DataCollatorForTokenClassification,
......
import random import random
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
...@@ -175,72 +176,111 @@ class DataCollatorForTokenClassification: ...@@ -175,72 +176,111 @@ class DataCollatorForTokenClassification:
return batch return batch
def _collate_batch(examples, tokenizer):
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
# Tensorize if necessary.
if isinstance(examples[0], (list, tuple)):
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
# Check if padding is necessary.
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
# If yes, check if we have a `pad_token`.
if tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({tokenizer.__class__.__name__}) does not have a pad token."
)
# Creating the full tensor and filling it with our data.
max_length = max(x.size(0) for x in examples)
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
for i, example in enumerate(examples):
if tokenizer.padding_side == "right":
result[i, : example.shape[0]] = example
else:
result[i, -example.shape[0] :] = example
return result
@dataclass @dataclass
class DataCollatorForLanguageModeling: class DataCollatorForLanguageModeling:
""" """
Data collator used for language modeling. Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
are not all of the same length.
- collates batches of tensors, honoring their tokenizer's pad_token Args:
- preprocesses batches for masked language modeling tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
non-masked tokens and the value to predict for the masked token.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
.. note::
For best performance, this data collator should be used with a dataset having items that are dictionaries or
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
argument :obj:`return_special_tokens_mask=True`.
""" """
tokenizer: PreTrainedTokenizerBase tokenizer: PreTrainedTokenizerBase
mlm: bool = True mlm: bool = True
mlm_probability: float = 0.15 mlm_probability: float = 0.15
def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
)
def __call__( def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples] batch = self.tokenizer.pad(examples, return_tensors="pt")
batch = self._tensorize_batch(examples) else:
batch = {"input_ids": _collate_batch(examples, self.tokenizer)}
# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
if self.mlm: if self.mlm:
inputs, labels = self.mask_tokens(batch) batch["input_ids"], batch["labels"] = self.mask_tokens(
return {"input_ids": inputs, "labels": labels} batch["input_ids"], special_tokens_mask=special_tokens_mask
)
else: else:
labels = batch.clone().detach() labels = batch["input_ids"]
if self.tokenizer.pad_token_id is not None: if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100 labels[labels == self.tokenizer.pad_token_id] = -100
return {"input_ids": batch, "labels": labels} batch["labels"] = labels
return batch
def _tensorize_batch(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> torch.Tensor:
# In order to accept both lists of lists and lists of Tensors
if isinstance(examples[0], (list, tuple)):
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def mask_tokens(
self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
""" """
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone() labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, self.mlm_probability) probability_matrix = torch.full(labels.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = [ special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
] ]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
if self.tokenizer._pad_token is not None: else:
padding_mask = labels.eq(self.tokenizer.pad_token_id) special_tokens_mask = special_tokens_mask.bool()
probability_matrix.masked_fill_(padding_mask, value=0.0)
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool() masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens labels[~masked_indices] = -100 # We only compute loss on masked tokens
...@@ -385,9 +425,16 @@ class DataCollatorForSOP(DataCollatorForLanguageModeling): ...@@ -385,9 +425,16 @@ class DataCollatorForSOP(DataCollatorForLanguageModeling):
- preprocesses batches for both masked language modeling and sentence order prediction - preprocesses batches for both masked language modeling and sentence order prediction
""" """
def __init__(self, *args, **kwargs):
warnings.warn(
"DataCollatorForSOP is deprecated and will be removed in a future version, you can now use "
"DataCollatorForLanguageModeling instead.",
FutureWarning,
)
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
input_ids = [example["input_ids"] for example in examples] input_ids = [example["input_ids"] for example in examples]
input_ids = self._tensorize_batch(input_ids) input_ids = _collate_batch(input_ids, self.tokenizer)
input_ids, labels, attention_mask = self.mask_tokens(input_ids) input_ids, labels, attention_mask = self.mask_tokens(input_ids)
token_type_ids = [example["token_type_ids"] for example in examples] token_type_ids = [example["token_type_ids"] for example in examples]
...@@ -582,136 +629,3 @@ class DataCollatorForPermutationLanguageModeling: ...@@ -582,136 +629,3 @@ class DataCollatorForPermutationLanguageModeling:
) & masked_indices[i] ) & masked_indices[i]
return inputs.long(), perm_mask, target_mapping, labels.long() return inputs.long(), perm_mask, target_mapping, labels.long()
@dataclass
class DataCollatorForNextSentencePrediction:
"""
Data collator used for next sentence prediction. - collates examples which contains pre-generated negative examples
- preprocesses batches for masked language modeling
"""
tokenizer: PreTrainedTokenizerBase
mlm: bool = True
block_size: int = 512
short_seq_probability: float = 0.1
nsp_probability: float = 0.5
mlm_probability: float = 0.15
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""
The input should contain negative examples, :class:`~transformers.DataCollatorForNextSentencePrediction` will
not generate any negative examples
Args:
examples (:obj:`List[Dict]`): Each dictionary should have the following keys:
- ``tokens_a``: A sequence of tokens, which should appear before ``tokens_b`` in the text.
- ``tokens_b``: A sequence of tokens, which should appear after ``tokens_a`` in the text.
- ``is_random_next``: 1 if this pair is generated randomly, else 0.
"""
tokens_a = [e["tokens_a"] for e in examples]
tokens_b = [e["tokens_b"] for e in examples]
nsp_labels = [1 if e["is_random_next"] else 0 for e in examples]
input_ids = []
segment_ids = []
attention_masks = []
assert len(tokens_a) == len(tokens_b)
for i in range(len(tokens_a)):
input_id, attention_mask, segment_id = self.create_features_from_example(tokens_a[i], tokens_b[i])
input_ids.append(input_id)
segment_ids.append(segment_id)
attention_masks.append(attention_mask)
if self.mlm:
input_ids, mlm_labels = self.mask_tokens(self._tensorize_batch(input_ids))
else:
input_ids = self._tensorize_batch(input_ids)
result = {
"input_ids": input_ids,
"attention_mask": self._tensorize_batch(attention_masks),
"token_type_ids": self._tensorize_batch(segment_ids),
"labels": mlm_labels if self.mlm else None,
"next_sentence_label": torch.tensor(nsp_labels),
}
return result
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
def create_features_from_example(self, tokens_a, tokens_b):
"""Creates examples for a single document."""
max_num_tokens = self.block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
tokens_a, tokens_b, _ = self.tokenizer.truncate_sequences(
tokens_a,
tokens_b,
num_tokens_to_remove=len(tokens_a) + len(tokens_b) - max_num_tokens,
truncation_strategy="longest_first",
)
input_id = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
attention_mask = [1] * len(input_id)
segment_id = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
assert len(input_id) <= self.block_size
# pad
while len(input_id) < self.block_size:
input_id.append(0)
attention_mask.append(0)
segment_id.append(0)
input_id = torch.tensor(input_id)
attention_mask = torch.tensor(attention_mask)
segment_id = torch.tensor(segment_id)
return input_id, attention_mask, segment_id
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import pickle import pickle
import random import random
import time import time
import warnings
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch import torch
...@@ -17,6 +18,11 @@ from ...utils import logging ...@@ -17,6 +18,11 @@ from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
DEPRECATION_WARNING = (
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library."
)
class TextDataset(Dataset): class TextDataset(Dataset):
""" """
This will be superseded by a framework-agnostic approach soon. This will be superseded by a framework-agnostic approach soon.
...@@ -30,6 +36,7 @@ class TextDataset(Dataset): ...@@ -30,6 +36,7 @@ class TextDataset(Dataset):
overwrite_cache=False, overwrite_cache=False,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
): ):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
assert os.path.isfile(file_path), f"Input file path {file_path} not found" assert os.path.isfile(file_path), f"Input file path {file_path} not found"
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False) block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
...@@ -94,6 +101,7 @@ class LineByLineTextDataset(Dataset): ...@@ -94,6 +101,7 @@ class LineByLineTextDataset(Dataset):
""" """
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int): def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
assert os.path.isfile(file_path), f"Input file path {file_path} not found" assert os.path.isfile(file_path), f"Input file path {file_path} not found"
# Here, we do not cache the features, operating under the assumption # Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the # that we will soon use fast multithreaded tokenizers from the
...@@ -120,6 +128,7 @@ class LineByLineWithRefDataset(Dataset): ...@@ -120,6 +128,7 @@ class LineByLineWithRefDataset(Dataset):
""" """
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str): def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
assert os.path.isfile(file_path), f"Input file path {file_path} not found" assert os.path.isfile(file_path), f"Input file path {file_path} not found"
assert os.path.isfile(ref_path), f"Ref file path {file_path} not found" assert os.path.isfile(ref_path), f"Ref file path {file_path} not found"
# Here, we do not cache the features, operating under the assumption # Here, we do not cache the features, operating under the assumption
...@@ -156,6 +165,7 @@ class LineByLineWithSOPTextDataset(Dataset): ...@@ -156,6 +165,7 @@ class LineByLineWithSOPTextDataset(Dataset):
""" """
def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int): def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
assert os.path.isdir(file_dir) assert os.path.isdir(file_dir)
logger.info(f"Creating features from dataset file folder at {file_dir}") logger.info(f"Creating features from dataset file folder at {file_dir}")
self.examples = [] self.examples = []
...@@ -305,6 +315,7 @@ class TextDatasetForNextSentencePrediction(Dataset): ...@@ -305,6 +315,7 @@ class TextDatasetForNextSentencePrediction(Dataset):
short_seq_probability=0.1, short_seq_probability=0.1,
nsp_probability=0.5, nsp_probability=0.5,
): ):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
assert os.path.isfile(file_path), f"Input file path {file_path} not found" assert os.path.isfile(file_path), f"Input file path {file_path} not found"
self.block_size = block_size - tokenizer.num_special_tokens_to_add(pair=True) self.block_size = block_size - tokenizer.num_special_tokens_to_add(pair=True)
...@@ -449,9 +460,18 @@ class TextDatasetForNextSentencePrediction(Dataset): ...@@ -449,9 +460,18 @@ class TextDatasetForNextSentencePrediction(Dataset):
assert len(tokens_a) >= 1 assert len(tokens_a) >= 1
assert len(tokens_b) >= 1 assert len(tokens_b) >= 1
self.examples.append( # add special tokens
{"tokens_a": tokens_a, "tokens_b": tokens_b, "is_random_next": is_random_next} input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
) # add token type ids, 0 for sentence a, 1 for sentence b
token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
example = {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
"next_sentence_label": torch.tensor(1 if is_random_next else 0, dtype=torch.long),
}
self.examples.append(example)
current_chunk = [] current_chunk = []
current_length = 0 current_length = 0
......
...@@ -26,11 +26,6 @@ class DataCollatorForLanguageModeling: ...@@ -26,11 +26,6 @@ class DataCollatorForLanguageModeling:
requires_pytorch(self) requires_pytorch(self)
class DataCollatorForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class DataCollatorForPermutationLanguageModeling: class DataCollatorForPermutationLanguageModeling:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
......
...@@ -12,9 +12,7 @@ if is_torch_available(): ...@@ -12,9 +12,7 @@ if is_torch_available():
from transformers import ( from transformers import (
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling, DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
DataCollatorForTokenClassification, DataCollatorForTokenClassification,
DataCollatorWithPadding, DataCollatorWithPadding,
default_data_collator, default_data_collator,
...@@ -201,13 +199,16 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -201,13 +199,16 @@ class DataCollatorIntegrationTest(unittest.TestCase):
def test_nsp(self): def test_nsp(self):
tokenizer = BertTokenizer(self.vocab_file) tokenizer = BertTokenizer(self.vocab_file)
features = [{"tokens_a": [0, 1, 2, 3, 4], "tokens_b": [0, 1, 2, 3, 4], "is_random_next": i} for i in range(2)] features = [
data_collator = DataCollatorForNextSentencePrediction(tokenizer) {"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
for i in range(2)
]
data_collator = DataCollatorForLanguageModeling(tokenizer)
batch = data_collator(features) batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 512))) self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,))) self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
def test_sop(self): def test_sop(self):
...@@ -216,11 +217,11 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -216,11 +217,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
{ {
"input_ids": torch.tensor([0, 1, 2, 3, 4]), "input_ids": torch.tensor([0, 1, 2, 3, 4]),
"token_type_ids": torch.tensor([0, 1, 2, 3, 4]), "token_type_ids": torch.tensor([0, 1, 2, 3, 4]),
"sentence_order_label": torch.tensor(i), "sentence_order_label": i,
} }
for i in range(2) for i in range(2)
] ]
data_collator = DataCollatorForSOP(tokenizer) data_collator = DataCollatorForLanguageModeling(tokenizer)
batch = data_collator(features) batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5))) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment