Unverified Commit 4ce6bcc3 authored by Marc van Zee's avatar Marc van Zee Committed by GitHub
Browse files

Adds Flax BERT finetuning example on GLUE (#11564)



* Adds Flax BERT finetuning example

* fix traced jax tensor type

* Use Optax losses and learning schedulers

* Add 1GPU training results

* merge into master & make style

* fix input

* del file

* Fix bug in loss and add torch runs

* finish bert flax fine-tune

* Update examples/flax/text-classification/README.md

* Update examples/flax/text-classification/run_flax_glue.py

* add requirements

* finalize

* finalize
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent f13f1f8f
<!---
Copyright 2021 The Google Flax Team Authors and HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Text classification examples
## GLUE tasks
Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/flax/text-classification/run_flax_glue.py).
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models).
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
```bash
export TASK_NAME=mrpc
python run_flax_glue.py \
--model_name_or_path bert-base-cased \
--task_name $TASK_NAME \
--max_length 128 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--output_dir /tmp/$TASK_NAME/
```
where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
Using the command above, the script will train for 3 epochs and run eval after each epoch.
Metrics and hyperparameters are stored in Tensorflow event files in `---output_dir`.
You can see the results by running `tensorboard` in that directory:
```bash
$ tensorboard --logdir .
```
### Accuracy Evaluation
We train five replicas and report mean accuracy and stdev on the dev set below.
We use the settings as in the command above (with an exception for MRPC and
WNLI which are tiny and where we used 5 epochs instead of 3), and we use a total
train batch size of 32 (we train on 8 Cloud v3 TPUs, so a per-device batch size of 4),
On the task other than MRPC and WNLI we train for 3 these epochs because this is the standard,
but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models
are undertrained and we could get better results when training longer.
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1zKL_xn32HwbxkFMxB3ftca-soTHAuBFgIhYhOhCnZ4E/edit?usp=sharing).
| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
| CoLA | Matthew's corr | 59.57 | 58.04 | 1.81 | [tfhub.dev](https://tensorboard.dev/experiment/f4OvQpWtRq6CvddpxGBd0A/) |
| SST-2 | Accuracy | 92.43 | 91.79 | 0.59 | [tfhub.dev](https://tensorboard.dev/experiment/BYFwa49MRTaLIn93DgAEtA/) |
| MRPC | F1/Accuracy | 89.50/84.8 | 88.70/84.02 | 0.56/0.48 | [tfhub.dev](https://tensorboard.dev/experiment/9ZWH5xwXRS6zEEUE4RaBhQ/) |
| STS-B | Pearson/Spearman corr. | 90.00/88.71 | 89.09/88.61 | 0.51/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/mUlI5B9QQ0WGEJip7p3Tng/) |
| QQP | Accuracy/F1 | 90.88/87.64 | 90.75/87.53 | 0.11/0.13 | [tfhub.dev](https://tensorboard.dev/experiment/pO6h75L3SvSXSWRcgljXKA/) |
| MNLI | Matched acc. | 84.06 | 83.88 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/LKwaOH18RMuo7nJkESrpKg/) |
| QNLI | Accuracy | 91.01 | 90.86 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/qesXxNcaQhmKxPmbw1sOoA/) |
| RTE | Accuracy | 66.80 | 65.27 | 1.07 | [tfhub.dev](https://tensorboard.dev/experiment/Z84xC0r6RjyzT4SLqiAbzQ/) |
| WNLI | Accuracy | 39.44 | 32.96 | 5.85 | [tfhub.dev](https://tensorboard.dev/experiment/gV73w9v0RIKrqVw32PZbAQ/) |
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
### Runtime evaluation
We also ran each task once on a single V100 GPU, 8 V100 GPUs, and 8 Cloud v3 TPUs and report the
overall training time below. For comparison we ran Pytorch's [run_glue.py](https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-classification/run_glue.py) on a single GPU (last column).
| Task | 8 TPU | 8 GPU | 1 GPU | 1 GPU (Pytorch) |
|-------|---------|---------|------------|-----------------|
| CoLA | 1m 46s | 1m 26s | 3m 6s | 4m 6s |
| SST-2 | 5m 30s | 6m 28s | 22m 6s | 34m 37s |
| MRPC | 1m 32s | 1m 14s | 2m 17s | 2m 56s |
| STS-B | 1m 33s | 1m 12s | 2m 11s | 2m 48s |
| QQP | 24m 40s | 31m 48s | 1h 20m 15s | 2h 54m |
| MNLI | 26m 30s | 33m 55s | 2h 7m 30s | 3u 7m 6s |
| QNLI | 8m | 9m 40s | 34m 20s | 49m 8s |
| RTE | 1m 21s | 55s | 1m 8s | 1m 16s |
| WNLI | 1m 12s | 48s | 38s | 36s |
datasets >= 1.1.3
jax>=0.2.8
jaxlib>=0.1.59
git+https://github.com/google/flax.git
git+https://github.com/deepmind/optax.git
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
import argparse
import logging
import os
import random
import time
from itertools import chain
from typing import Any, Callable, Dict, Tuple
import datasets
from datasets import load_dataset, load_metric
import jax
import jax.numpy as jnp
import optax
import transformers
from flax import linen as nn
from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
logger = logging.getLogger(__name__)
Array = Any
Dataset = datasets.arrow_dataset.Dataset
PRNGKey = Any
task_to_keys = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
parser.add_argument(
"--task_name",
type=str,
default=None,
help="The name of the glue task to train on.",
choices=list(task_to_keys.keys()),
)
parser.add_argument(
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
)
parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
)
parser.add_argument(
"--max_length",
type=int,
default=128,
help=(
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
" sequences shorter will be padded."
),
)
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
)
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=2, help="A seed for reproducible training.")
args = parser.parse_args()
# Sanity checks
if args.task_name is None and args.train_file is None and args.validation_file is None:
raise ValueError("Need either a task name or a training/validation file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
return args
def create_train_state(
model: FlaxAutoModelForSequenceClassification,
learning_rate_fn: Callable[[int], float],
is_regression: bool,
num_labels: int,
) -> train_state.TrainState:
"""Create initial training state."""
class TrainState(train_state.TrainState):
"""Train state with an Optax optimizer.
The two functions below differ depending on whether the task is classification
or regression.
Args:
logits_fn: Applied to last layer to obtain the logits.
loss_fn: Function to compute the loss.
"""
logits_fn: Callable = struct.field(pytree_node=False)
loss_fn: Callable = struct.field(pytree_node=False)
# Creates a multi-optimizer consisting of two "Adam with weight decay" optimizers.
def adamw(weight_decay):
return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay)
def traverse(fn):
def mask(data):
flat = traverse_util.flatten_dict(data)
return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})
return mask
# We use Optax's "masking" functionality to create a multi-optimizer, one
# with weight decay and the other without. Note masking means the optimizer
# will ignore these paths.
decay_path = lambda p: not any(x in p for x in ["bias", "LayerNorm.weight"]) # noqa: E731
tx = optax.chain(
optax.masked(adamw(0.0), mask=traverse(lambda path, _: decay_path(path))),
optax.masked(adamw(0.01), mask=traverse(lambda path, _: not decay_path(path))),
)
if is_regression:
def mse_loss(logits, labels):
return jnp.mean((logits[..., 0] - labels) ** 2)
return TrainState.create(
apply_fn=model.__call__,
params=model.params,
tx=tx,
logits_fn=lambda logits: logits[..., 0],
loss_fn=mse_loss,
)
else: # Classification.
def cross_entropy_loss(logits, labels):
logits = nn.log_softmax(logits)
xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
return jnp.mean(xentropy)
return TrainState.create(
apply_fn=model.__call__,
params=model.params,
tx=tx,
logits_fn=lambda logits: logits.argmax(-1),
loss_fn=cross_entropy_loss,
)
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
"""Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
steps_per_epoch = len(dataset) // batch_size
perms = jax.random.permutation(rng, len(dataset))
perms = perms[: steps_per_epoch * batch_size] # Skip incomplete batch.
perms = perms.reshape((steps_per_epoch, batch_size))
for perm in perms:
batch = dataset[perm]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch
def glue_eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
for i in range(len(dataset) // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch
def main():
args = parse_args()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
# For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
# sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
# label if at least two columns are provided.
# If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
# single column. You can easily tweak this behavior (see below)
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if args.task_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset("glue", args.task_name)
else:
# Loading the dataset from local csv or json file.
data_files = {}
if args.train_file is not None:
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file
extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Labels
if args.task_name is not None:
is_regression = args.task_name == "stsb"
if not is_regression:
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)
else:
num_labels = 1
else:
# Trying to have good defaults here, don't hesitate to tweak to your needs.
is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
if is_regression:
num_labels = 1
else:
# A useful fast method:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
label_list = raw_datasets["train"].unique("label")
label_list.sort() # Let's sort it for determinism
num_labels = len(label_list)
# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)
# Preprocessing the datasets
if args.task_name is not None:
sentence1_key, sentence2_key = task_to_keys[args.task_name]
else:
# Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
sentence1_key, sentence2_key = "sentence1", "sentence2"
else:
if len(non_label_column_names) >= 2:
sentence1_key, sentence2_key = non_label_column_names[:2]
else:
sentence1_key, sentence2_key = non_label_column_names[0], None
# Some models have set the order of the labels to use, so let's make sure we do use it.
label_to_id = None
if (
model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
and args.task_name is not None
and not is_regression
):
# Some have all caps in their config, some don't.
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
logger.info(
f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
"Using it!"
)
label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.",
)
elif args.task_name is None:
label_to_id = {v: i for i, v in enumerate(label_list)}
def preprocess_function(examples):
# Tokenize the texts
texts = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True)
if "label" in examples:
if label_to_id is not None:
# Map labels to IDs (not necessary for GLUE tasks)
result["labels"] = [label_to_id[l] for l in examples["label"]]
else:
# In all cases, rename the column to labels because the model will expect that.
result["labels"] = examples["label"]
return result
processed_datasets = raw_datasets.map(
preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names
)
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# Define a summary writer
summary_writer = tensorboard.SummaryWriter(args.output_dir)
summary_writer.hparams(vars(args))
def write_metric(train_metrics, eval_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
num_epochs = int(args.num_train_epochs)
rng = jax.random.PRNGKey(args.seed)
train_batch_size = args.per_device_train_batch_size * jax.local_device_count()
eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()
learning_rate_fn = create_learning_rate_fn(
len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
)
state = create_train_state(model, learning_rate_fn, is_regression, num_labels=num_labels)
# define step functions
def train_step(
state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
) -> Tuple[train_state.TrainState, float]:
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
targets = batch.pop("labels")
def loss_fn(params):
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = state.loss_fn(logits, targets)
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
return new_state, metrics
p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
def eval_step(state, batch):
logits = state.apply_fn(**batch, params=state.params, train=False)[0]
return state.logits_fn(logits)
p_eval_step = jax.pmap(eval_step, axis_name="batch")
if args.task_name is not None:
metric = load_metric("glue", args.task_name)
else:
metric = load_metric("accuracy")
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0
for epoch in range(1, num_epochs + 1):
logger.info(f"Epoch {epoch}")
logger.info(" Training...")
# make sure weights are replicated on each device
state = replicate(state)
train_start = time.time()
train_metrics = []
rng, input_rng, dropout_rng = jax.random.split(rng, 3)
# train
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
dropout_rngs = shard_prng_key(dropout_rng)
state, metrics = p_train_step(state, batch, dropout_rngs)
train_metrics.append(metrics)
train_time += time.time() - train_start
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
logger.info(" Evaluating...")
rng, input_rng = jax.random.split(rng)
# evaluate
for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
labels = batch.pop("labels")
predictions = p_eval_step(state, batch)
metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(eval_dataset) % eval_batch_size
# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# put weights on single device
state = unreplicate(state)
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: jnp.array(v) for k, v in batch.items()}
labels = batch.pop("labels")
predictions = eval_step(state, batch)
metric.add_batch(predictions=predictions, references=labels)
eval_metric = metric.compute()
logger.info(f" Done! Eval metrics: {eval_metric}")
cur_step = epoch * (len(train_dataset) // train_batch_size)
write_metric(train_metrics, eval_metric, train_time, cur_step)
# save last checkpoint
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(args.output_dir, params=params)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment