Unverified Commit 242ec31a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Refactor MLM (#12013)



* fix_torch_device_generate_test

* remove @

* finish refactor
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent 4674061b
...@@ -34,6 +34,7 @@ import numpy as np ...@@ -34,6 +34,7 @@ import numpy as np
from datasets import load_dataset from datasets import load_dataset
from tqdm import tqdm from tqdm import tqdm
import flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
...@@ -185,9 +186,7 @@ class DataTrainingArguments: ...@@ -185,9 +186,7 @@ class DataTrainingArguments:
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
# Adapted from transformers/data/data_collator.py @flax.struct.dataclass
# Letting here for now, let's discuss where it should live
@dataclass
class FlaxDataCollatorForLanguageModeling: class FlaxDataCollatorForLanguageModeling:
""" """
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
...@@ -196,12 +195,8 @@ class FlaxDataCollatorForLanguageModeling: ...@@ -196,12 +195,8 @@ class FlaxDataCollatorForLanguageModeling:
Args: Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data. The tokenizer used for encoding the data.
mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
non-masked tokens and the value to predict for the masked token.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15): mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`. The probability with which to (randomly) mask tokens in the input.
.. note:: .. note::
...@@ -212,11 +207,10 @@ class FlaxDataCollatorForLanguageModeling: ...@@ -212,11 +207,10 @@ class FlaxDataCollatorForLanguageModeling:
""" """
tokenizer: PreTrainedTokenizerBase tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15 mlm_probability: float = 0.15
def __post_init__(self): def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None: if self.tokenizer.mask_token is None:
raise ValueError( raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. " "This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead." "You should pass `mlm=False` to train on causal language modeling instead."
...@@ -228,15 +222,10 @@ class FlaxDataCollatorForLanguageModeling: ...@@ -228,15 +222,10 @@ class FlaxDataCollatorForLanguageModeling:
# If special token mask has been preprocessed, pop it from the dict. # If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None) special_tokens_mask = batch.pop("special_tokens_mask", None)
if self.mlm:
batch["input_ids"], batch["labels"] = self.mask_tokens( batch["input_ids"], batch["labels"] = self.mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask batch["input_ids"], special_tokens_mask=special_tokens_mask
) )
else:
labels = batch["input_ids"].copy()
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch return batch
def mask_tokens( def mask_tokens(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment