Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
113eaa75
Unverified
Commit
113eaa75
authored
May 14, 2021
by
Patrick von Platen
Committed by
GitHub
May 14, 2021
Browse files
correct example script (#11726)
parent
bd3b599c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
9 deletions
+6
-9
examples/flax/text-classification/run_flax_glue.py
examples/flax/text-classification/run_flax_glue.py
+6
-9
No files found.
examples/flax/text-classification/run_flax_glue.py
View file @
113eaa75
...
@@ -119,12 +119,6 @@ def parse_args():
...
@@ -119,12 +119,6 @@ def parse_args():
default
=
None
,
default
=
None
,
help
=
"Total number of training steps to perform. If provided, overrides num_train_epochs."
,
help
=
"Total number of training steps to perform. If provided, overrides num_train_epochs."
,
)
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_warmup_steps"
,
type
=
int
,
default
=
0
,
help
=
"Number of steps for the warmup in the lr scheduler."
"--num_warmup_steps"
,
type
=
int
,
default
=
0
,
help
=
"Number of steps for the warmup in the lr scheduler."
)
)
...
@@ -457,13 +451,13 @@ def main():
...
@@ -457,13 +451,13 @@ def main():
logger
.
info
(
f
"===== Starting training (
{
num_epochs
}
epochs) ====="
)
logger
.
info
(
f
"===== Starting training (
{
num_epochs
}
epochs) ====="
)
train_time
=
0
train_time
=
0
# make sure weights are replicated on each device
state
=
replicate
(
state
)
for
epoch
in
range
(
1
,
num_epochs
+
1
):
for
epoch
in
range
(
1
,
num_epochs
+
1
):
logger
.
info
(
f
"Epoch
{
epoch
}
"
)
logger
.
info
(
f
"Epoch
{
epoch
}
"
)
logger
.
info
(
" Training..."
)
logger
.
info
(
" Training..."
)
# make sure weights are replicated on each device
state
=
replicate
(
state
)
train_start
=
time
.
time
()
train_start
=
time
.
time
()
train_metrics
=
[]
train_metrics
=
[]
rng
,
input_rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
,
3
)
rng
,
input_rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
,
3
)
...
@@ -501,6 +495,9 @@ def main():
...
@@ -501,6 +495,9 @@ def main():
predictions
=
eval_step
(
state
,
batch
)
predictions
=
eval_step
(
state
,
batch
)
metric
.
add_batch
(
predictions
=
predictions
,
references
=
labels
)
metric
.
add_batch
(
predictions
=
predictions
,
references
=
labels
)
# make sure weights are replicated on each device
state
=
replicate
(
state
)
eval_metric
=
metric
.
compute
()
eval_metric
=
metric
.
compute
()
logger
.
info
(
f
" Done! Eval metrics:
{
eval_metric
}
"
)
logger
.
info
(
f
" Done! Eval metrics:
{
eval_metric
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment