Unverified Commit 31c3e7e7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Add T5 pretraining script (#12355)



* fix_torch_device_generate_test

* remove @

* add length computatan

* finish masking

* finish

* upload

* fix some bugs

* finish

* fix dependency table

* correct tensorboard

* Apply suggestions from code review

* correct processing

* slight change init

* correct some more mistakes

* apply suggestions

* improve readme

* fix indent

* Apply suggestions from code review
Co-authored-by: default avatarSaulLu <55560583+SaulLu@users.noreply.github.com>

* correct tokenizer

* finish

* finish

* finish

* finish
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
Co-authored-by: default avatarSaulLu <55560583+SaulLu@users.noreply.github.com>
parent e2770748
...@@ -241,6 +241,140 @@ of 3.24 and 25.72 respectively after 20 epochs on a single TPUv3-8. ...@@ -241,6 +241,140 @@ of 3.24 and 25.72 respectively after 20 epochs on a single TPUv3-8.
This should take less than ~21 hours. This should take less than ~21 hours.
Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/2zEhLwJ0Qp2FAkI3WVH9qA). Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/2zEhLwJ0Qp2FAkI3WVH9qA).
## T5-like span-masked language modeling
In the following, we demonstrate how to train a T5 model using the span-masked language model
objective as proposed in the [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683).
More specifically, we demonstrate how JAX/Flax can be leveraged
to pre-train [**`t5-small`**](https://huggingface.co/t5-small)
in Norwegian on a single TPUv3-8 pod.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"norwegian-t5-small"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create norwegian-t5-small
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/norwegian-t5-small
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd norwegian-t5-small
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_t5_mlm_flax.py` and `t5_tokenizer_model` scripts.
```bash
export MODEL_DIR="./norwegian-t5-small"
ln -s ~/transformers/examples/flax/language-modeling/run_t5_mlm_flax.py run_t5_mlm_flax.py
ln -s ~/transformers/examples/flax/language-modeling/t5_tokenizer_model.py t5_tokenizer_model.py
```
### Train tokenizer
In the first step, we train a tokenizer to efficiently process the text input for the model.
We make use of the [tokenizers](https://github.com/huggingface/tokenizers) library to train
a sentencepiece unigram tokenizer as shown in [t5_tokenizer_model.py](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling/t5_tokenizer_model.py)
which is heavily inspired from [yandex-research/DeDLOC's tokenizer model](https://github.com/yandex-research/DeDLOC/blob/5c994bc64e573702a9a79add3ecd68b38f14b548/sahajbert/tokenizer/tokenizer_model.py) .
The tokenizer is trained on the complete Norwegian dataset of OSCAR
and consequently saved in `${MODEL_DIR}`
This can take up to 120 minutes depending on your hardware ☕☕☕ .
```python
import datasets
from t5_tokenizer_model import SentencePieceUnigramTokenizer
vocab_size = 32_000
input_sentence_size = None
model_dir = "./norwegian-t5-small" # ${MODEL_DIR}
# Initialize a dataset
dataset = datasets.load_dataset("oscar", name="unshuffled_deduplicated_no", split="train")
tokenizer = SentencePieceUnigramTokenizer(unk_token="<unk>", eos_token="</s>", pad_token="<pad>")
# Build an iterator over this dataset
def batch_iterator(input_sentence_size=None):
if input_sentence_size is None:
input_sentence_size = len(dataset)
batch_length = 100
for i in range(0, input_sentence_size, batch_length):
yield dataset[i: i + batch_length]["text"]
# Train tokenizer
tokenizer.train_from_iterator(
iterator=batch_iterator(input_sentence_size=input_sentence_size),
vocab_size=vocab_size,
show_progress=True,
)
# Save files to disk
tokenizer.save(f"{model_dir}/tokenizer.json")
```
### Create configuration
Next, we create the model's configuration file. This is as simple
as loading and storing [`**t5-small**`](https://huggingface.co/t5-small)
in the local model folder:
```python
from transformers import T5Config
model_dir = "./norwegian-t5-small" # ${MODEL_DIR}
config = T5Config.from_pretrained("t5-small")
config.save_pretrained(model_dir)
```
### Train model
Next we can run the example script to pretrain the model:
```bash
./run_t5_mlm_flax.py \
--output_dir="${MODEL_DIR}" \
--model_type="t5" \
--config_name="${MODEL_DIR}" \
--tokenizer_name="${MODEL_DIR}" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="512" \
--per_device_train_batch_size="16" \
--per_device_eval_batch_size="16" \
--learning_rate="1e-3" \
--weight_decay="0.001" \
--warmup_steps="5000" \
--overwrite_output_dir \
--num_train_epochs="10" \
--push_to_hub
```
Training should converge at a loss and accuracy
of XXX and XXX respectively after 10 epochs on a single TPUv3-8.
This should take less than 18 hours.
Training statistics can be accessed on directly on the 🤗 [hub (TODO)]()
## Runtime evaluation ## Runtime evaluation
......
...@@ -582,12 +582,12 @@ if __name__ == "__main__": ...@@ -582,12 +582,12 @@ if __name__ == "__main__":
# Replicate the train state on each device # Replicate the train state on each device
state = jax_utils.replicate(state) state = jax_utils.replicate(state)
train_metrics = []
train_time = 0 train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs: for epoch in epochs:
# ======================== Training ================================ # ======================== Training ================================
train_start = time.time() train_start = time.time()
train_metrics = []
# Create sampling rng # Create sampling rng
rng, input_rng = jax.random.split(rng) rng, input_rng = jax.random.split(rng)
......
This diff is collapsed.
#!/usr/bin/env python3
import json
from typing import Iterator, List, Union
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, trainers
from tokenizers.implementations.base_tokenizer import BaseTokenizer
from tokenizers.models import Unigram
from tokenizers.processors import TemplateProcessing
class SentencePieceUnigramTokenizer(BaseTokenizer):
"""
This class is a copy of `DeDLOC's tokenizer implementation <https://github.com/yandex-research/DeDLOC/blob/main/sahajbert/tokenizer/tokenizer_model.py>`__ .
Custom SentencePiece Unigram Tokenizer with NMT, NKFC, spaces and lower-casing characters normalization
Represents the Unigram algorithm, with the pretokenization used by SentencePiece
"""
def __init__(
self,
replacement: str = "▁",
add_prefix_space: bool = True,
unk_token: Union[str, AddedToken] = "<unk>",
eos_token: Union[str, AddedToken] = "</s>",
pad_token: Union[str, AddedToken] = "<pad>",
):
self.special_tokens = {
"pad": {"id": 0, "token": pad_token},
"eos": {"id": 1, "token": eos_token},
"unk": {"id": 2, "token": unk_token},
}
self.special_tokens_list = [None] * len(self.special_tokens)
for token_dict in self.special_tokens.values():
self.special_tokens_list[token_dict["id"]] = token_dict["token"]
tokenizer = Tokenizer(Unigram())
tokenizer.normalizer = normalizers.Sequence(
[
normalizers.Nmt(),
normalizers.NFKC(),
normalizers.Replace(Regex(" {2,}"), " "),
normalizers.Lowercase(),
]
)
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
pre_tokenizers.Digits(individual_digits=True),
pre_tokenizers.Punctuation(),
]
)
tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
tokenizer.post_processor = TemplateProcessing(
single=f"$A {self.special_tokens['eos']['token']}",
special_tokens=[(self.special_tokens["eos"]["token"], self.special_tokens["eos"]["id"])],
)
parameters = {
"model": "SentencePieceUnigram",
"replacement": replacement,
"add_prefix_space": add_prefix_space,
}
super().__init__(tokenizer, parameters)
def train(
self,
files: Union[str, List[str]],
vocab_size: int = 8000,
show_progress: bool = True,
):
"""Train the model using the given files"""
trainer = trainers.UnigramTrainer(
vocab_size=vocab_size,
special_tokens=self.special_tokens_list,
show_progress=show_progress,
)
if isinstance(files, str):
files = [files]
self._tokenizer.train(files, trainer=trainer)
self.add_unk_id()
def train_from_iterator(
self,
iterator: Union[Iterator[str], Iterator[Iterator[str]]],
vocab_size: int = 8000,
show_progress: bool = True,
):
"""Train the model using the given iterator"""
trainer = trainers.UnigramTrainer(
vocab_size=vocab_size,
special_tokens=self.special_tokens_list,
show_progress=show_progress,
)
self._tokenizer.train_from_iterator(iterator, trainer=trainer)
self.add_unk_id()
def add_unk_id(self):
tokenizer_json = json.loads(self._tokenizer.to_str())
tokenizer_json["model"]["unk_id"] = self.special_tokens["unk"]["id"]
self._tokenizer = Tokenizer.from_str(json.dumps(tokenizer_json))
...@@ -378,7 +378,7 @@ official [flax example folder](https://github.com/huggingface/transformers/tree/ ...@@ -378,7 +378,7 @@ official [flax example folder](https://github.com/huggingface/transformers/tree/
- [Masked language modeling (BERT, RoBERTa, ELECTRA, BigBird)](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_mlm_flax.py) - [Masked language modeling (BERT, RoBERTa, ELECTRA, BigBird)](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_mlm_flax.py)
- [Text classification (BERT, RoBERTa, ELECTRA, BigBird)](https://github.com/huggingface/transformers/blob/master/examples/flax/text-classification/run_flax_glue.py) - [Text classification (BERT, RoBERTa, ELECTRA, BigBird)](https://github.com/huggingface/transformers/blob/master/examples/flax/text-classification/run_flax_glue.py)
- [Summarization / Seq2Seq (BART, MBART, T5)](https://github.com/huggingface/transformers/blob/master/examples/flax/summarization/run_summarization_flax.py) - [Summarization / Seq2Seq (BART, MBART, T5)](https://github.com/huggingface/transformers/blob/master/examples/flax/summarization/run_summarization_flax.py)
- [(TODO) Masked Seq2Seq pret-training (T5)]( ) - [Masked Seq2Seq pret-training (T5)](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_t5_mlm_flax.py)
- [(TODO) Image classification (ViT)]( ) - [(TODO) Image classification (ViT)]( )
- [(TODO) CLIP pretraining, fine-tuning (CLIP)]( ) - [(TODO) CLIP pretraining, fine-tuning (CLIP)]( )
......
...@@ -141,13 +141,6 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( ...@@ -141,13 +141,6 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
] ]
) )
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(BartConfig, FlaxBartForConditionalGeneration)
]
)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[ [
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
......
...@@ -185,31 +185,32 @@ class FlaxT5Attention(nn.Module): ...@@ -185,31 +185,32 @@ class FlaxT5Attention(nn.Module):
self.dropout = self.config.dropout_rate self.dropout = self.config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim self.inner_dim = self.n_heads * self.key_value_proj_dim
inner_dim_init_std = self.config.initializer_factor * (self.inner_dim ** -0.5) q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
d_model_init_std = self.config.initializer_factor * (self.inner_dim ** -0.5) kv_init_std = self.config.initializer_factor * (self.inner_dim ** -0.5)
o_init_std = self.config.initializer_factor * (self.inner_dim ** -0.5)
self.q = nn.Dense( self.q = nn.Dense(
self.inner_dim, self.inner_dim,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(d_model_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(q_init_std, self.dtype),
dtype=self.dtype, dtype=self.dtype,
) )
self.k = nn.Dense( self.k = nn.Dense(
self.inner_dim, self.inner_dim,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(d_model_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(kv_init_std, self.dtype),
dtype=self.dtype, dtype=self.dtype,
) )
self.v = nn.Dense( self.v = nn.Dense(
self.inner_dim, self.inner_dim,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(d_model_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(kv_init_std, self.dtype),
dtype=self.dtype, dtype=self.dtype,
) )
self.o = nn.Dense( self.o = nn.Dense(
self.d_model, self.d_model,
use_bias=False, use_bias=False,
kernel_init=jax.nn.initializers.normal(inner_dim_init_std, self.dtype), kernel_init=jax.nn.initializers.normal(o_init_std, self.dtype),
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -217,7 +218,7 @@ class FlaxT5Attention(nn.Module): ...@@ -217,7 +218,7 @@ class FlaxT5Attention(nn.Module):
self.relative_attention_bias = nn.Embed( self.relative_attention_bias = nn.Embed(
self.relative_attention_num_buckets, self.relative_attention_num_buckets,
self.n_heads, self.n_heads,
embedding_init=jax.nn.initializers.normal(d_model_init_std, self.dtype), embedding_init=jax.nn.initializers.normal(kv_init_std, self.dtype),
dtype=self.dtype, dtype=self.dtype,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment