"tests/t5/test_modeling_tf_t5.py" did not exist on "e983da0e7d91c100e6e35efcb8a69c8cd41d6e09"
Unverified Commit 113eaa75 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

correct example script (#11726)

parent bd3b599c
......@@ -119,12 +119,6 @@ def parse_args():
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
......@@ -457,13 +451,13 @@ def main():
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0
# make sure weights are replicated on each device
state = replicate(state)
for epoch in range(1, num_epochs + 1):
logger.info(f"Epoch {epoch}")
logger.info(" Training...")
# make sure weights are replicated on each device
state = replicate(state)
train_start = time.time()
train_metrics = []
rng, input_rng, dropout_rng = jax.random.split(rng, 3)
......@@ -501,6 +495,9 @@ def main():
predictions = eval_step(state, batch)
metric.add_batch(predictions=predictions, references=labels)
# make sure weights are replicated on each device
state = replicate(state)
eval_metric = metric.compute()
logger.info(f" Done! Eval metrics: {eval_metric}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment