Unverified Commit 4212bb0d authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

[Re-submit] Compute true loss Flax examples (#19504)



* Compute true loss

* fixup

* final

* final

* final

* Update examples/flax/language-modeling/run_bart_dlm_flax.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* jax.tree_map => jax.tree_util.tree_map

* Compute true loss

* final

* fixup

* final

* final

* Update examples/flax/language-modeling/run_bart_dlm_flax.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* jax.tree_map => jax.tree_util.tree_map
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 0903fc80
...@@ -335,7 +335,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf ...@@ -335,7 +335,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
batch_idx = np.arange(len(dataset)) batch_idx = np.arange(len(dataset))
for idx in range(steps): for idx in range(steps):
start_idx = batch_size * idx start_idx = batch_size * idx
end_idx = batch_size * (idx + 1) end_idx = batch_size * (idx + 1)
...@@ -347,7 +346,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf ...@@ -347,7 +346,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="train"): def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="train"):
if train_time: if train_time:
summary_writer.scalar("train_time", train_time, step) summary_writer.scalar("train_time", train_time, step)
...@@ -782,11 +780,9 @@ def main(): ...@@ -782,11 +780,9 @@ def main():
num_splits = steps // steps_per_block + int(steps % steps_per_block > 0) num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
for idx in range(num_splits): for idx in range(num_splits):
if not block_size: if not block_size:
_ds = ds _ds = ds
else: else:
start_idx = block_size * idx start_idx = block_size * idx
end_idx = block_size * (idx + 1) end_idx = block_size * (idx + 1)
...@@ -926,8 +922,9 @@ def main(): ...@@ -926,8 +922,9 @@ def main():
# ignore padded tokens from loss # ignore padded tokens from loss
loss = loss * padding_mask loss = loss * padding_mask
loss = loss.sum() / padding_mask.sum() loss = loss.sum()
return loss num_labels = padding_mask.sum()
return loss, num_labels
# Define gradient update step fn # Define gradient update step fn
def train_step(state, batch, label_smoothing_factor=0.0): def train_step(state, batch, label_smoothing_factor=0.0):
...@@ -936,29 +933,38 @@ def main(): ...@@ -936,29 +933,38 @@ def main():
def compute_loss(params): def compute_loss(params):
labels = batch.pop("labels") labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss return loss, num_labels
grad_fn = jax.value_and_grad(compute_loss) grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
loss, grad = grad_fn(state.params) (loss, num_labels), grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch") num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics return new_state, metrics
# Define eval fn # Define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0): def eval_step(params, batch, label_smoothing_factor=0.0):
labels = batch.pop("labels") labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0] logits = model(**batch, params=params, train=False)[0]
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
# summarize metrics loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
metrics = {"loss": loss} metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics return metrics
# Define generation function # Define generation function
...@@ -1024,7 +1030,6 @@ def main(): ...@@ -1024,7 +1030,6 @@ def main():
ckpt_dir: str = "", ckpt_dir: str = "",
is_prediction=False, is_prediction=False,
): ):
logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***") logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
metrics = [] metrics = []
...@@ -1103,12 +1108,10 @@ def main(): ...@@ -1103,12 +1108,10 @@ def main():
logger.info(desc) logger.info(desc)
if jax.process_index() == 0: if jax.process_index() == 0:
if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)): if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True) os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
if metrics: if metrics:
# Save metrics (only for the evaluation/prediction being done along with training) # Save metrics (only for the evaluation/prediction being done along with training)
if has_tensorboard and training_args.do_train: if has_tensorboard and training_args.do_train:
write_metric( write_metric(
...@@ -1143,7 +1146,6 @@ def main(): ...@@ -1143,7 +1146,6 @@ def main():
input_rng = None input_rng = None
if training_args.do_train: if training_args.do_train:
cur_step = 0 cur_step = 0
train_time = 0 train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
...@@ -1166,7 +1168,6 @@ def main(): ...@@ -1166,7 +1168,6 @@ def main():
# train # train
for batch_idx, _ in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)): for batch_idx, _ in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
cur_step += 1 cur_step += 1
batch = next(train_batches) batch = next(train_batches)
batch_start = time.time() batch_start = time.time()
...@@ -1177,7 +1178,6 @@ def main(): ...@@ -1177,7 +1178,6 @@ def main():
# log and save info # log and save info
if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0: if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:
_train_metric = unreplicate(train_metric) _train_metric = unreplicate(train_metric)
desc = ( desc = (
f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} |" f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} |"
...@@ -1217,7 +1217,6 @@ def main(): ...@@ -1217,7 +1217,6 @@ def main():
# log and save info # log and save info
if training_args.logging_steps <= 0: if training_args.logging_steps <= 0:
logger.info(desc) logger.info(desc)
with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp: with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
......
...@@ -351,7 +351,7 @@ The example script uses the 🤗 Datasets library. You can easily customize them ...@@ -351,7 +351,7 @@ The example script uses the 🤗 Datasets library. You can easily customize them
To setup all relevant files for training, let's create a directory. To setup all relevant files for training, let's create a directory.
```bash ```bash
mkdir ./norwegian-roberta-base mkdir ./norwegian-bart-base
``` ```
### Train tokenizer ### Train tokenizer
......
...@@ -799,19 +799,25 @@ def main(): ...@@ -799,19 +799,25 @@ def main():
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
# take average # take average
loss = loss.sum() / label_mask.sum() loss = loss.sum()
num_labels = label_mask.sum()
return loss return loss, num_labels
grad_fn = jax.value_and_grad(loss_fn) grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
loss, grad = grad_fn(state.params) (loss, num_labels), grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch") num_labels = jax.lax.psum(num_labels, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean( # true loss = total loss / total samples
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" loss = jax.lax.psum(loss, "batch")
) loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics, new_dropout_rng return new_state, metrics, new_dropout_rng
# Create parallel version of the train step # Create parallel version of the train step
...@@ -888,7 +894,7 @@ def main(): ...@@ -888,7 +894,7 @@ def main():
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples) eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = [] eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
...@@ -903,9 +909,9 @@ def main(): ...@@ -903,9 +909,9 @@ def main():
# normalize eval metrics # normalize eval metrics
eval_metrics = get_metrics(eval_metrics) eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.sum, eval_metrics) eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer") eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar # Update progress bar
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
...@@ -917,7 +923,7 @@ def main(): ...@@ -917,7 +923,7 @@ def main():
if cur_step % training_args.save_steps == 0 and cur_step > 0: if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params) model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub: if training_args.push_to_hub:
...@@ -928,7 +934,7 @@ def main(): ...@@ -928,7 +934,7 @@ def main():
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training # Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples) eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = [] eval_metrics = []
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
...@@ -943,9 +949,9 @@ def main(): ...@@ -943,9 +949,9 @@ def main():
# normalize eval metrics # normalize eval metrics
eval_metrics = get_metrics(eval_metrics) eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics) eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer") eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
try: try:
perplexity = math.exp(eval_metrics["loss"]) perplexity = math.exp(eval_metrics["loss"])
......
...@@ -723,18 +723,25 @@ def main(): ...@@ -723,18 +723,25 @@ def main():
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
# take average # take average
loss = loss.sum() / label_mask.sum() loss = loss.sum()
num_labels = label_mask.sum()
return loss return loss, num_labels
grad_fn = jax.value_and_grad(loss_fn) grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
loss, grad = grad_fn(state.params) (loss, num_labels), grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch") num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad) new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean( metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
)
return new_state, metrics, new_dropout_rng return new_state, metrics, new_dropout_rng
......
...@@ -328,7 +328,6 @@ class FlaxDataCollatorForT5MLM: ...@@ -328,7 +328,6 @@ class FlaxDataCollatorForT5MLM:
decoder_start_token_id: int decoder_start_token_id: int
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding: def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding:
# convert list to dict and tensorize input # convert list to dict and tensorize input
batch = BatchEncoding( batch = BatchEncoding(
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()} {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
...@@ -397,7 +396,6 @@ class FlaxDataCollatorForT5MLM: ...@@ -397,7 +396,6 @@ class FlaxDataCollatorForT5MLM:
return input_ids return input_ids
def random_spans_noise_mask(self, length): def random_spans_noise_mask(self, length):
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ . """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens. Noise mask consisting of random spans of noise tokens.
......
...@@ -784,8 +784,9 @@ def main(): ...@@ -784,8 +784,9 @@ def main():
# ignore padded tokens from loss # ignore padded tokens from loss
loss = loss * padding_mask loss = loss * padding_mask
loss = loss.sum() / padding_mask.sum() loss = loss.sum()
return loss num_labels = padding_mask.sum()
return loss, num_labels
# Define gradient update step fn # Define gradient update step fn
def train_step(state, batch, label_smoothing_factor=0.0): def train_step(state, batch, label_smoothing_factor=0.0):
...@@ -794,29 +795,38 @@ def main(): ...@@ -794,29 +795,38 @@ def main():
def compute_loss(params): def compute_loss(params):
labels = batch.pop("labels") labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss return loss, num_labels
grad_fn = jax.value_and_grad(compute_loss) grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
loss, grad = grad_fn(state.params) (loss, num_labels), grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch") num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics return new_state, metrics
# Define eval fn # Define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0): def eval_step(params, batch, label_smoothing_factor=0.0):
labels = batch.pop("labels") labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0] logits = model(**batch, params=params, train=False)[0]
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
# summarize metrics loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
metrics = {"loss": loss} metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics return metrics
# Define generation function # Define generation function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment