Unverified Commit 6c40e497 authored by Andrea Cappelli's avatar Andrea Cappelli Committed by GitHub
Browse files

Run mlm pad to multiple for fp16 (#11128)

* Add mlm collator pad to multiple option (#10627)

* Use padding to 8x in run mlm (#10627)
parent dfed4ec2
...@@ -422,7 +422,12 @@ def main(): ...@@ -422,7 +422,12 @@ def main():
# Data collator # Data collator
# This one will take care of randomly masking the tokens. # This one will take care of randomly masking the tokens.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability) pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm_probability=data_args.mlm_probability,
pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
)
# Initialize our Trainer # Initialize our Trainer
trainer = Trainer( trainer = Trainer(
......
...@@ -192,7 +192,7 @@ class DataCollatorForTokenClassification: ...@@ -192,7 +192,7 @@ class DataCollatorForTokenClassification:
return batch return batch
def _collate_batch(examples, tokenizer): def _collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
# Tensorize if necessary. # Tensorize if necessary.
if isinstance(examples[0], (list, tuple)): if isinstance(examples[0], (list, tuple)):
...@@ -201,7 +201,7 @@ def _collate_batch(examples, tokenizer): ...@@ -201,7 +201,7 @@ def _collate_batch(examples, tokenizer):
# Check if padding is necessary. # Check if padding is necessary.
length_of_first = examples[0].size(0) length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length: if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
return torch.stack(examples, dim=0) return torch.stack(examples, dim=0)
# If yes, check if we have a `pad_token`. # If yes, check if we have a `pad_token`.
...@@ -213,6 +213,8 @@ def _collate_batch(examples, tokenizer): ...@@ -213,6 +213,8 @@ def _collate_batch(examples, tokenizer):
# Creating the full tensor and filling it with our data. # Creating the full tensor and filling it with our data.
max_length = max(x.size(0) for x in examples) max_length = max(x.size(0) for x in examples)
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
for i, example in enumerate(examples): for i, example in enumerate(examples):
if tokenizer.padding_side == "right": if tokenizer.padding_side == "right":
...@@ -311,6 +313,8 @@ class DataCollatorForLanguageModeling: ...@@ -311,6 +313,8 @@ class DataCollatorForLanguageModeling:
non-masked tokens and the value to predict for the masked token. non-masked tokens and the value to predict for the masked token.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15): mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`. The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
.. note:: .. note::
...@@ -323,6 +327,7 @@ class DataCollatorForLanguageModeling: ...@@ -323,6 +327,7 @@ class DataCollatorForLanguageModeling:
tokenizer: PreTrainedTokenizerBase tokenizer: PreTrainedTokenizerBase
mlm: bool = True mlm: bool = True
mlm_probability: float = 0.15 mlm_probability: float = 0.15
pad_to_multiple_of: Optional[int] = None
def __post_init__(self): def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None: if self.mlm and self.tokenizer.mask_token is None:
...@@ -336,9 +341,9 @@ class DataCollatorForLanguageModeling: ...@@ -336,9 +341,9 @@ class DataCollatorForLanguageModeling:
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
# Handle dict or lists with proper padding and conversion to tensor. # Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], (dict, BatchEncoding)):
batch = self.tokenizer.pad(examples, return_tensors="pt") batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
else: else:
batch = {"input_ids": _collate_batch(examples, self.tokenizer)} batch = {"input_ids": _collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)}
# If special token mask has been preprocessed, pop it from the dict. # If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None) special_tokens_mask = batch.pop("special_tokens_mask", None)
......
...@@ -146,11 +146,8 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -146,11 +146,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["labels"].shape, torch.Size([2, 6])) self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3) self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)
def test_data_collator_for_language_modeling(self): def _test_no_pad_and_pad(self, no_pad_features, pad_features):
tokenizer = BertTokenizer(self.vocab_file) tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
batch = data_collator(no_pad_features) batch = data_collator(no_pad_features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
...@@ -160,6 +157,15 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -160,6 +157,15 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8)
batch = data_collator(no_pad_features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
batch = data_collator(pad_features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
tokenizer._pad_token = None tokenizer._pad_token = None
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -185,6 +191,32 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -185,6 +191,32 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertTrue(torch.any(masked_tokens)) self.assertTrue(torch.any(masked_tokens))
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist())) self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
batch = data_collator(no_pad_features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
self.assertTrue(torch.any(masked_tokens))
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
batch = data_collator(pad_features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
self.assertTrue(torch.any(masked_tokens))
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
def test_data_collator_for_language_modeling(self):
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
self._test_no_pad_and_pad(no_pad_features, pad_features)
no_pad_features = [list(range(10)), list(range(10))]
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)
def test_plm(self): def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file) tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
...@@ -225,6 +257,14 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -225,6 +257,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["labels"].shape, torch.Size((2, 5))) self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,))) self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 8)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 8)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 8)))
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
def test_sop(self): def test_sop(self):
tokenizer = BertTokenizer(self.vocab_file) tokenizer = BertTokenizer(self.vocab_file)
features = [ features = [
...@@ -242,3 +282,11 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -242,3 +282,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5))) self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 5))) self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,))) self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 8)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 8)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 8)))
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment