Unverified Commit 71c19850 authored by Anton Vlasjuk's avatar Anton Vlasjuk Committed by GitHub
Browse files

Immutability for data collators (#30603)

* immutability fix for seq2seq as well as immutability tests for the collators

* ensure we don't act on none labels and formatting

* remove tf/pt in respective tests as they are not required

* more type error fixes tf/np

* remove todo

* apply suggestions from code review

* formatting / style
parent 5962d62b
......@@ -585,11 +585,34 @@ class DataCollatorForSeq2Seq:
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
# same length to return tensors.
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
# reconvert list[None] to None if necessary
# this might occur when we pass {..., "labels": None}
if labels is not None and all(label is None for label in labels):
labels = None
non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
# run through tokenizer without labels to ensure no side effects
batch = pad_without_fast_tokenizer_warning(
self.tokenizer,
non_labels_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
# we have to pad the labels manually as we cannot rely on `tokenizer.pad` and we need them to be of the same length to return tensors
no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD
if labels is not None and not no_padding:
if labels is not None:
if no_padding:
if isinstance(features[0][label_name], list):
batch["labels"] = list(labels)
else:
batch["labels"] = [np.concatenate([label, []]) for label in labels]
else:
max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
if self.pad_to_multiple_of is not None:
......@@ -600,25 +623,35 @@ class DataCollatorForSeq2Seq:
)
padding_side = self.tokenizer.padding_side
for feature in features:
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
if isinstance(feature["labels"], list):
feature["labels"] = (
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
)
elif padding_side == "right":
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
if isinstance(features[0][label_name], list):
batch["labels"] = [
label + [self.label_pad_token_id] * (max_label_length - len(label))
if padding_side == "right"
else [self.label_pad_token_id] * (max_label_length - len(label)) + label
for label in labels
]
else:
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
batch["labels"] = [
np.concatenate([label, [self.label_pad_token_id] * (max_label_length - len(label))])
if padding_side == "right"
else np.concatenate([[self.label_pad_token_id] * (max_label_length - len(label)), label])
for label in labels
]
features = pad_without_fast_tokenizer_warning(
self.tokenizer,
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
# reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument
if batch.get("labels", None) is not None:
if return_tensors == "pt":
import torch
batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
elif return_tensors == "tf":
import tensorflow as tf
batch["labels"] = tf.constant(batch["labels"], dtype=tf.int64)
else:
batch["labels"] = np.array(batch["labels"], dtype=np.int64)
else:
batch["labels"] = None
# prepare decoder_input_ids
if (
......@@ -626,10 +659,10 @@ class DataCollatorForSeq2Seq:
and self.model is not None
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
):
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])
features["decoder_input_ids"] = decoder_input_ids
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
batch["decoder_input_ids"] = decoder_input_ids
return features
return batch
@dataclass
......
......@@ -439,6 +439,330 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))
@require_torch
class DataCollatorImmutabilityTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
self.vocab_file = os.path.join(self.tmpdirname, "vocab.txt")
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def _turn_to_none(self, item):
"""used to convert `item` to `None` type"""
return None
def _validate_original_data_against_collated_data(self, collator, original_data, batch_data):
# we only care about side effects, the results are tested elsewhere
collator(batch_data)
# we go through every item and convert to `primitive` datatypes if necessary
# then compares for equivalence for the original data and the data that has been passed through the collator
for original, batch in zip(original_data, batch_data):
for original_val, batch_val in zip(original.values(), batch.values()):
if isinstance(original_val, (np.ndarray, torch.Tensor)):
self.assertEqual(original_val.tolist(), batch_val.tolist())
else:
self.assertEqual(original_val, batch_val)
def _validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
self, collator, base_data, input_key, input_datatype, label_key, label_datatype, ignore_label=False
):
# using the arguments to recreate the features with their respective (potentially new) datatypes
features_original = [
{label_key: label_datatype(sample[label_key]), input_key: input_datatype(sample[input_key])}
for sample in base_data
]
features_batch = [
{label_key: label_datatype(sample[label_key]), input_key: input_datatype(sample[input_key])}
for sample in base_data
]
# some collators do not use labels, or sometimes we want to check if the collator with labels can handle such cases
if ignore_label:
for original, batch in zip(features_original, features_batch):
original.pop(label_key)
batch.pop(label_key)
self._validate_original_data_against_collated_data(
collator=collator, original_data=features_original, batch_data=features_batch
)
def test_default_collator_immutability(self):
features_base_single_label = [{"label": i, "inputs": (0, 1, 2, 3, 4, 5)} for i in range(4)]
features_base_multiple_labels = [{"label": (0, 1, 2), "inputs": (0, 1, 2, 3, 4, 5)} for i in range(4)]
for datatype_input, datatype_label in [
(list, int),
(list, float),
(np.array, int),
(np.array, torch.tensor),
(list, self._turn_to_none),
]:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=default_data_collator,
base_data=features_base_single_label,
input_key="inputs",
input_datatype=datatype_input,
label_key="label",
label_datatype=datatype_label,
)
for datatype_input, datatype_label in [(list, list), (list, self._turn_to_none)]:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=default_data_collator,
base_data=features_base_multiple_labels,
input_key="inputs",
input_datatype=datatype_input,
label_key="label",
label_datatype=datatype_label,
)
features_base_single_label_alt = [{"input_ids": (0, 1, 2, 3, 4), "label": float(i)} for i in range(4)]
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=default_data_collator,
base_data=features_base_single_label_alt,
input_key="input_ids",
input_datatype=list,
label_key="label",
label_datatype=float,
)
def test_with_padding_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_original = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
features_batch = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=10)
self._validate_original_data_against_collated_data(
collator=data_collator, original_data=features_original, batch_data=features_batch
)
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
self._validate_original_data_against_collated_data(
collator=data_collator, original_data=features_original, batch_data=features_batch
)
def test_for_token_classification_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_base = [
{"input_ids": (0, 1, 2), "labels": (0, 1, 2)},
{"input_ids": (0, 1, 2, 3, 4, 5), "labels": (0, 1, 2, 3, 4, 5)},
]
token_classification_collators = [
DataCollatorForTokenClassification(tokenizer),
DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10),
DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8),
DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1),
]
for datatype_input, datatype_label in [(list, list), (torch.tensor, torch.tensor)]:
for collator in token_classification_collators:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=collator,
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
)
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=token_classification_collators[-1],
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
def test_seq2seq_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_base = [
{"input_ids": list(range(3)), "labels": list(range(3))},
{"input_ids": list(range(6)), "labels": list(range(6))},
]
seq2seq_collators = [
DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST),
DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.MAX_LENGTH, max_length=7),
DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST, pad_to_multiple_of=8),
DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST, label_pad_token_id=-1),
]
for datatype_input, datatype_label in [(list, list), (torch.tensor, torch.tensor)]:
for collator in seq2seq_collators:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=collator,
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
)
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=seq2seq_collators[-1],
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
features_base_no_pad = [
{"input_ids": list(range(3)), "labels": list(range(3))},
{"input_ids": list(range(3)), "labels": list(range(3))},
]
seq2seq_no_padding_collator = DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.DO_NOT_PAD)
for datatype_input, datatype_label in [(list, list), (torch.tensor, torch.tensor)]:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=seq2seq_no_padding_collator,
base_data=features_base_no_pad,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
)
def test_language_modelling_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_base_no_pad = [
{"input_ids": tuple(range(10)), "labels": (1,)},
{"input_ids": tuple(range(10)), "labels": (1,)},
]
features_base_pad = [
{"input_ids": tuple(range(5)), "labels": (1,)},
{"input_ids": tuple(range(5)), "labels": (1,)},
]
lm_collators = [
DataCollatorForLanguageModeling(tokenizer, mlm=False),
DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8),
DataCollatorForLanguageModeling(tokenizer),
DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8),
]
for datatype_input, datatype_label in [(list, list), (torch.tensor, torch.tensor)]:
for collator in lm_collators:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=collator,
base_data=features_base_no_pad,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=collator,
base_data=features_base_pad,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
def test_whole_world_masking_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_base = [
{"input_ids": list(range(10)), "labels": (1,)},
{"input_ids": list(range(10)), "labels": (1,)},
]
whole_word_masking_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
for datatype_input, datatype_label in [(list, list), (np.array, np.array)]:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=whole_word_masking_collator,
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
def test_permutation_language_modelling_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
plm_collator = DataCollatorForPermutationLanguageModeling(tokenizer)
no_pad_features_original = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
no_pad_features_batch = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
self._validate_original_data_against_collated_data(
collator=plm_collator, original_data=no_pad_features_original, batch_data=no_pad_features_batch
)
pad_features_original = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
pad_features_batch = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
self._validate_original_data_against_collated_data(
collator=plm_collator, original_data=pad_features_original, batch_data=pad_features_batch
)
def test_next_sentence_prediction_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_original = [
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
for i in range(2)
]
features_batch = [
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
for i in range(2)
]
nsp_collator = DataCollatorForLanguageModeling(tokenizer)
self._validate_original_data_against_collated_data(
collator=nsp_collator, original_data=features_original, batch_data=features_batch
)
nsp_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
self._validate_original_data_against_collated_data(
collator=nsp_collator, original_data=features_original, batch_data=features_batch
)
def test_sentence_order_prediction_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_original = [
{
"input_ids": torch.tensor([0, 1, 2, 3, 4]),
"token_type_ids": torch.tensor([0, 1, 2, 3, 4]),
"sentence_order_label": i,
}
for i in range(2)
]
features_batch = [
{
"input_ids": torch.tensor([0, 1, 2, 3, 4]),
"token_type_ids": torch.tensor([0, 1, 2, 3, 4]),
"sentence_order_label": i,
}
for i in range(2)
]
sop_collator = DataCollatorForLanguageModeling(tokenizer)
self._validate_original_data_against_collated_data(
collator=sop_collator, original_data=features_original, batch_data=features_batch
)
sop_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
self._validate_original_data_against_collated_data(
collator=sop_collator, original_data=features_original, batch_data=features_batch
)
@require_tf
class TFDataCollatorIntegrationTest(unittest.TestCase):
def setUp(self):
......@@ -794,6 +1118,338 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["sentence_order_label"].shape.as_list(), [2])
@require_tf
class TFDataCollatorImmutabilityTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
self.vocab_file = os.path.join(self.tmpdirname, "vocab.txt")
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def _turn_to_none(self, item):
"""used to convert `item` to `None` type"""
return None
def _validate_original_data_against_collated_data(self, collator, original_data, batch_data):
# we only care about side effects, the results are tested elsewhere
collator(batch_data)
# we go through every item and convert to `primitive` datatypes if necessary
# then compares for equivalence for the original data and the data that has been passed through the collator
for original, batch in zip(original_data, batch_data):
for original_val, batch_val in zip(original.values(), batch.values()):
if isinstance(original_val, np.ndarray):
self.assertEqual(original_val.tolist(), batch_val.tolist())
elif isinstance(original_val, tf.Tensor):
self.assertEqual(original_val.numpy().tolist(), batch_val.numpy().tolist())
else:
self.assertEqual(original_val, batch_val)
def _validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
self, collator, base_data, input_key, input_datatype, label_key, label_datatype, ignore_label=False
):
# using the arguments to recreate the features with their respective (potentially new) datatypes
features_original = [
{label_key: label_datatype(sample[label_key]), input_key: input_datatype(sample[input_key])}
for sample in base_data
]
features_batch = [
{label_key: label_datatype(sample[label_key]), input_key: input_datatype(sample[input_key])}
for sample in base_data
]
# some collators do not use labels, or sometimes we want to check if the collator with labels can handle such cases
if ignore_label:
for original, batch in zip(features_original, features_batch):
original.pop(label_key)
batch.pop(label_key)
self._validate_original_data_against_collated_data(
collator=collator, original_data=features_original, batch_data=features_batch
)
def test_default_collator_immutability(self):
features_base_single_label = [{"label": i, "inputs": (0, 1, 2, 3, 4, 5)} for i in range(4)]
features_base_multiple_labels = [{"label": (0, 1, 2), "inputs": (0, 1, 2, 3, 4, 5)} for i in range(4)]
for datatype_input, datatype_label in [
(list, int),
(list, float),
(np.array, int),
(np.array, tf.constant),
(list, self._turn_to_none),
]:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=lambda x: default_data_collator(x, return_tensors="tf"),
base_data=features_base_single_label,
input_key="inputs",
input_datatype=datatype_input,
label_key="label",
label_datatype=datatype_label,
)
for datatype_input, datatype_label in [(list, list), (list, self._turn_to_none)]:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=lambda x: default_data_collator(x, return_tensors="tf"),
base_data=features_base_multiple_labels,
input_key="inputs",
input_datatype=datatype_input,
label_key="label",
label_datatype=datatype_label,
)
features_base_single_label_alt = [{"input_ids": (0, 1, 2, 3, 4), "label": float(i)} for i in range(4)]
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=lambda x: default_data_collator(x, return_tensors="tf"),
base_data=features_base_single_label_alt,
input_key="input_ids",
input_datatype=list,
label_key="label",
label_datatype=float,
)
def test_with_padding_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_original = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
features_batch = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=10, return_tensors="tf")
self._validate_original_data_against_collated_data(
collator=data_collator, original_data=features_original, batch_data=features_batch
)
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8, return_tensors="tf")
self._validate_original_data_against_collated_data(
collator=data_collator, original_data=features_original, batch_data=features_batch
)
def test_for_token_classification_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_base = [
{"input_ids": (0, 1, 2), "labels": (0, 1, 2)},
{"input_ids": (0, 1, 2, 3, 4, 5), "labels": (0, 1, 2, 3, 4, 5)},
]
token_classification_collators = [
DataCollatorForTokenClassification(tokenizer, return_tensors="tf"),
DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10, return_tensors="tf"),
DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8, return_tensors="tf"),
DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1, return_tensors="tf"),
]
for datatype_input, datatype_label in [(list, list)]:
for collator in token_classification_collators:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=collator,
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
)
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=token_classification_collators[-1],
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
def test_seq2seq_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_base = [
{"input_ids": list(range(3)), "labels": list(range(3))},
{"input_ids": list(range(6)), "labels": list(range(6))},
]
seq2seq_collators = [
DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST, return_tensors="tf"),
DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.MAX_LENGTH, max_length=7, return_tensors="tf"),
DataCollatorForSeq2Seq(
tokenizer, padding=PaddingStrategy.LONGEST, pad_to_multiple_of=8, return_tensors="tf"
),
DataCollatorForSeq2Seq(
tokenizer, padding=PaddingStrategy.LONGEST, label_pad_token_id=-1, return_tensors="tf"
),
]
for datatype_input, datatype_label in [(list, list)]:
for collator in seq2seq_collators:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=collator,
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
)
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=seq2seq_collators[-1],
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
features_base_no_pad = [
{"input_ids": list(range(3)), "labels": list(range(3))},
{"input_ids": list(range(3)), "labels": list(range(3))},
]
seq2seq_no_padding_collator = DataCollatorForSeq2Seq(
tokenizer, padding=PaddingStrategy.DO_NOT_PAD, return_tensors="tf"
)
for datatype_input, datatype_label in [(list, list)]:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=seq2seq_no_padding_collator,
base_data=features_base_no_pad,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
)
def test_language_modelling_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_base_no_pad = [
{"input_ids": tuple(range(10)), "labels": (1,)},
{"input_ids": tuple(range(10)), "labels": (1,)},
]
features_base_pad = [
{"input_ids": tuple(range(5)), "labels": (1,)},
{"input_ids": tuple(range(5)), "labels": (1,)},
]
lm_collators = [
DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="tf"),
DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8, return_tensors="tf"),
DataCollatorForLanguageModeling(tokenizer, return_tensors="tf"),
DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="tf"),
]
for datatype_input, datatype_label in [(list, list)]:
for collator in lm_collators:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=collator,
base_data=features_base_no_pad,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=collator,
base_data=features_base_pad,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
def test_whole_world_masking_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_base = [
{"input_ids": list(range(10)), "labels": (1,)},
{"input_ids": list(range(10)), "labels": (1,)},
]
whole_word_masking_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf")
for datatype_input, datatype_label in [(list, list), (np.array, np.array)]:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=whole_word_masking_collator,
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
def test_permutation_language_modelling_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
plm_collator = DataCollatorForPermutationLanguageModeling(tokenizer, return_tensors="tf")
no_pad_features_original = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
no_pad_features_batch = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
self._validate_original_data_against_collated_data(
collator=plm_collator, original_data=no_pad_features_original, batch_data=no_pad_features_batch
)
pad_features_original = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
pad_features_batch = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
self._validate_original_data_against_collated_data(
collator=plm_collator, original_data=pad_features_original, batch_data=pad_features_batch
)
def test_next_sentence_prediction_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_original = [
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
for i in range(2)
]
features_batch = [
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
for i in range(2)
]
nsp_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="tf")
self._validate_original_data_against_collated_data(
collator=nsp_collator, original_data=features_original, batch_data=features_batch
)
nsp_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="tf")
self._validate_original_data_against_collated_data(
collator=nsp_collator, original_data=features_original, batch_data=features_batch
)
def test_sentence_order_prediction_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_original = [
{
"input_ids": tf.convert_to_tensor([0, 1, 2, 3, 4]),
"token_type_ids": tf.convert_to_tensor([0, 1, 2, 3, 4]),
"sentence_order_label": i,
}
for i in range(2)
]
features_batch = [
{
"input_ids": tf.convert_to_tensor([0, 1, 2, 3, 4]),
"token_type_ids": tf.convert_to_tensor([0, 1, 2, 3, 4]),
"sentence_order_label": i,
}
for i in range(2)
]
sop_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="tf")
self._validate_original_data_against_collated_data(
collator=sop_collator, original_data=features_original, batch_data=features_batch
)
sop_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="tf")
self._validate_original_data_against_collated_data(
collator=sop_collator, original_data=features_original, batch_data=features_batch
)
class NumpyDataCollatorIntegrationTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
......@@ -1137,3 +1793,332 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["token_type_ids"].shape, (2, 8))
self.assertEqual(batch["labels"].shape, (2, 8))
self.assertEqual(batch["sentence_order_label"].shape, (2,))
class NumpyDataCollatorImmutabilityTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
self.vocab_file = os.path.join(self.tmpdirname, "vocab.txt")
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def _turn_to_none(self, item):
"""used to convert `item` to `None` type"""
return None
def _validate_original_data_against_collated_data(self, collator, original_data, batch_data):
# we only care about side effects, the results are tested elsewhere
collator(batch_data)
# we go through every item and convert to `primitive` datatypes if necessary
# then compares for equivalence for the original data and the data that has been passed through the collator
for original, batch in zip(original_data, batch_data):
for original_val, batch_val in zip(original.values(), batch.values()):
if isinstance(original_val, np.ndarray):
self.assertEqual(original_val.tolist(), batch_val.tolist())
else:
self.assertEqual(original_val, batch_val)
def _validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
self, collator, base_data, input_key, input_datatype, label_key, label_datatype, ignore_label=False
):
# using the arguments to recreate the features with their respective (potentially new) datatypes
features_original = [
{label_key: label_datatype(sample[label_key]), input_key: input_datatype(sample[input_key])}
for sample in base_data
]
features_batch = [
{label_key: label_datatype(sample[label_key]), input_key: input_datatype(sample[input_key])}
for sample in base_data
]
# some collators do not use labels, or sometimes we want to check if the collator with labels can handle such cases
if ignore_label:
for original, batch in zip(features_original, features_batch):
original.pop(label_key)
batch.pop(label_key)
self._validate_original_data_against_collated_data(
collator=collator, original_data=features_original, batch_data=features_batch
)
def test_default_collator_immutability(self):
features_base_single_label = [{"label": i, "inputs": (0, 1, 2, 3, 4, 5)} for i in range(4)]
features_base_multiple_labels = [{"label": (0, 1, 2), "inputs": (0, 1, 2, 3, 4, 5)} for i in range(4)]
for datatype_input, datatype_label in [
(list, int),
(list, float),
(np.array, int),
(np.array, np.array),
(list, self._turn_to_none),
]:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=lambda x: default_data_collator(x, return_tensors="np"),
base_data=features_base_single_label,
input_key="inputs",
input_datatype=datatype_input,
label_key="label",
label_datatype=datatype_label,
)
for datatype_input, datatype_label in [(list, list), (list, self._turn_to_none)]:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=lambda x: default_data_collator(x, return_tensors="np"),
base_data=features_base_multiple_labels,
input_key="inputs",
input_datatype=datatype_input,
label_key="label",
label_datatype=datatype_label,
)
features_base_single_label_alt = [{"input_ids": (0, 1, 2, 3, 4), "label": float(i)} for i in range(4)]
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=lambda x: default_data_collator(x, return_tensors="np"),
base_data=features_base_single_label_alt,
input_key="input_ids",
input_datatype=list,
label_key="label",
label_datatype=float,
)
def test_with_padding_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_original = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
features_batch = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=10, return_tensors="np")
self._validate_original_data_against_collated_data(
collator=data_collator, original_data=features_original, batch_data=features_batch
)
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8, return_tensors="np")
self._validate_original_data_against_collated_data(
collator=data_collator, original_data=features_original, batch_data=features_batch
)
def test_for_token_classification_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_base = [
{"input_ids": (0, 1, 2), "labels": (0, 1, 2)},
{"input_ids": (0, 1, 2, 3, 4, 5), "labels": (0, 1, 2, 3, 4, 5)},
]
token_classification_collators = [
DataCollatorForTokenClassification(tokenizer, return_tensors="np"),
DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10, return_tensors="np"),
DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8, return_tensors="np"),
DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1, return_tensors="np"),
]
for datatype_input, datatype_label in [(list, list)]:
for collator in token_classification_collators:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=collator,
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
)
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=token_classification_collators[-1],
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
def test_seq2seq_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_base = [
{"input_ids": list(range(3)), "labels": list(range(3))},
{"input_ids": list(range(6)), "labels": list(range(6))},
]
seq2seq_collators = [
DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST, return_tensors="np"),
DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.MAX_LENGTH, max_length=7, return_tensors="np"),
DataCollatorForSeq2Seq(
tokenizer, padding=PaddingStrategy.LONGEST, pad_to_multiple_of=8, return_tensors="np"
),
DataCollatorForSeq2Seq(
tokenizer, padding=PaddingStrategy.LONGEST, label_pad_token_id=-1, return_tensors="np"
),
]
for datatype_input, datatype_label in [(list, list)]:
for collator in seq2seq_collators:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=collator,
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
)
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=seq2seq_collators[-1],
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
features_base_no_pad = [
{"input_ids": list(range(3)), "labels": list(range(3))},
{"input_ids": list(range(3)), "labels": list(range(3))},
]
seq2seq_no_padding_collator = DataCollatorForSeq2Seq(
tokenizer, padding=PaddingStrategy.DO_NOT_PAD, return_tensors="np"
)
for datatype_input, datatype_label in [(list, list)]:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=seq2seq_no_padding_collator,
base_data=features_base_no_pad,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
)
def test_language_modelling_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_base_no_pad = [
{"input_ids": tuple(range(10)), "labels": (1,)},
{"input_ids": tuple(range(10)), "labels": (1,)},
]
features_base_pad = [
{"input_ids": tuple(range(5)), "labels": (1,)},
{"input_ids": tuple(range(5)), "labels": (1,)},
]
lm_collators = [
DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="np"),
DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8, return_tensors="np"),
DataCollatorForLanguageModeling(tokenizer, return_tensors="np"),
DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="np"),
]
for datatype_input, datatype_label in [(list, list)]:
for collator in lm_collators:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=collator,
base_data=features_base_no_pad,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=collator,
base_data=features_base_pad,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
def test_whole_world_masking_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_base = [
{"input_ids": list(range(10)), "labels": (1,)},
{"input_ids": list(range(10)), "labels": (1,)},
]
whole_word_masking_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
for datatype_input, datatype_label in [(list, list), (np.array, np.array)]:
self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes(
collator=whole_word_masking_collator,
base_data=features_base,
input_key="input_ids",
input_datatype=datatype_input,
label_key="labels",
label_datatype=datatype_label,
ignore_label=True,
)
def test_permutation_language_modelling_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
plm_collator = DataCollatorForPermutationLanguageModeling(tokenizer, return_tensors="np")
no_pad_features_original = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
no_pad_features_batch = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
self._validate_original_data_against_collated_data(
collator=plm_collator, original_data=no_pad_features_original, batch_data=no_pad_features_batch
)
pad_features_original = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
pad_features_batch = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
self._validate_original_data_against_collated_data(
collator=plm_collator, original_data=pad_features_original, batch_data=pad_features_batch
)
def test_next_sentence_prediction_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_original = [
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
for i in range(2)
]
features_batch = [
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
for i in range(2)
]
nsp_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="np")
self._validate_original_data_against_collated_data(
collator=nsp_collator, original_data=features_original, batch_data=features_batch
)
nsp_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="np")
self._validate_original_data_against_collated_data(
collator=nsp_collator, original_data=features_original, batch_data=features_batch
)
def test_sentence_order_prediction_collator_immutability(self):
tokenizer = BertTokenizer(self.vocab_file)
features_original = [
{
"input_ids": np.array([0, 1, 2, 3, 4]),
"token_type_ids": np.array([0, 1, 2, 3, 4]),
"sentence_order_label": i,
}
for i in range(2)
]
features_batch = [
{
"input_ids": np.array([0, 1, 2, 3, 4]),
"token_type_ids": np.array([0, 1, 2, 3, 4]),
"sentence_order_label": i,
}
for i in range(2)
]
sop_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="np")
self._validate_original_data_against_collated_data(
collator=sop_collator, original_data=features_original, batch_data=features_batch
)
sop_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="np")
self._validate_original_data_against_collated_data(
collator=sop_collator, original_data=features_original, batch_data=features_batch
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment