"...composable_kernel_rocm.git" did not exist on "e8c19535f7735acd72917ff2af6520ba080bf980"
Unverified Commit 7490a97c authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Flax] Fix incomplete batches in example scripts (#17863)

* [Flax] Fix incomplete batches in example scripts

* fix dataloader batching

* convert jnp batch idxs to np array

* add missing `pad_shard_unpad` to final prediction generate step

* only `pad_shard_unpad` at inference time

* merge conflicts

* remove incomplete batch step from eval

* fix run_qa.py

* add `pad_shard_unpad` to run_flax_ner.py

* add `pad_shard_unpad` to run_flax_glue.py

* add `pad_shard_unpad` to run_image_classification.py

* make style

* fix mlm flax eval batches

* remove redundant imports
parent 9caf68a6
...@@ -43,7 +43,7 @@ import jax.numpy as jnp ...@@ -43,7 +43,7 @@ import jax.numpy as jnp
import optax import optax
import transformers import transformers
from flax import jax_utils, traverse_util from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository from huggingface_hub import Repository
...@@ -264,20 +264,24 @@ class TrainState(train_state.TrainState): ...@@ -264,20 +264,24 @@ class TrainState(train_state.TrainState):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False): def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
""" """
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
Shuffle batches if `shuffle` is `True`. and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
""" """
steps_per_epoch = len(dataset) // batch_size
if shuffle: if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset)) batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else: else:
batch_idx = jnp.arange(len(dataset)) batch_idx = np.arange(len(dataset))
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. if drop_last:
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx: for idx in batch_idx:
batch = dataset[idx] batch = dataset[idx]
...@@ -621,7 +625,8 @@ def main(): ...@@ -621,7 +625,8 @@ def main():
# Store some constant # Store some constant
num_epochs = int(training_args.num_train_epochs) num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs total_train_steps = steps_per_epoch * num_epochs
...@@ -764,13 +769,14 @@ def main(): ...@@ -764,13 +769,14 @@ def main():
if cur_step % training_args.eval_steps == 0 and cur_step > 0: if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
eval_metrics = [] eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
eval_steps = len(eval_dataset) // eval_batch_size eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward # Model forward
batch = next(eval_loader) batch = next(eval_loader)
batch = shard(batch) metrics = pad_shard_unpad(p_eval_step, static_return=True)(
metrics = p_eval_step(state.params, batch) state.params, batch, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics) eval_metrics.append(metrics)
# normalize eval metrics # normalize eval metrics
...@@ -806,12 +812,14 @@ def main(): ...@@ -806,12 +812,14 @@ def main():
# Eval after training # Eval after training
if training_args.do_eval: if training_args.do_eval:
eval_metrics = [] eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
eval_steps = len(eval_dataset) // eval_batch_size eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward # Model forward
batch = shard(next(eval_loader)) batch = next(eval_loader)
metrics = p_eval_step(state.params, batch) metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, batch, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics) eval_metrics.append(metrics)
# normalize eval metrics # normalize eval metrics
......
...@@ -43,6 +43,7 @@ import jax ...@@ -43,6 +43,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
from flax import jax_utils, traverse_util from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository from huggingface_hub import Repository
...@@ -326,15 +327,20 @@ class FlaxDataCollatorForLanguageModeling: ...@@ -326,15 +327,20 @@ class FlaxDataCollatorForLanguageModeling:
return inputs, labels return inputs, labels
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray: def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
num_samples = len(samples_idx) num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size if drop_last:
samples_to_remove = num_samples % batch_size
if samples_to_remove != 0: if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove] samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size sections_split = num_samples // batch_size
batch_idx = np.split(samples_idx, sections_split) samples_idx = samples_idx.reshape((sections_split, batch_size))
return batch_idx else:
sections_split = math.ceil(num_samples / batch_size)
samples_idx = np.array_split(samples_idx, sections_split)
return samples_idx
def write_train_metric(summary_writer, train_metrics, train_time, step): def write_train_metric(summary_writer, train_metrics, train_time, step):
...@@ -632,12 +638,14 @@ def main(): ...@@ -632,12 +638,14 @@ def main():
config, config,
seed=training_args.seed, seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype), dtype=getattr(jnp, model_args.dtype),
use_auth_token=True if model_args.use_auth_token else None,
) )
# Store some constant # Store some constant
num_epochs = int(training_args.num_train_epochs) num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
...@@ -796,7 +804,7 @@ def main(): ...@@ -796,7 +804,7 @@ def main():
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples) eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
eval_metrics = [] eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
...@@ -804,8 +812,9 @@ def main(): ...@@ -804,8 +812,9 @@ def main():
model_inputs = data_collator(samples, pad_to_multiple_of=16) model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward # Model forward
model_inputs = shard(model_inputs.data) metrics = pad_shard_unpad(p_eval_step, static_return=True)(
metrics = p_eval_step(state.params, model_inputs) state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics) eval_metrics.append(metrics)
# normalize eval metrics # normalize eval metrics
...@@ -835,7 +844,7 @@ def main(): ...@@ -835,7 +844,7 @@ def main():
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples) eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
eval_metrics = [] eval_metrics = []
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
...@@ -843,8 +852,9 @@ def main(): ...@@ -843,8 +852,9 @@ def main():
model_inputs = data_collator(samples, pad_to_multiple_of=16) model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward # Model forward
model_inputs = shard(model_inputs.data) metrics = pad_shard_unpad(p_eval_step, static_return=True)(
metrics = p_eval_step(state.params, model_inputs) state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics) eval_metrics.append(metrics)
# normalize eval metrics # normalize eval metrics
......
...@@ -21,6 +21,7 @@ https://huggingface.co/models?filter=t5 ...@@ -21,6 +21,7 @@ https://huggingface.co/models?filter=t5
""" """
import json import json
import logging import logging
import math
import os import os
import sys import sys
import time import time
...@@ -41,6 +42,7 @@ import jax ...@@ -41,6 +42,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
from flax import jax_utils, traverse_util from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository from huggingface_hub import Repository
...@@ -326,6 +328,7 @@ class FlaxDataCollatorForT5MLM: ...@@ -326,6 +328,7 @@ class FlaxDataCollatorForT5MLM:
decoder_start_token_id: int decoder_start_token_id: int
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
# convert list to dict and tensorize input # convert list to dict and tensorize input
batch = BatchEncoding( batch = BatchEncoding(
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()} {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
...@@ -394,6 +397,7 @@ class FlaxDataCollatorForT5MLM: ...@@ -394,6 +397,7 @@ class FlaxDataCollatorForT5MLM:
return input_ids return input_ids
def random_spans_noise_mask(self, length): def random_spans_noise_mask(self, length):
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ . """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens. Noise mask consisting of random spans of noise tokens.
...@@ -457,15 +461,20 @@ class FlaxDataCollatorForT5MLM: ...@@ -457,15 +461,20 @@ class FlaxDataCollatorForT5MLM:
return is_noise[:orig_length] return is_noise[:orig_length]
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray: def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
num_samples = len(samples_idx) num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size if drop_last:
samples_to_remove = num_samples % batch_size
if samples_to_remove != 0: if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove] samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size sections_split = num_samples // batch_size
batch_idx = np.split(samples_idx, sections_split) samples_idx = samples_idx.reshape((sections_split, batch_size))
return batch_idx else:
sections_split = math.ceil(num_samples / batch_size)
samples_idx = np.array_split(samples_idx, sections_split)
return samples_idx
def write_train_metric(summary_writer, train_metrics, train_time, step): def write_train_metric(summary_writer, train_metrics, train_time, step):
...@@ -737,6 +746,7 @@ def main(): ...@@ -737,6 +746,7 @@ def main():
config, config,
seed=training_args.seed, seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype), dtype=getattr(jnp, model_args.dtype),
use_auth_token=True if model_args.use_auth_token else None,
) )
# Data collator # Data collator
...@@ -754,7 +764,8 @@ def main(): ...@@ -754,7 +764,8 @@ def main():
# Store some constant # Store some constant
num_epochs = int(training_args.num_train_epochs) num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
...@@ -915,7 +926,7 @@ def main(): ...@@ -915,7 +926,7 @@ def main():
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples) eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
eval_metrics = [] eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
...@@ -923,8 +934,9 @@ def main(): ...@@ -923,8 +934,9 @@ def main():
model_inputs = data_collator(samples) model_inputs = data_collator(samples)
# Model forward # Model forward
model_inputs = shard(model_inputs.data) metrics = pad_shard_unpad(p_eval_step, static_return=True)(
metrics = p_eval_step(state.params, model_inputs) state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics) eval_metrics.append(metrics)
# get eval metrics # get eval metrics
...@@ -952,7 +964,7 @@ def main(): ...@@ -952,7 +964,7 @@ def main():
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples) eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
eval_metrics = [] eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
...@@ -960,8 +972,9 @@ def main(): ...@@ -960,8 +972,9 @@ def main():
model_inputs = data_collator(samples) model_inputs = data_collator(samples)
# Model forward # Model forward
model_inputs = shard(model_inputs.data) metrics = pad_shard_unpad(p_eval_step, static_return=True)(
metrics = p_eval_step(state.params, model_inputs) state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics) eval_metrics.append(metrics)
# get eval metrics # get eval metrics
......
...@@ -20,13 +20,13 @@ Fine-tuning the library models for question answering. ...@@ -20,13 +20,13 @@ Fine-tuning the library models for question answering.
import json import json
import logging import logging
import math
import os import os
import random import random
import sys import sys
import time import time
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from enum import Enum from enum import Enum
from itertools import chain
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
...@@ -40,7 +40,7 @@ import jax.numpy as jnp ...@@ -40,7 +40,7 @@ import jax.numpy as jnp
import optax import optax
import transformers import transformers
from flax import struct, traverse_util from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository from huggingface_hub import Repository
...@@ -406,11 +406,15 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int): ...@@ -406,11 +406,15 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
# region eval data iterator # region eval data iterator
def eval_data_collator(dataset: Dataset, batch_size: int): def eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices.""" """Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
for i in range(len(dataset) // batch_size): batch_idx = np.arange(len(dataset))
batch = dataset[i * batch_size : (i + 1) * batch_size]
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()} batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch yield batch
...@@ -856,8 +860,9 @@ def main(): ...@@ -856,8 +860,9 @@ def main():
rng = jax.random.PRNGKey(training_args.seed) rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count()) dropout_rngs = jax.random.split(rng, jax.local_device_count())
train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count() train_batch_size = int(training_args.per_device_train_batch_size) * jax.local_device_count()
eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.local_device_count()
# endregion # endregion
# region Load model # region Load model
...@@ -975,32 +980,17 @@ def main(): ...@@ -975,32 +980,17 @@ def main():
# evaluate # evaluate
for batch in tqdm( for batch in tqdm(
eval_data_collator(eval_dataset, eval_batch_size), eval_data_collator(eval_dataset, eval_batch_size),
total=len(eval_dataset) // eval_batch_size, total=math.ceil(len(eval_dataset) / eval_batch_size),
desc="Evaluating ...", desc="Evaluating ...",
position=2, position=2,
): ):
_ = batch.pop("example_id") _ = batch.pop("example_id")
_ = batch.pop("offset_mapping") _ = batch.pop("offset_mapping")
predictions = p_eval_step(state, batch) predictions = pad_shard_unpad(p_eval_step)(
start_logits = np.array([pred for pred in chain(*predictions[0])]) state, batch, min_device_batch=per_device_eval_batch_size
end_logits = np.array([pred for pred in chain(*predictions[1])]) )
all_start_logits.append(start_logits) start_logits = np.array(predictions[0])
all_end_logits.append(end_logits) end_logits = np.array(predictions[1])
# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(eval_dataset) % eval_batch_size
# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: np.array(v) for k, v in batch.items()}
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
predictions = eval_step(unreplicate(state), batch)
start_logits = np.array([pred for pred in predictions[0]])
end_logits = np.array([pred for pred in predictions[1]])
all_start_logits.append(start_logits) all_start_logits.append(start_logits)
all_end_logits.append(end_logits) all_end_logits.append(end_logits)
...@@ -1039,30 +1029,15 @@ def main(): ...@@ -1039,30 +1029,15 @@ def main():
all_start_logits = [] all_start_logits = []
all_end_logits = [] all_end_logits = []
eva_loader = eval_data_collator(eval_dataset, eval_batch_size) eval_loader = eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(eva_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2): for batch in tqdm(
_ = batch.pop("example_id") eval_loader, total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2
_ = batch.pop("offset_mapping") ):
predictions = p_eval_step(state, batch)
start_logits = np.array([pred for pred in chain(*predictions[0])])
end_logits = np.array([pred for pred in chain(*predictions[1])])
all_start_logits.append(start_logits)
all_end_logits.append(end_logits)
# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(eval_dataset) % eval_batch_size
# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: np.array(v) for k, v in batch.items()}
_ = batch.pop("example_id") _ = batch.pop("example_id")
_ = batch.pop("offset_mapping") _ = batch.pop("offset_mapping")
predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size)
predictions = eval_step(unreplicate(state), batch) start_logits = np.array(predictions[0])
start_logits = np.array([pred for pred in predictions[0]]) end_logits = np.array(predictions[1])
end_logits = np.array([pred for pred in predictions[1]])
all_start_logits.append(start_logits) all_start_logits.append(start_logits)
all_end_logits.append(end_logits) all_end_logits.append(end_logits)
......
...@@ -20,6 +20,7 @@ Fine-tuning the library models for summarization. ...@@ -20,6 +20,7 @@ Fine-tuning the library models for summarization.
import json import json
import logging import logging
import math
import os import os
import sys import sys
import time import time
...@@ -41,7 +42,7 @@ import optax ...@@ -41,7 +42,7 @@ import optax
import transformers import transformers
from filelock import FileLock from filelock import FileLock
from flax import jax_utils, traverse_util from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository from huggingface_hub import Repository
...@@ -335,26 +336,28 @@ class TrainState(train_state.TrainState): ...@@ -335,26 +336,28 @@ class TrainState(train_state.TrainState):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False): def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
""" """
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
Shuffle batches if `shuffle` is `True`. and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
""" """
steps_per_epoch = len(dataset) // batch_size
if shuffle: if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset)) batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else: else:
batch_idx = jnp.arange(len(dataset)) batch_idx = np.arange(len(dataset))
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. if drop_last:
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx: for idx in batch_idx:
batch = dataset[idx] batch = dataset[idx]
batch = {k: jnp.array(v) for k, v in batch.items()} batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch yield batch
...@@ -706,7 +709,8 @@ def main(): ...@@ -706,7 +709,8 @@ def main():
# Store some constant # Store some constant
num_epochs = int(training_args.num_train_epochs) num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs total_train_steps = steps_per_epoch * num_epochs
...@@ -850,6 +854,7 @@ def main(): ...@@ -850,6 +854,7 @@ def main():
# train # train
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader) batch = next(train_loader)
batch = shard(batch)
state, train_metric = p_train_step(state, batch) state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric) train_metrics.append(train_metric)
...@@ -867,21 +872,23 @@ def main(): ...@@ -867,21 +872,23 @@ def main():
eval_preds = [] eval_preds = []
eval_labels = [] eval_labels = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
eval_steps = len(eval_dataset) // eval_batch_size eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward # Model forward
batch = next(eval_loader) batch = next(eval_loader)
labels = batch["labels"] labels = batch["labels"]
metrics = p_eval_step(state.params, batch) metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, batch, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics) eval_metrics.append(metrics)
# generation # generation
if data_args.predict_with_generate: if data_args.predict_with_generate:
generated_ids = p_generate_step(state.params, batch) generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) eval_labels.extend(labels)
# normalize eval metrics # normalize eval metrics
eval_metrics = get_metrics(eval_metrics) eval_metrics = get_metrics(eval_metrics)
...@@ -920,21 +927,23 @@ def main(): ...@@ -920,21 +927,23 @@ def main():
pred_generations = [] pred_generations = []
pred_labels = [] pred_labels = []
pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size) pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size, drop_last=False)
pred_steps = len(predict_dataset) // eval_batch_size pred_steps = math.ceil(len(predict_dataset) / eval_batch_size)
for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False): for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
# Model forward # Model forward
batch = next(pred_loader) batch = next(pred_loader)
labels = batch["labels"] labels = batch["labels"]
metrics = p_eval_step(state.params, batch) metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, batch, min_device_batch=per_device_eval_batch_size
)
pred_metrics.append(metrics) pred_metrics.append(metrics)
# generation # generation
if data_args.predict_with_generate: if data_args.predict_with_generate:
generated_ids = p_generate_step(state.params, batch) generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) pred_labels.extend(labels)
# normalize prediction metrics # normalize prediction metrics
pred_metrics = get_metrics(pred_metrics) pred_metrics = get_metrics(pred_metrics)
......
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE.""" """ Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
import json import json
import logging import logging
import math
import os import os
import random import random
import sys import sys
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
...@@ -35,7 +35,7 @@ import jax.numpy as jnp ...@@ -35,7 +35,7 @@ import jax.numpy as jnp
import optax import optax
import transformers import transformers
from flax import struct, traverse_util from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository from huggingface_hub import Repository
...@@ -300,11 +300,15 @@ def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int): ...@@ -300,11 +300,15 @@ def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
def glue_eval_data_collator(dataset: Dataset, batch_size: int): def glue_eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices.""" """Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
for i in range(len(dataset) // batch_size): batch_idx = np.arange(len(dataset))
batch = dataset[i * batch_size : (i + 1) * batch_size]
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()} batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch yield batch
...@@ -521,8 +525,9 @@ def main(): ...@@ -521,8 +525,9 @@ def main():
rng = jax.random.PRNGKey(training_args.seed) rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count()) dropout_rngs = jax.random.split(rng, jax.local_device_count())
train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count() train_batch_size = int(training_args.per_device_train_batch_size) * jax.local_device_count()
eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
learning_rate_fn = create_learning_rate_fn( learning_rate_fn = create_learning_rate_fn(
len(train_dataset), len(train_dataset),
...@@ -621,26 +626,15 @@ def main(): ...@@ -621,26 +626,15 @@ def main():
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size) eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm( for batch in tqdm(
eval_loader, eval_loader,
total=len(eval_dataset) // eval_batch_size, total=math.ceil(len(eval_dataset) / eval_batch_size),
desc="Evaluating ...", desc="Evaluating ...",
position=2, position=2,
): ):
labels = batch.pop("labels") labels = batch.pop("labels")
predictions = p_eval_step(state, batch) predictions = pad_shard_unpad(p_eval_step)(
metric.add_batch(predictions=chain(*predictions), references=chain(*labels)) state, batch, min_device_batch=per_device_eval_batch_size
)
# evaluate also on leftover examples (not divisible by batch_size) metric.add_batch(predictions=np.array(predictions), references=labels)
num_leftover_samples = len(eval_dataset) % eval_batch_size
# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: np.array(v) for k, v in batch.items()}
labels = batch.pop("labels")
predictions = eval_step(unreplicate(state), batch)
metric.add_batch(predictions=predictions, references=labels)
eval_metric = metric.compute() eval_metric = metric.compute()
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
""" Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)""" """ Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)"""
import json import json
import logging import logging
import math
import os import os
import random import random
import sys import sys
...@@ -36,7 +37,7 @@ import jax.numpy as jnp ...@@ -36,7 +37,7 @@ import jax.numpy as jnp
import optax import optax
import transformers import transformers
from flax import struct, traverse_util from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository from huggingface_hub import Repository
...@@ -351,11 +352,15 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int): ...@@ -351,11 +352,15 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
def eval_data_collator(dataset: Dataset, batch_size: int): def eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices.""" """Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
for i in range(len(dataset) // batch_size): batch_idx = np.arange(len(dataset))
batch = dataset[i * batch_size : (i + 1) * batch_size]
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()} batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch yield batch
...@@ -600,6 +605,7 @@ def main(): ...@@ -600,6 +605,7 @@ def main():
dropout_rngs = jax.random.split(rng, jax.local_device_count()) dropout_rngs = jax.random.split(rng, jax.local_device_count())
train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count() train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count() eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()
learning_rate_fn = create_learning_rate_fn( learning_rate_fn = create_learning_rate_fn(
...@@ -728,34 +734,16 @@ def main(): ...@@ -728,34 +734,16 @@ def main():
# evaluate # evaluate
for batch in tqdm( for batch in tqdm(
eval_data_collator(eval_dataset, eval_batch_size), eval_data_collator(eval_dataset, eval_batch_size),
total=len(eval_dataset) // eval_batch_size, total=math.ceil(len(eval_dataset) / eval_batch_size),
desc="Evaluating ...", desc="Evaluating ...",
position=2, position=2,
): ):
labels = batch.pop("labels") labels = batch.pop("labels")
predictions = p_eval_step(state, batch) predictions = pad_shard_unpad(p_eval_step)(
predictions = np.array([pred for pred in chain(*predictions)]) state, batch, min_device_batch=per_device_eval_batch_size
labels = np.array([label for label in chain(*labels)])
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
preds, refs = get_labels(predictions, labels)
metric.add_batch(
predictions=preds,
references=refs,
) )
predictions = np.array(predictions)
# evaluate also on leftover examples (not divisible by batch_size) labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
num_leftover_samples = len(eval_dataset) % eval_batch_size
# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: np.array(v) for k, v in batch.items()}
labels = batch.pop("labels")
predictions = eval_step(unreplicate(state), batch)
labels = np.array(labels)
labels[np.array(batch["attention_mask"]) == 0] = -100
preds, refs = get_labels(predictions, labels) preds, refs = get_labels(predictions, labels)
metric.add_batch( metric.add_batch(
predictions=preds, predictions=preds,
...@@ -791,28 +779,12 @@ def main(): ...@@ -791,28 +779,12 @@ def main():
eval_loader = eval_data_collator(eval_dataset, eval_batch_size) eval_loader = eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2): for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
labels = batch.pop("labels") labels = batch.pop("labels")
predictions = p_eval_step(state, batch) predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size)
predictions = np.array([pred for pred in chain(*predictions)]) predictions = np.array(predictions)
labels = np.array([label for label in chain(*labels)])
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
preds, refs = get_labels(predictions, labels) preds, refs = get_labels(predictions, labels)
metric.add_batch(predictions=preds, references=refs) metric.add_batch(predictions=preds, references=refs)
# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(eval_dataset) % eval_batch_size
# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: np.array(v) for k, v in batch.items()}
labels = np.array(batch.pop("labels"))
predictions = eval_step(unreplicate(state), batch)
labels[np.array(batch["attention_mask"]) == 0] = -100
preds, refs = get_labels(predictions, labels)
metric.add_batch(predictions=preds, references=refs)
eval_metrics = compute_metrics() eval_metrics = compute_metrics()
if jax.process_index() == 0: if jax.process_index() == 0:
......
...@@ -40,7 +40,7 @@ import jax.numpy as jnp ...@@ -40,7 +40,7 @@ import jax.numpy as jnp
import optax import optax
import transformers import transformers
from flax import jax_utils from flax import jax_utils
from flax.jax_utils import unreplicate from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository from huggingface_hub import Repository
...@@ -368,7 +368,8 @@ def main(): ...@@ -368,7 +368,8 @@ def main():
# Store some constant # Store some constant
num_epochs = int(training_args.num_train_epochs) num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs total_train_steps = steps_per_epoch * num_epochs
...@@ -398,7 +399,7 @@ def main(): ...@@ -398,7 +399,7 @@ def main():
shuffle=False, shuffle=False,
num_workers=data_args.preprocessing_num_workers, num_workers=data_args.preprocessing_num_workers,
persistent_workers=True, persistent_workers=True,
drop_last=True, drop_last=False,
collate_fn=collate_fn, collate_fn=collate_fn,
) )
...@@ -532,8 +533,9 @@ def main(): ...@@ -532,8 +533,9 @@ def main():
eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False) eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
for batch in eval_loader: for batch in eval_loader:
# Model forward # Model forward
batch = shard(batch) metrics = pad_shard_unpad(p_eval_step, static_return=True)(
metrics = p_eval_step(state.params, batch) state.params, batch, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics) eval_metrics.append(metrics)
eval_step_progress_bar.update(1) eval_step_progress_bar.update(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment