Unverified Commit 113eaa75 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

correct example script (#11726)

parent bd3b599c
...@@ -119,12 +119,6 @@ def parse_args(): ...@@ -119,12 +119,6 @@ def parse_args():
default=None, default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
) )
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument( parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
) )
...@@ -457,13 +451,13 @@ def main(): ...@@ -457,13 +451,13 @@ def main():
logger.info(f"===== Starting training ({num_epochs} epochs) =====") logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0 train_time = 0
# make sure weights are replicated on each device
state = replicate(state)
for epoch in range(1, num_epochs + 1): for epoch in range(1, num_epochs + 1):
logger.info(f"Epoch {epoch}") logger.info(f"Epoch {epoch}")
logger.info(" Training...") logger.info(" Training...")
# make sure weights are replicated on each device
state = replicate(state)
train_start = time.time() train_start = time.time()
train_metrics = [] train_metrics = []
rng, input_rng, dropout_rng = jax.random.split(rng, 3) rng, input_rng, dropout_rng = jax.random.split(rng, 3)
...@@ -501,6 +495,9 @@ def main(): ...@@ -501,6 +495,9 @@ def main():
predictions = eval_step(state, batch) predictions = eval_step(state, batch)
metric.add_batch(predictions=predictions, references=labels) metric.add_batch(predictions=predictions, references=labels)
# make sure weights are replicated on each device
state = replicate(state)
eval_metric = metric.compute() eval_metric = metric.compute()
logger.info(f" Done! Eval metrics: {eval_metric}") logger.info(f" Done! Eval metrics: {eval_metric}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment