Unverified Commit d1f5ca1a authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

[FLAX] glue training example refactor (#13815)

* refactor run_flax_glue.py

* updated readme

* rm unused import and args typo fix

* refactor

* make consistent arg name across task

* has_tensorboard check

* argparse -> argument dataclasses

* refactor according to review

* fix
parent db350394
......@@ -85,10 +85,10 @@ class ExamplesTests(TestCasePlus):
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--learning_rate=1e-4
--max_train_steps=10
--num_warmup_steps=2
--eval_steps=2
--warmup_steps=2
--seed=42
--max_length=128
--max_seq_length=128
""".split()
with patch.object(sys, "argv", testargs):
......
......@@ -33,15 +33,16 @@ export TASK_NAME=mrpc
python run_flax_glue.py \
--model_name_or_path bert-base-cased \
--task_name ${TASK_NAME} \
--max_length 128 \
--max_seq_length 128 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--eval_steps 100 \
--output_dir ./$TASK_NAME/ \
--push_to_hub
```
where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
where task name can be one of cola, mnli, mnli_mismatched, mnli_matched, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
Using the command above, the script will train for 3 epochs and run eval after each epoch.
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
......
......@@ -14,18 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
import argparse
import json
import logging
import os
import random
import sys
import time
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
from datasets import load_dataset, load_metric
from tqdm import tqdm
import jax
import jax.numpy as jnp
......@@ -40,13 +43,18 @@ from transformers import (
AutoConfig,
AutoTokenizer,
FlaxAutoModelForSequenceClassification,
HfArgumentParser,
PretrainedConfig,
TrainingArguments,
is_tensorboard_available,
)
from transformers.file_utils import get_full_repo_name
from transformers.utils import check_min_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.16.0.dev0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset
......@@ -66,101 +74,118 @@ task_to_keys = {
}
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
parser.add_argument(
"--task_name",
type=str,
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_slow_tokenizer: Optional[bool] = field(
default=False,
metadata={"help": "If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library)."},
)
cache_dir: Optional[str] = field(
default=None,
help="The name of the glue task to train on.",
choices=list(task_to_keys.keys()),
)
parser.add_argument(
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
)
parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
)
parser.add_argument(
"--max_length",
type=int,
default=128,
help=(
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
" sequences shorter will be padded."
),
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
task_name: Optional[str] = field(
default=None, metadata={"help": f"The name of the glue task to train on. choices {list(task_to_keys.keys())}"}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
)
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
)
text_column_name: Optional[str] = field(
default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."}
)
label_column_name: Optional[str] = field(
default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."}
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
)
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
preprocessing_num_workers: Optional[int] = field(
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
metadata={"help": "The number of processes to use for the preprocessing."},
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
max_seq_length: int = field(
default=None,
metadata={
"help": "The maximum total input sequence length after tokenization. If set, sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=3, help="A seed for reproducible training.")
parser.add_argument(
"--push_to_hub",
action="store_true",
help="If passed, model checkpoints and tensorboard logs will be pushed to the hub",
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
parser.add_argument(
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
args = parser.parse_args()
# Sanity checks
if args.task_name is None and args.train_file is None and args.validation_file is None:
raise ValueError("Need either a task name or a training/validation file.")
def __post_init__(self):
if self.task_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
return args
self.task_name = self.task_name.lower() if type(self.task_name) == str else self.task_name
def create_train_state(
......@@ -249,7 +274,7 @@ def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
for perm in perms:
batch = dataset[perm]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch
......@@ -259,14 +284,20 @@ def glue_eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
for i in range(len(dataset) // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch
def main():
args = parse_args()
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
......@@ -284,12 +315,14 @@ def main():
transformers.utils.logging.set_verbosity_error()
# Handle the repository creation
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token)
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
......@@ -303,24 +336,24 @@ def main():
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if args.task_name is not None:
if data_args.task_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset("glue", args.task_name)
raw_datasets = load_dataset("glue", data_args.task_name)
else:
# Loading the dataset from local csv or json file.
data_files = {}
if args.train_file is not None:
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file
extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1]
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Labels
if args.task_name is not None:
is_regression = args.task_name == "stsb"
if data_args.task_name is not None:
is_regression = data_args.task_name == "stsb"
if not is_regression:
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)
......@@ -339,13 +372,17 @@ def main():
num_labels = len(label_list)
# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)
config = AutoConfig.from_pretrained(
model_args.model_name_or_path, num_labels=num_labels, finetuning_task=data_args.task_name
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, use_fast=not model_args.use_slow_tokenizer
)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_args.model_name_or_path, config=config)
# Preprocessing the datasets
if args.task_name is not None:
sentence1_key, sentence2_key = task_to_keys[args.task_name]
if data_args.task_name is not None:
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
else:
# Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
......@@ -361,7 +398,7 @@ def main():
label_to_id = None
if (
model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
and args.task_name is not None
and data_args.task_name is not None
and not is_regression
):
# Some have all caps in their config, some don't.
......@@ -378,7 +415,7 @@ def main():
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.",
)
elif args.task_name is None:
elif data_args.task_name is None:
label_to_id = {v: i for i, v in enumerate(label_list)}
def preprocess_function(examples):
......@@ -386,7 +423,7 @@ def main():
texts = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True)
result = tokenizer(*texts, padding="max_length", max_length=data_args.max_seq_length, truncation=True)
if "label" in examples:
if label_to_id is not None:
......@@ -402,7 +439,7 @@ def main():
)
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
eval_dataset = processed_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
......@@ -414,8 +451,8 @@ def main():
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(args.output_dir)
summary_writer.hparams(vars(args))
summary_writer = SummaryWriter(training_args.output_dir)
summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
except ImportError as ie:
has_tensorboard = False
logger.warning(
......@@ -427,7 +464,7 @@ def main():
"Please run pip install tensorboard to enable."
)
def write_metric(train_metrics, eval_metrics, train_time, step):
def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
......@@ -436,22 +473,27 @@ def main():
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
num_epochs = int(args.num_train_epochs)
rng = jax.random.PRNGKey(args.seed)
num_epochs = int(training_args.num_train_epochs)
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
train_batch_size = args.per_device_train_batch_size * jax.local_device_count()
eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()
train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()
learning_rate_fn = create_learning_rate_fn(
len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
len(train_dataset),
train_batch_size,
training_args.num_train_epochs,
training_args.warmup_steps,
training_args.learning_rate,
)
state = create_train_state(
model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay
model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=training_args.weight_decay
)
# define step functions
......@@ -482,8 +524,8 @@ def main():
p_eval_step = jax.pmap(eval_step, axis_name="batch")
if args.task_name is not None:
metric = load_metric("glue", args.task_name)
if data_args.task_name is not None:
metric = load_metric("glue", data_args.task_name)
else:
metric = load_metric("accuracy")
......@@ -493,25 +535,56 @@ def main():
# make sure weights are replicated on each device
state = replicate(state)
for epoch in range(1, num_epochs + 1):
logger.info(f"Epoch {epoch}")
logger.info(" Training...")
steps_per_epoch = len(train_dataset) // train_batch_size
total_steps = steps_per_epoch * num_epochs
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
train_metrics = []
# Create sampling rng
rng, input_rng = jax.random.split(rng)
# train
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs)
train_metrics.append(metrics)
train_loader = glue_train_data_collator(input_rng, train_dataset, train_batch_size)
for step, batch in enumerate(
tqdm(
train_loader,
total=steps_per_epoch,
desc="Training...",
position=1,
),
):
state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
train_metrics.append(train_metric)
cur_step = (epoch * steps_per_epoch) + (step + 1)
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics
train_metric = unreplicate(train_metric)
train_time += time.time() - train_start
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
logger.info(" Evaluating...")
epochs.write(
f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
)
train_metrics = []
if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:
eval_metrics = {}
# evaluate
for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(
eval_loader,
total=len(eval_dataset) // eval_batch_size,
desc="Evaluating ...",
position=2,
):
labels = batch.pop("labels")
predictions = p_eval_step(state, batch)
metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
......@@ -523,33 +596,33 @@ def main():
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = {k: np.array(v) for k, v in batch.items()}
labels = batch.pop("labels")
predictions = eval_step(unreplicate(state), batch)
metric.add_batch(predictions=predictions, references=labels)
eval_metric = metric.compute()
logger.info(f" Done! Eval metrics: {eval_metric}")
cur_step = epoch * (len(train_dataset) // train_batch_size)
logger.info(f"Step... ({cur_step}/{total_steps} | Eval metrics: {eval_metric})")
# Save metrics
if has_tensorboard and jax.process_index() == 0:
write_metric(train_metrics, eval_metric, train_time, cur_step)
write_eval_metric(summary_writer, eval_metrics, cur_step)
if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(args.output_dir, params=params)
tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
params = jax.device_get(unreplicate(state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
# save the eval metrics in json
if jax.process_index() == 0:
eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
path = os.path.join(args.output_dir, "eval_results.json")
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metric, f, indent=4, sort_keys=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment