Unverified Commit 00440e35 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax MLM] Refactor run mlm with optax (#11745)



* refactor

* update

* update

* update

* refactor run mlm

* finalize

* refactor more

* fix typo

* update

* finish refactor

* modify run mlm

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* small fixes

* upload

* upload

* finish run mlm script
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent 43891be1
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Language model training examples
The following example showcases how to train a language model from scratch
using the JAX/Flax backend.
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
Models written in JAX/Flax are **immutable** and updated in a purely functional
way which enables simple and efficient model parallelism.
## Masked language modeling
In the following, we demonstrate how to train a bi-directional transformer model
using masked language modeling objective as introduced in [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).
More specifically, we demonstrate how JAX/Flax can be leveraged
to pre-train [**`roberta-base`**](https://huggingface.co/roberta-base)
in Norwegian on a single TPUv3-8 pod.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
Let's start by creating a folder to save the trained model and a symbolic link to the `run_mlm_flax.py` script.
```bash
export MODEL_DIR="./norwegian-roberta-base"
mkdir -p ${MODEL_DIR}
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py
```
### Train tokenizer
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
The tokenizer is trained on the complete Norwegian dataset of OSCAR
and consequently saved in `${MODEL_DIR}`
This can take up to 10 minutes depending on your hardware ☕.
```python
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
model_dir = "./norwegian-roberta-base" # ${MODEL_DIR}
# load dataset
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
# Instantiate tokenizer
tokenizer = ByteLevelBPETokenizer()
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i: i + batch_size]["text"]
# Customized training
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
"<s>",
"<pad>",
"</s>",
"<unk>",
"<mask>",
])
# Save files to disk
tokenizer.save(f"{model_dir}/tokenizer.json")
```
### Create configuration
Next, we create the model's configuration file. This is as simple
as loading and storing [`**roberta-base**`](https://huggingface.co/roberta-base)
in the local model folder:
```python
from transformers import RobertaConfig
model_dir = "./norwegian-roberta-base" # ${MODEL_DIR}
config = RobertaConfig.from_pretrained("roberta-base")
config.save_pretrained(model_dir)
```
### Train model
Next we can run the example script to pretrain the model:
```bash
./run_mlm_flax.py \
--output_dir="./runs" \
--model_type="roberta" \
--config_name="${MODEL_DIR}" \
--tokenizer_name="${MODEL_DIR}" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="128" \
--weight_decay="0.01" \
--per_device_train_batch_size="128" \
--per_device_eval_batch_size="128" \
--learning_rate="3e-4" \
--warmup_steps="1000" \
--overwrite_output_dir \
--pad_to_max_length \
--num_train_epochs="18" \
--adam_beta1="0.9" \
--adam_beta2="0.98"
```
Training should converge at a loss and accuracy
of 1.78 and 0.64 respectively after 18 epochs on a single TPUv3-8.
This should take less than 18 hours.
Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg).
For a step-by-step walkthrough of how to do masked language modeling in Flax, please have a
look at [this TODO: (Patrick)]() google colab.
## TODO(Patrick): Add comparison with PyTorch GPU/TPU
datasets >= 1.1.3
jax>=0.2.8
jaxlib>=0.1.59
flax>=0.3.4
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -23,6 +23,7 @@ https://huggingface.co/models?filter=masked-lm
import logging
import os
import sys
import time
from dataclasses import dataclass, field
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
......@@ -35,11 +36,10 @@ from tqdm import tqdm
import jax
import jax.numpy as jnp
import optax
from flax import jax_utils
from flax.optim import Adam
from flax.training import common_utils
from flax.training.common_utils import get_metrics
from jax.nn import log_softmax
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from transformers import (
CONFIG_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
......@@ -269,167 +269,30 @@ class FlaxDataCollatorForLanguageModeling:
return inputs, labels
def create_learning_rate_scheduler(
factors="constant * linear_warmup * rsqrt_decay",
base_learning_rate=0.5,
warmup_steps=1000,
decay_factor=0.5,
steps_per_decay=20000,
steps_per_cycle=100000,
):
"""Creates learning rate schedule.
Interprets factors in the factors string which can consist of:
* constant: interpreted as the constant value,
* linear_warmup: interpreted as linear warmup until warmup_steps,
* rsqrt_decay: divide by square root of max(step, warmup_steps)
* rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1)
* decay_every: Every k steps decay the learning rate by decay_factor.
* cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
Args:
factors: string, factors separated by "*" that defines the schedule.
base_learning_rate: float, the starting constant for the lr schedule.
warmup_steps: int, how many steps to warm up for in the warmup schedule.
decay_factor: float, the amount to decay the learning rate by.
steps_per_decay: int, how often to decay the learning rate.
steps_per_cycle: int, steps per cycle when using cosine decay.
Returns:
a function learning_rate(step): float -> {"learning_rate": float}, the
step-dependent lr.
"""
factors = [n.strip() for n in factors.split("*")]
def step_fn(step):
"""Step to learning rate function."""
ret = 1.0
for name in factors:
if name == "constant":
ret *= base_learning_rate
elif name == "linear_warmup":
ret *= jnp.minimum(1.0, step / warmup_steps)
elif name == "rsqrt_decay":
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
elif name == "rsqrt_normalized_decay":
ret *= jnp.sqrt(warmup_steps)
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
elif name == "decay_every":
ret *= decay_factor ** (step // steps_per_decay)
elif name == "cosine_decay":
progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle))
ret *= jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
else:
raise ValueError(f"Unknown factor {name}.")
return jnp.asarray(ret, dtype=jnp.float32)
return step_fn
def compute_metrics(logits, labels, weights, label_smoothing=0.0):
"""Compute summary metrics."""
loss, normalizer = cross_entropy(logits, labels, weights, label_smoothing)
acc, _ = accuracy(logits, labels, weights)
metrics = {"loss": loss, "accuracy": acc, "normalizer": normalizer}
metrics = jax.lax.psum(metrics, axis_name="batch")
return metrics
def accuracy(logits, targets, weights=None):
"""Compute weighted accuracy for log probs and targets.
Args:
logits: [batch, length, num_classes] float array.
targets: categorical targets [batch, length] int array.
weights: None or array of shape [batch, length]
Returns:
Tuple of scalar loss and batch normalizing factor.
"""
if logits.ndim != targets.ndim + 1:
raise ValueError(f"Incorrect shapes. Got shape {logits.shape} logits and {targets.shape} targets")
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
loss *= weights
return loss.sum(), weights.sum()
def cross_entropy(logits, targets, weights=None, label_smoothing=0.0):
"""Compute cross entropy and entropy for log probs and targets.
Args:
logits: [batch, length, num_classes] float array.
targets: categorical targets [batch, length] int array.
weights: None or array of shape [batch, length]
label_smoothing: label smoothing constant, used to determine the on and off values.
Returns:
Tuple of scalar loss and batch normalizing factor.
"""
if logits.ndim != targets.ndim + 1:
raise ValueError(f"Incorrect shapes. Got shape {logits.shape} logits and {targets.shape} targets")
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence)
loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1)
loss = loss - normalizing_constant
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
else:
normalizing_factor = np.prod(targets.shape)
return loss.sum(), normalizing_factor
def training_step(optimizer, batch, dropout_rng):
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
def loss_fn(params):
targets = batch.pop("labels")
# Hide away tokens which doesn't participate in the optimization
token_mask = jnp.where(targets > 0, 1.0, 0.0)
logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss, weight_sum = cross_entropy(logits, targets, token_mask)
return loss / weight_sum
step = optimizer.state.step
lr = lr_scheduler_fn(step)
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(optimizer.target)
grad = jax.lax.pmean(grad, "batch")
optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
return loss, optimizer, new_dropout_rng
def eval_step(params, batch):
"""
Calculate evaluation metrics on a batch.
"""
targets = batch.pop("labels")
# Hide away tokens which doesn't participate in the optimization
token_mask = jnp.where(targets > 0, 1.0, 0.0)
logits = model(**batch, params=params, train=False)[0]
return compute_metrics(logits, targets, token_mask)
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
nb_samples = len(samples_idx)
samples_to_remove = nb_samples % batch_size
num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = nb_samples // batch_size
sections_split = num_samples // batch_size
batch_idx = np.split(samples_idx, sections_split)
return batch_idx
def write_metric(train_metrics, eval_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
if __name__ == "__main__":
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
......@@ -486,6 +349,7 @@ if __name__ == "__main__":
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
data_args.dataset_name,
......@@ -610,7 +474,6 @@ if __name__ == "__main__":
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
tokenized_datasets = tokenized_datasets.map(
group_texts,
batched=True,
......@@ -619,7 +482,7 @@ if __name__ == "__main__":
)
# Enable tensorboard only on the master node
if has_tensorboard and jax.host_id() == 0:
if has_tensorboard and jax.process_index() == 0:
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
# Data collator
......@@ -632,58 +495,128 @@ if __name__ == "__main__":
model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
# Setup optimizer
optimizer = Adam(
learning_rate=training_args.learning_rate,
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
# Create learning rate schedule
warmup_fn = optax.linear_schedule(
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
)
decay_fn = optax.linear_schedule(
init_value=training_args.learning_rate,
end_value=0,
transition_steps=num_train_steps - training_args.warmup_steps,
)
linear_decay_lr_schedule_fn = optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=1e-8,
weight_decay=training_args.weight_decay,
beta1=training_args.adam_beta1,
beta2=training_args.adam_beta2,
).create(model.params)
# Create learning rate scheduler
# warmup_steps = 0 causes the Flax optimizer to return NaNs; warmup_steps = 1 is functionally equivalent.
lr_scheduler_fn = create_learning_rate_scheduler(
base_learning_rate=training_args.learning_rate, warmup_steps=max(training_args.warmup_steps, 1)
)
# Create parallel version of the training and evaluation steps
p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0,))
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
# Setup train state
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
# Replicate the optimizer on each device
optimizer = jax_utils.replicate(optimizer)
# Define gradient update step fn
def train_step(state, batch, dropout_rng):
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
# Store some constant
nb_epochs = int(training_args.num_train_epochs)
batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
def loss_fn(params):
labels = batch.pop("labels")
epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
for epoch in epochs:
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
# compute loss, ignore padded input tokens
label_mask = jnp.where(labels > 0, 1.0, 0.0)
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
# take average
loss = loss.sum() / label_mask.sum()
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean(
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
)
return new_state, metrics, new_dropout_rng
# Create parallel version of the train step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
# Define eval fn
def eval_step(params, batch):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
# compute loss, ignore padded input tokens
label_mask = jnp.where(labels > 0, 1.0, 0.0)
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
# compute accuracy
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
# summarize metrics
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
metrics = jax.lax.psum(metrics, axis_name="batch")
return metrics
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
# Replicate the train state on each device
state = jax_utils.replicate(state)
train_metrics = []
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
# Create sampling rng
rng, training_rng, eval_rng = jax.random.split(rng, 3)
rng, input_rng = jax.random.split(rng)
# Generate an epoch by shuffling sampling indices from the train dataset
nb_training_samples = len(tokenized_datasets["train"])
training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples))
training_batch_idx = generate_batch_splits(training_samples_idx, batch_size)
num_train_samples = len(tokenized_datasets["train"])
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step
for batch_idx in tqdm(training_batch_idx, desc="Training...", position=1):
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward
model_inputs = common_utils.shard(model_inputs.data)
loss, optimizer, dropout_rngs = p_training_step(optimizer, model_inputs, dropout_rngs)
model_inputs = shard(model_inputs.data)
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
train_metrics.append(train_metric)
epochs.write(f"Loss: {loss}")
train_time += time.time() - train_start
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
nb_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(nb_eval_samples)
num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
......@@ -692,26 +625,27 @@ if __name__ == "__main__":
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward
model_inputs = common_utils.shard(model_inputs.data)
metrics = p_eval_step(optimizer.target, model_inputs)
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics)
eval_metrics_np = get_metrics(eval_metrics)
eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np)
eval_normalizer = eval_metrics_np.pop("normalizer")
eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar
epochs.desc = (
f"Epoch... ({epoch + 1}/{nb_epochs} | Loss: {eval_summary['loss']}, Acc: {eval_summary['accuracy']})"
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
)
# Save metrics
if has_tensorboard and jax.host_id() == 0:
for name, value in eval_summary.items():
summary_writer.scalar(name, value, epoch)
# save last checkpoint
if jax.host_id() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], optimizer.target))
model.save_pretrained(training_args.output_dir, params=params)
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
write_metric(train_metrics, eval_metrics, train_time, cur_step)
# save last checkpoint
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
datasets >= 1.1.3
jax>=0.2.8
jaxlib>=0.1.59
git+https://github.com/google/flax.git
flax>=0.3.4
git+https://github.com/deepmind/optax.git
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment