Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
113eaa75
"tests/t5/test_modeling_tf_t5.py" did not exist on "e983da0e7d91c100e6e35efcb8a69c8cd41d6e09"
Unverified
Commit
113eaa75
authored
May 14, 2021
by
Patrick von Platen
Committed by
GitHub
May 14, 2021
Browse files
correct example script (#11726)
parent
bd3b599c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
9 deletions
+6
-9
examples/flax/text-classification/run_flax_glue.py
examples/flax/text-classification/run_flax_glue.py
+6
-9
No files found.
examples/flax/text-classification/run_flax_glue.py
View file @
113eaa75
...
...
@@ -119,12 +119,6 @@ def parse_args():
default
=
None
,
help
=
"Total number of training steps to perform. If provided, overrides num_train_epochs."
,
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
parser
.
add_argument
(
"--num_warmup_steps"
,
type
=
int
,
default
=
0
,
help
=
"Number of steps for the warmup in the lr scheduler."
)
...
...
@@ -457,13 +451,13 @@ def main():
logger
.
info
(
f
"===== Starting training (
{
num_epochs
}
epochs) ====="
)
train_time
=
0
# make sure weights are replicated on each device
state
=
replicate
(
state
)
for
epoch
in
range
(
1
,
num_epochs
+
1
):
logger
.
info
(
f
"Epoch
{
epoch
}
"
)
logger
.
info
(
" Training..."
)
# make sure weights are replicated on each device
state
=
replicate
(
state
)
train_start
=
time
.
time
()
train_metrics
=
[]
rng
,
input_rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
,
3
)
...
...
@@ -501,6 +495,9 @@ def main():
predictions
=
eval_step
(
state
,
batch
)
metric
.
add_batch
(
predictions
=
predictions
,
references
=
labels
)
# make sure weights are replicated on each device
state
=
replicate
(
state
)
eval_metric
=
metric
.
compute
()
logger
.
info
(
f
" Done! Eval metrics:
{
eval_metric
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment