Unverified Commit 242ec31a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Refactor MLM (#12013)



* fix_torch_device_generate_test

* remove @

* finish refactor
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent 4674061b
......@@ -34,6 +34,7 @@ import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import optax
......@@ -185,9 +186,7 @@ class DataTrainingArguments:
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
# Adapted from transformers/data/data_collator.py
# Letting here for now, let's discuss where it should live
@dataclass
@flax.struct.dataclass
class FlaxDataCollatorForLanguageModeling:
"""
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
......@@ -196,12 +195,8 @@ class FlaxDataCollatorForLanguageModeling:
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
non-masked tokens and the value to predict for the masked token.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
The probability with which to (randomly) mask tokens in the input.
.. note::
......@@ -212,11 +207,10 @@ class FlaxDataCollatorForLanguageModeling:
"""
tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15
def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
......@@ -228,15 +222,10 @@ class FlaxDataCollatorForLanguageModeling:
# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
if self.mlm:
batch["input_ids"], batch["labels"] = self.mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
else:
labels = batch["input_ids"].copy()
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
batch["labels"] = labels
batch["input_ids"], batch["labels"] = self.mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
return batch
def mask_tokens(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment