Unverified Commit 3dcb748e authored by Shashank Gupta's avatar Shashank Gupta Committed by GitHub
Browse files

Added data collator for permutation (XLNet) language modeling and related calls (#5522)

* Added data collator for XLNet language modeling and related calls

Added DataCollatorForXLNetLanguageModeling in data/data_collator.py
to generate necessary inputs for language modeling training with
XLNetLMHeadModel. Also added related arguments, logic and calls in
examples/language-modeling/run_language_modeling.py.

Resolves: #4739, #2008 (partially)

* Changed name to `DataCollatorForPermutationLanguageModeling`

Changed the name of `DataCollatorForXLNetLanguageModeling` to the more general `DataCollatorForPermutationLanguageModelling`.
Removed the `--mlm` flag requirement for the new collator and defined a separate `--plm_probability` flag for its use.
CTRL uses a CLM loss just like GPT and GPT-2, so should work out of the box with this script (provided `past` is taken care of
similar to `mems` for XLNet).
Changed calls and imports appropriately.

* Added detailed comments, changed variable names

Added more detailed comments to `DataCollatorForPermutationLanguageModeling` in `data/data_collator.py` to explain working. Also cleaned up variable names and made them more informative.

* Added tests for new data collator

Added tests in `tests/test_trainer.py` for DataCollatorForPermutationLanguageModeling based on those in DataCollatorForLanguageModeling. A specific test has been added to check for odd-length sequences.

* Fixed styling issues
parent 1d233286
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, CTRL, BERT, RoBERTa, XLNet).
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss. BERT and RoBERTa are fine-tuned
using a masked language modeling (MLM) loss. using a masked language modeling (MLM) loss. XLNet is fine-tuned using a permutation language modeling (PLM) loss.
""" """
...@@ -33,6 +33,7 @@ from transformers import ( ...@@ -33,6 +33,7 @@ from transformers import (
AutoModelWithLMHead, AutoModelWithLMHead,
AutoTokenizer, AutoTokenizer,
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
HfArgumentParser, HfArgumentParser,
LineByLineTextDataset, LineByLineTextDataset,
PreTrainedTokenizer, PreTrainedTokenizer,
...@@ -101,6 +102,15 @@ class DataTrainingArguments: ...@@ -101,6 +102,15 @@ class DataTrainingArguments:
mlm_probability: float = field( mlm_probability: float = field(
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
) )
plm_probability: float = field(
default=1 / 6,
metadata={
"help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling."
},
)
max_span_length: int = field(
default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."}
)
block_size: int = field( block_size: int = field(
default=-1, default=-1,
...@@ -207,8 +217,8 @@ def main(): ...@@ -207,8 +217,8 @@ def main():
if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm: if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
raise ValueError( raise ValueError(
"BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm " "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the"
"flag (masked language modeling)." "--mlm flag (masked language modeling)."
) )
if data_args.block_size <= 0: if data_args.block_size <= 0:
...@@ -221,9 +231,14 @@ def main(): ...@@ -221,9 +231,14 @@ def main():
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
data_collator = DataCollatorForLanguageModeling( if config.model_type == "xlnet":
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability data_collator = DataCollatorForPermutationLanguageModeling(
) tokenizer=tokenizer, plm_probability=data_args.plm_probability, max_span_length=data_args.max_span_length,
)
else:
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
)
# Initialize our Trainer # Initialize our Trainer
trainer = Trainer( trainer = Trainer(
......
...@@ -400,7 +400,12 @@ if is_torch_available(): ...@@ -400,7 +400,12 @@ if is_torch_available():
# Trainer # Trainer
from .trainer import Trainer, torch_distributed_zero_first from .trainer import Trainer, torch_distributed_zero_first
from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling from .data.data_collator import (
default_data_collator,
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
)
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
# Benchmarks # Benchmarks
......
...@@ -21,8 +21,8 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten ...@@ -21,8 +21,8 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
Very simple data collator that: Very simple data collator that:
- simply collates batches of dict-like objects - simply collates batches of dict-like objects
- Performs special handling for potential keys named: - Performs special handling for potential keys named:
- `label`: handles a single value (int or float) per object - ``label``: handles a single value (int or float) per object
- `label_ids`: handles a list of values per object - ``label_ids``: handles a list of values per object
- does not do any additional preprocessing - does not do any additional preprocessing
i.e., Property names of the input object will be used as corresponding inputs to the model. i.e., Property names of the input object will be used as corresponding inputs to the model.
...@@ -134,3 +134,126 @@ class DataCollatorForLanguageModeling: ...@@ -134,3 +134,126 @@ class DataCollatorForLanguageModeling:
# The rest of the time (10% of the time) we keep the masked input tokens unchanged # The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels return inputs, labels
@dataclass
class DataCollatorForPermutationLanguageModeling:
"""
Data collator used for permutation language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for permutation language modeling with procedures specific to XLNet
"""
tokenizer: PreTrainedTokenizer
plm_probability: float = 1 / 6
max_span_length: int = 5 # maximum length of a span of masked tokens
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
batch = self._tensorize_batch(examples)
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
0. Start from the beginning of the sequence by setting ``cur_len = 0`` (number of tokens processed so far).
1. Sample a ``span_length`` from the interval ``[1, max_span_length]`` (length of span of tokens to be masked)
2. Reserve a context of length ``context_length = span_length / plm_probability`` to surround span to be masked
3. Sample a starting point ``start_index`` from the interval ``[cur_len, cur_len + context_length - span_length]`` and mask tokens ``start_index:start_index + span_length``
4. Set ``cur_len = cur_len + context_length``. If ``cur_len < max_len`` (i.e. there are tokens remaining in the sequence to be processed), repeat from Step 1.
"""
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for permutation language modeling. Please add a mask token if you want to use this tokenizer."
)
if inputs.size(1) % 2 != 0:
raise ValueError(
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see relevant comments in source code for details."
)
labels = inputs.clone()
# Creating the mask and target_mapping tensors
masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
for i in range(labels.size(0)):
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
cur_len = 0
max_len = labels.size(1)
while cur_len < max_len:
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
context_length = int(span_length / self.plm_probability)
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
masked_indices[i, start_index : start_index + span_length] = 1
# Set `cur_len = cur_len + context_length`
cur_len += context_length
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
# the i-th predict corresponds to the i-th token.
target_mapping[i] = torch.eye(labels.size(1))
special_tokens_mask = torch.tensor(
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
dtype=torch.bool,
)
masked_indices.masked_fill_(special_tokens_mask, value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
masked_indices.masked_fill_(padding_mask, value=0.0)
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
non_func_mask = ~(padding_mask & special_tokens_mask)
inputs[masked_indices] = self.tokenizer.mask_token_id
labels[~masked_indices] = -100 # We only compute loss on masked tokens
perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
for i in range(labels.size(0)):
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
# determine which tokens a given token can attend to (encoded in `perm_mask`).
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
# This requires that the sequence length be even.
# Create a linear factorisation order
perm_index = torch.arange(labels.size(1))
# Split this into two halves, assuming that half the sequence is reused each time
perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
# Permute the two halves such that they do not cross over
perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
# Flatten this out into the desired permuted factorisation order
perm_index = torch.flatten(perm_index.transpose(0, 1))
# Set the permutation indices of non-masked (non-functional) tokens to the
# smallest index (-1) so that:
# (1) They can be seen by all other positions
# (2) They cannot see masked positions, so there won't be information leak
perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
perm_mask[i] = (
perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
) & masked_indices[i]
return inputs, perm_mask, target_mapping, labels
...@@ -12,6 +12,7 @@ if is_torch_available(): ...@@ -12,6 +12,7 @@ if is_torch_available():
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
default_data_collator, default_data_collator,
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
TextDataset, TextDataset,
...@@ -123,6 +124,34 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -123,6 +124,34 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
def test_plm(self):
tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")
data_collator = DataCollatorForPermutationLanguageModeling(tokenizer)
# ^ permutation lm
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 112)))
self.assertEqual(batch["perm_mask"].shape, torch.Size((31, 112, 112)))
self.assertEqual(batch["target_mapping"].shape, torch.Size((31, 112, 112)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 112)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 512, 512)))
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 512, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
example = [torch.randint(5, [5])]
with self.assertRaises(ValueError):
# Expect error due to odd sequence length
data_collator(example)
@require_torch @require_torch
class TrainerIntegrationTest(unittest.TestCase): class TrainerIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment