Unverified Commit d0f7508a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Correct logging steps flax (#12515)

* fix_torch_device_generate_test

* remove @

* push
parent bb4ac2b5
...@@ -574,7 +574,7 @@ def main(): ...@@ -574,7 +574,7 @@ def main():
cur_step = epoch * (len(train_dataset) // train_batch_size) + step cur_step = epoch * (len(train_dataset) // train_batch_size) + step
if cur_step % training_args.logging_steps and cur_step > 0: if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics # Save metrics
train_metric = unreplicate(train_metric) train_metric = unreplicate(train_metric)
train_time += time.time() - train_start train_time += time.time() - train_start
......
...@@ -608,7 +608,7 @@ if __name__ == "__main__": ...@@ -608,7 +608,7 @@ if __name__ == "__main__":
cur_step = epoch * num_train_samples + step cur_step = epoch * num_train_samples + step
if cur_step % training_args.logging_steps and cur_step > 0: if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics # Save metrics
train_metric = jax_utils.unreplicate(train_metric) train_metric = jax_utils.unreplicate(train_metric)
train_time += time.time() - train_start train_time += time.time() - train_start
......
...@@ -724,7 +724,7 @@ if __name__ == "__main__": ...@@ -724,7 +724,7 @@ if __name__ == "__main__":
cur_step = epoch * num_train_samples + step cur_step = epoch * num_train_samples + step
if cur_step % training_args.logging_steps and cur_step > 0: if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics # Save metrics
train_metric = jax_utils.unreplicate(train_metric) train_metric = jax_utils.unreplicate(train_metric)
train_time += time.time() - train_start train_time += time.time() - train_start
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment