"docs/source/vscode:/vscode.git/clone" did not exist on "867f3950fa908632ddb3564873293b620d73c2dc"
Unverified Commit d1f5ca1a authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

[FLAX] glue training example refactor (#13815)

* refactor run_flax_glue.py

* updated readme

* rm unused import and args typo fix

* refactor

* make consistent arg name across task

* has_tensorboard check

* argparse -> argument dataclasses

* refactor according to review

* fix
parent db350394
...@@ -85,10 +85,10 @@ class ExamplesTests(TestCasePlus): ...@@ -85,10 +85,10 @@ class ExamplesTests(TestCasePlus):
--per_device_train_batch_size=2 --per_device_train_batch_size=2
--per_device_eval_batch_size=1 --per_device_eval_batch_size=1
--learning_rate=1e-4 --learning_rate=1e-4
--max_train_steps=10 --eval_steps=2
--num_warmup_steps=2 --warmup_steps=2
--seed=42 --seed=42
--max_length=128 --max_seq_length=128
""".split() """.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
......
...@@ -33,15 +33,16 @@ export TASK_NAME=mrpc ...@@ -33,15 +33,16 @@ export TASK_NAME=mrpc
python run_flax_glue.py \ python run_flax_glue.py \
--model_name_or_path bert-base-cased \ --model_name_or_path bert-base-cased \
--task_name ${TASK_NAME} \ --task_name ${TASK_NAME} \
--max_length 128 \ --max_seq_length 128 \
--learning_rate 2e-5 \ --learning_rate 2e-5 \
--num_train_epochs 3 \ --num_train_epochs 3 \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
--eval_steps 100 \
--output_dir ./$TASK_NAME/ \ --output_dir ./$TASK_NAME/ \
--push_to_hub --push_to_hub
``` ```
where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli. where task name can be one of cola, mnli, mnli_mismatched, mnli_matched, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
Using the command above, the script will train for 3 epochs and run eval after each epoch. Using the command above, the script will train for 3 epochs and run eval after each epoch.
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`. Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
......
...@@ -14,18 +14,21 @@ ...@@ -14,18 +14,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE.""" """ Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
import argparse
import json import json
import logging import logging
import os import os
import random import random
import sys
import time import time
from dataclasses import dataclass, field
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Tuple from typing import Any, Callable, Dict, Optional, Tuple
import datasets import datasets
import numpy as np
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from tqdm import tqdm
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -40,13 +43,18 @@ from transformers import ( ...@@ -40,13 +43,18 @@ from transformers import (
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
FlaxAutoModelForSequenceClassification, FlaxAutoModelForSequenceClassification,
HfArgumentParser,
PretrainedConfig, PretrainedConfig,
TrainingArguments,
is_tensorboard_available, is_tensorboard_available,
) )
from transformers.file_utils import get_full_repo_name from transformers.file_utils import get_full_repo_name
from transformers.utils import check_min_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.16.0.dev0")
Array = Any Array = Any
Dataset = datasets.arrow_dataset.Dataset Dataset = datasets.arrow_dataset.Dataset
...@@ -66,101 +74,118 @@ task_to_keys = { ...@@ -66,101 +74,118 @@ task_to_keys = {
} }
def parse_args(): @dataclass
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") class ModelArguments:
parser.add_argument( """
"--task_name", Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
type=str, """
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_slow_tokenizer: Optional[bool] = field(
default=False,
metadata={"help": "If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library)."},
)
cache_dir: Optional[str] = field(
default=None, default=None,
help="The name of the glue task to train on.", metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
choices=list(task_to_keys.keys()), )
) model_revision: str = field(
parser.add_argument( default="main",
"--train_file", type=str, default=None, help="A csv or a json file containing the training data." metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
) )
parser.add_argument( use_auth_token: bool = field(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." default=False,
) metadata={
parser.add_argument( "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"--max_length", "with private models)."
type=int, },
default=128, )
help=(
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
" sequences shorter will be padded." @dataclass
), class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
task_name: Optional[str] = field(
default=None, metadata={"help": f"The name of the glue task to train on. choices {list(task_to_keys.keys())}"}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
)
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
)
text_column_name: Optional[str] = field(
default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."}
)
label_column_name: Optional[str] = field(
default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."}
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
parser.add_argument( preprocessing_num_workers: Optional[int] = field(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
)
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None, default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", metadata={"help": "The number of processes to use for the preprocessing."},
) )
parser.add_argument( max_seq_length: int = field(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." default=None,
metadata={
"help": "The maximum total input sequence length after tokenization. If set, sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
) )
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") max_train_samples: Optional[int] = field(
parser.add_argument("--seed", type=int, default=3, help="A seed for reproducible training.") default=None,
parser.add_argument( metadata={
"--push_to_hub", "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
action="store_true", "value if set."
help="If passed, model checkpoints and tensorboard logs will be pushed to the hub", },
) )
parser.add_argument( max_eval_samples: Optional[int] = field(
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
) )
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
args = parser.parse_args()
# Sanity checks def __post_init__(self):
if args.task_name is None and args.train_file is None and args.validation_file is None: if self.task_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a task name or a training/validation file.") raise ValueError("Need either a dataset name or a training/validation file.")
else: else:
if args.train_file is not None: if self.train_file is not None:
extension = args.train_file.split(".")[-1] extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if args.validation_file is not None: if self.validation_file is not None:
extension = args.validation_file.split(".")[-1] extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
self.task_name = self.task_name.lower() if type(self.task_name) == str else self.task_name
if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
return args
def create_train_state( def create_train_state(
...@@ -249,7 +274,7 @@ def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int): ...@@ -249,7 +274,7 @@ def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
for perm in perms: for perm in perms:
batch = dataset[perm] batch = dataset[perm]
batch = {k: jnp.array(v) for k, v in batch.items()} batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch) batch = shard(batch)
yield batch yield batch
...@@ -259,14 +284,20 @@ def glue_eval_data_collator(dataset: Dataset, batch_size: int): ...@@ -259,14 +284,20 @@ def glue_eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices.""" """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
for i in range(len(dataset) // batch_size): for i in range(len(dataset) // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size] batch = dataset[i * batch_size : (i + 1) * batch_size]
batch = {k: jnp.array(v) for k, v in batch.items()} batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch) batch = shard(batch)
yield batch yield batch
def main(): def main():
args = parse_args() parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
...@@ -284,12 +315,14 @@ def main(): ...@@ -284,12 +315,14 @@ def main():
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
# Handle the repository creation # Handle the repository creation
if args.push_to_hub: if training_args.push_to_hub:
if args.hub_model_id is None: if training_args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token) repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else: else:
repo_name = args.hub_model_id repo_name = training_args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
...@@ -303,24 +336,24 @@ def main(): ...@@ -303,24 +336,24 @@ def main():
# In distributed training, the load_dataset function guarantee that only one local process can concurrently # In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset. # download the dataset.
if args.task_name is not None: if data_args.task_name is not None:
# Downloading and loading a dataset from the hub. # Downloading and loading a dataset from the hub.
raw_datasets = load_dataset("glue", args.task_name) raw_datasets = load_dataset("glue", data_args.task_name)
else: else:
# Loading the dataset from local csv or json file. # Loading the dataset from local csv or json file.
data_files = {} data_files = {}
if args.train_file is not None: if data_args.train_file is not None:
data_files["train"] = args.train_file data_files["train"] = data_args.train_file
if args.validation_file is not None: if data_args.validation_file is not None:
data_files["validation"] = args.validation_file data_files["validation"] = data_args.validation_file
extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1] extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files) raw_datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset at # See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html. # https://huggingface.co/docs/datasets/loading_datasets.html.
# Labels # Labels
if args.task_name is not None: if data_args.task_name is not None:
is_regression = args.task_name == "stsb" is_regression = data_args.task_name == "stsb"
if not is_regression: if not is_regression:
label_list = raw_datasets["train"].features["label"].names label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list) num_labels = len(label_list)
...@@ -339,13 +372,17 @@ def main(): ...@@ -339,13 +372,17 @@ def main():
num_labels = len(label_list) num_labels = len(label_list)
# Load pretrained model and tokenizer # Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) config = AutoConfig.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) model_args.model_name_or_path, num_labels=num_labels, finetuning_task=data_args.task_name
model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config) )
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, use_fast=not model_args.use_slow_tokenizer
)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_args.model_name_or_path, config=config)
# Preprocessing the datasets # Preprocessing the datasets
if args.task_name is not None: if data_args.task_name is not None:
sentence1_key, sentence2_key = task_to_keys[args.task_name] sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
else: else:
# Again, we try to have some nice defaults but don't hesitate to tweak to your use case. # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
...@@ -361,7 +398,7 @@ def main(): ...@@ -361,7 +398,7 @@ def main():
label_to_id = None label_to_id = None
if ( if (
model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
and args.task_name is not None and data_args.task_name is not None
and not is_regression and not is_regression
): ):
# Some have all caps in their config, some don't. # Some have all caps in their config, some don't.
...@@ -378,7 +415,7 @@ def main(): ...@@ -378,7 +415,7 @@ def main():
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.", "\nIgnoring the model labels as a result.",
) )
elif args.task_name is None: elif data_args.task_name is None:
label_to_id = {v: i for i, v in enumerate(label_list)} label_to_id = {v: i for i, v in enumerate(label_list)}
def preprocess_function(examples): def preprocess_function(examples):
...@@ -386,7 +423,7 @@ def main(): ...@@ -386,7 +423,7 @@ def main():
texts = ( texts = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
) )
result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True) result = tokenizer(*texts, padding="max_length", max_length=data_args.max_seq_length, truncation=True)
if "label" in examples: if "label" in examples:
if label_to_id is not None: if label_to_id is not None:
...@@ -402,7 +439,7 @@ def main(): ...@@ -402,7 +439,7 @@ def main():
) )
train_dataset = processed_datasets["train"] train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] eval_dataset = processed_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
# Log a few random samples from the training set: # Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3): for index in random.sample(range(len(train_dataset)), 3):
...@@ -414,8 +451,8 @@ def main(): ...@@ -414,8 +451,8 @@ def main():
try: try:
from flax.metrics.tensorboard import SummaryWriter from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(args.output_dir) summary_writer = SummaryWriter(training_args.output_dir)
summary_writer.hparams(vars(args)) summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
except ImportError as ie: except ImportError as ie:
has_tensorboard = False has_tensorboard = False
logger.warning( logger.warning(
...@@ -427,7 +464,7 @@ def main(): ...@@ -427,7 +464,7 @@ def main():
"Please run pip install tensorboard to enable." "Please run pip install tensorboard to enable."
) )
def write_metric(train_metrics, eval_metrics, train_time, step): def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step) summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics) train_metrics = get_metrics(train_metrics)
...@@ -436,22 +473,27 @@ def main(): ...@@ -436,22 +473,27 @@ def main():
for i, val in enumerate(vals): for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items(): for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step) summary_writer.scalar(f"eval_{metric_name}", value, step)
num_epochs = int(args.num_train_epochs) num_epochs = int(training_args.num_train_epochs)
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count()) dropout_rngs = jax.random.split(rng, jax.local_device_count())
train_batch_size = args.per_device_train_batch_size * jax.local_device_count() train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count() eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()
learning_rate_fn = create_learning_rate_fn( learning_rate_fn = create_learning_rate_fn(
len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate len(train_dataset),
train_batch_size,
training_args.num_train_epochs,
training_args.warmup_steps,
training_args.learning_rate,
) )
state = create_train_state( state = create_train_state(
model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=training_args.weight_decay
) )
# define step functions # define step functions
...@@ -482,8 +524,8 @@ def main(): ...@@ -482,8 +524,8 @@ def main():
p_eval_step = jax.pmap(eval_step, axis_name="batch") p_eval_step = jax.pmap(eval_step, axis_name="batch")
if args.task_name is not None: if data_args.task_name is not None:
metric = load_metric("glue", args.task_name) metric = load_metric("glue", data_args.task_name)
else: else:
metric = load_metric("accuracy") metric = load_metric("accuracy")
...@@ -493,25 +535,56 @@ def main(): ...@@ -493,25 +535,56 @@ def main():
# make sure weights are replicated on each device # make sure weights are replicated on each device
state = replicate(state) state = replicate(state)
for epoch in range(1, num_epochs + 1): steps_per_epoch = len(train_dataset) // train_batch_size
logger.info(f"Epoch {epoch}") total_steps = steps_per_epoch * num_epochs
logger.info(" Training...") epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time() train_start = time.time()
train_metrics = [] train_metrics = []
# Create sampling rng
rng, input_rng = jax.random.split(rng) rng, input_rng = jax.random.split(rng)
# train # train
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size): train_loader = glue_train_data_collator(input_rng, train_dataset, train_batch_size)
state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs) for step, batch in enumerate(
train_metrics.append(metrics) tqdm(
train_loader,
total=steps_per_epoch,
desc="Training...",
position=1,
),
):
state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
train_metrics.append(train_metric)
cur_step = (epoch * steps_per_epoch) + (step + 1)
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics
train_metric = unreplicate(train_metric)
train_time += time.time() - train_start train_time += time.time() - train_start
logger.info(f" Done! Training metrics: {unreplicate(metrics)}") if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
logger.info(" Evaluating...") epochs.write(
f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
)
train_metrics = []
if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:
eval_metrics = {}
# evaluate # evaluate
for batch in glue_eval_data_collator(eval_dataset, eval_batch_size): eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(
eval_loader,
total=len(eval_dataset) // eval_batch_size,
desc="Evaluating ...",
position=2,
):
labels = batch.pop("labels") labels = batch.pop("labels")
predictions = p_eval_step(state, batch) predictions = p_eval_step(state, batch)
metric.add_batch(predictions=chain(*predictions), references=chain(*labels)) metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
...@@ -523,33 +596,33 @@ def main(): ...@@ -523,33 +596,33 @@ def main():
if num_leftover_samples > 0 and jax.process_index() == 0: if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples # take leftover samples
batch = eval_dataset[-num_leftover_samples:] batch = eval_dataset[-num_leftover_samples:]
batch = {k: jnp.array(v) for k, v in batch.items()} batch = {k: np.array(v) for k, v in batch.items()}
labels = batch.pop("labels") labels = batch.pop("labels")
predictions = eval_step(unreplicate(state), batch) predictions = eval_step(unreplicate(state), batch)
metric.add_batch(predictions=predictions, references=labels) metric.add_batch(predictions=predictions, references=labels)
eval_metric = metric.compute() eval_metric = metric.compute()
logger.info(f" Done! Eval metrics: {eval_metric}")
cur_step = epoch * (len(train_dataset) // train_batch_size) logger.info(f"Step... ({cur_step}/{total_steps} | Eval metrics: {eval_metric})")
# Save metrics
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
write_metric(train_metrics, eval_metric, train_time, cur_step) write_eval_metric(summary_writer, eval_metrics, cur_step)
if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) params = jax.device_get(unreplicate(state.params))
model.save_pretrained(args.output_dir, params=params) model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(training_args.output_dir)
if args.push_to_hub: if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
# save the eval metrics in json # save the eval metrics in json
if jax.process_index() == 0: if jax.process_index() == 0:
eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()} eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
path = os.path.join(args.output_dir, "eval_results.json") path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f: with open(path, "w") as f:
json.dump(eval_metric, f, indent=4, sort_keys=True) json.dump(eval_metric, f, indent=4, sort_keys=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment