Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d0f7508a
Unverified
Commit
d0f7508a
authored
Jul 05, 2021
by
Patrick von Platen
Committed by
GitHub
Jul 05, 2021
Browse files
[Flax] Correct logging steps flax (#12515)
* fix_torch_device_generate_test * remove @ * push
parent
bb4ac2b5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3 additions
and
3 deletions
+3
-3
examples/flax/language-modeling/run_clm_flax.py
examples/flax/language-modeling/run_clm_flax.py
+1
-1
examples/flax/language-modeling/run_mlm_flax.py
examples/flax/language-modeling/run_mlm_flax.py
+1
-1
examples/flax/language-modeling/run_t5_mlm_flax.py
examples/flax/language-modeling/run_t5_mlm_flax.py
+1
-1
No files found.
examples/flax/language-modeling/run_clm_flax.py
View file @
d0f7508a
...
...
@@ -574,7 +574,7 @@ def main():
cur_step
=
epoch
*
(
len
(
train_dataset
)
//
train_batch_size
)
+
step
if
cur_step
%
training_args
.
logging_steps
and
cur_step
>
0
:
if
cur_step
%
training_args
.
logging_steps
==
0
and
cur_step
>
0
:
# Save metrics
train_metric
=
unreplicate
(
train_metric
)
train_time
+=
time
.
time
()
-
train_start
...
...
examples/flax/language-modeling/run_mlm_flax.py
View file @
d0f7508a
...
...
@@ -608,7 +608,7 @@ if __name__ == "__main__":
cur_step
=
epoch
*
num_train_samples
+
step
if
cur_step
%
training_args
.
logging_steps
and
cur_step
>
0
:
if
cur_step
%
training_args
.
logging_steps
==
0
and
cur_step
>
0
:
# Save metrics
train_metric
=
jax_utils
.
unreplicate
(
train_metric
)
train_time
+=
time
.
time
()
-
train_start
...
...
examples/flax/language-modeling/run_t5_mlm_flax.py
View file @
d0f7508a
...
...
@@ -724,7 +724,7 @@ if __name__ == "__main__":
cur_step
=
epoch
*
num_train_samples
+
step
if
cur_step
%
training_args
.
logging_steps
and
cur_step
>
0
:
if
cur_step
%
training_args
.
logging_steps
==
0
and
cur_step
>
0
:
# Save metrics
train_metric
=
jax_utils
.
unreplicate
(
train_metric
)
train_time
+=
time
.
time
()
-
train_start
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment