Unverified Commit 208df208 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Adapt examples to be able to use eval_steps and save_steps (#12543)



* fix_torch_device_generate_test

* remove @

* up

* up

* correct

* upload
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent 2870fd19
...@@ -141,6 +141,8 @@ Next we can run the example script to pretrain the model: ...@@ -141,6 +141,8 @@ Next we can run the example script to pretrain the model:
--adam_beta1="0.9" \ --adam_beta1="0.9" \
--adam_beta2="0.98" \ --adam_beta2="0.98" \
--logging_steps="500" \ --logging_steps="500" \
--save_steps="2500" \
--eval_steps="2500" \
--push_to_hub --push_to_hub
``` ```
...@@ -234,6 +236,8 @@ Next we can run the example script to pretrain the model: ...@@ -234,6 +236,8 @@ Next we can run the example script to pretrain the model:
--overwrite_output_dir \ --overwrite_output_dir \
--num_train_epochs="20" \ --num_train_epochs="20" \
--logging_steps="500" \ --logging_steps="500" \
--save_steps="2500" \
--eval_steps="2500" \
--push_to_hub --push_to_hub
``` ```
...@@ -370,6 +374,8 @@ Next we can run the example script to pretrain the model: ...@@ -370,6 +374,8 @@ Next we can run the example script to pretrain the model:
--overwrite_output_dir \ --overwrite_output_dir \
--num_train_epochs="10" \ --num_train_epochs="10" \
--logging_steps="500" \ --logging_steps="500" \
--save_steps="2500" \
--eval_steps="2500" \
--push_to_hub --push_to_hub
``` ```
......
...@@ -587,6 +587,7 @@ def main(): ...@@ -587,6 +587,7 @@ def main():
train_metrics = [] train_metrics = []
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
eval_metrics = [] eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
...@@ -599,7 +600,6 @@ def main(): ...@@ -599,7 +600,6 @@ def main():
# normalize eval metrics # normalize eval metrics
eval_metrics = get_metrics(eval_metrics) eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
try: try:
...@@ -608,7 +608,7 @@ def main(): ...@@ -608,7 +608,7 @@ def main():
eval_metrics["perplexity"] = float("inf") eval_metrics["perplexity"] = float("inf")
# Print metrics and update progress bar # Print metrics and update progress bar
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})" desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
epochs.write(desc) epochs.write(desc)
epochs.desc = desc epochs.desc = desc
...@@ -617,6 +617,7 @@ def main(): ...@@ -617,6 +617,7 @@ def main():
cur_step = epoch * (len(train_dataset) // train_batch_size) cur_step = epoch * (len(train_dataset) // train_batch_size)
write_eval_metric(summary_writer, eval_metrics, cur_step) write_eval_metric(summary_writer, eval_metrics, cur_step)
if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params)) params = jax.device_get(unreplicate(state.params))
...@@ -624,7 +625,7 @@ def main(): ...@@ -624,7 +625,7 @@ def main():
training_args.output_dir, training_args.output_dir,
params=params, params=params,
push_to_hub=training_args.push_to_hub, push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}", commit_message=f"Saving weights and logs of step {cur_step}",
) )
......
...@@ -621,6 +621,7 @@ if __name__ == "__main__": ...@@ -621,6 +621,7 @@ if __name__ == "__main__":
train_metrics = [] train_metrics = []
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples) eval_samples_idx = jnp.arange(num_eval_samples)
...@@ -643,15 +644,14 @@ if __name__ == "__main__": ...@@ -643,15 +644,14 @@ if __name__ == "__main__":
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar # Update progress bar
epochs.desc = ( epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
)
# Save metrics # Save metrics
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
write_eval_metric(summary_writer, eval_metrics, cur_step) write_eval_metric(summary_writer, eval_metrics, cur_step)
if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
...@@ -659,5 +659,5 @@ if __name__ == "__main__": ...@@ -659,5 +659,5 @@ if __name__ == "__main__":
training_args.output_dir, training_args.output_dir,
params=params, params=params,
push_to_hub=training_args.push_to_hub, push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}", commit_message=f"Saving weights and logs of step {cur_step}",
) )
...@@ -737,6 +737,7 @@ if __name__ == "__main__": ...@@ -737,6 +737,7 @@ if __name__ == "__main__":
train_metrics = [] train_metrics = []
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"]) num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples) eval_samples_idx = jnp.arange(num_eval_samples)
...@@ -757,16 +758,20 @@ if __name__ == "__main__": ...@@ -757,16 +758,20 @@ if __name__ == "__main__":
eval_metrics = jax.tree_map(jnp.mean, eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
# Update progress bar # Update progress bar
epochs.write( epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
)
# Save metrics # Save metrics
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
write_eval_metric(summary_writer, eval_metrics, cur_step) write_eval_metric(summary_writer, eval_metrics, cur_step)
if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub) model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of step {cur_step}",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment