Unverified Commit 2bd950ca authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

[Flax] token-classification model steps enumerate start from 1 (#14547)

* step start from 1

* Updated cur_step calcualtion
parent cea17acd
...@@ -598,7 +598,7 @@ def main(): ...@@ -598,7 +598,7 @@ def main():
state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs) state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
train_metrics.append(train_metric) train_metrics.append(train_metric)
cur_step = epoch * step_per_epoch + step cur_step = (epoch * step_per_epoch) + (step + 1)
if cur_step % training_args.logging_steps == 0 and cur_step > 0: if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics # Save metrics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment