Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7d321b76
Unverified
Commit
7d321b76
authored
Jul 07, 2021
by
Patrick von Platen
Committed by
GitHub
Jul 07, 2021
Browse files
[Flax] Allow retraining from save checkpoint (#12559)
* fix_torch_device_generate_test * remove @ * finish
parent
1d6623c6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
3 deletions
+22
-3
examples/flax/language-modeling/run_mlm_flax.py
examples/flax/language-modeling/run_mlm_flax.py
+8
-1
examples/flax/language-modeling/run_t5_mlm_flax.py
examples/flax/language-modeling/run_t5_mlm_flax.py
+6
-1
examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py
...cts/jax-projects/dataset-streaming/run_mlm_flax_stream.py
+8
-1
No files found.
examples/flax/language-modeling/run_mlm_flax.py
View file @
7d321b76
...
...
@@ -478,7 +478,14 @@ if __name__ == "__main__":
rng
=
jax
.
random
.
PRNGKey
(
training_args
.
seed
)
dropout_rngs
=
jax
.
random
.
split
(
rng
,
jax
.
local_device_count
())
model
=
FlaxAutoModelForMaskedLM
.
from_config
(
config
,
seed
=
training_args
.
seed
,
dtype
=
getattr
(
jnp
,
model_args
.
dtype
))
if
model_args
.
model_name_or_path
:
model
=
FlaxAutoModelForMaskedLM
.
from_pretrained
(
model_args
.
model_name_or_path
,
config
=
config
,
seed
=
training_args
.
seed
,
dtype
=
getattr
(
jnp
,
model_args
.
dtype
)
)
else
:
model
=
FlaxAutoModelForMaskedLM
.
from_config
(
config
,
seed
=
training_args
.
seed
,
dtype
=
getattr
(
jnp
,
model_args
.
dtype
)
)
# Store some constant
num_epochs
=
int
(
training_args
.
num_train_epochs
)
...
...
examples/flax/language-modeling/run_t5_mlm_flax.py
View file @
7d321b76
...
...
@@ -588,6 +588,11 @@ if __name__ == "__main__":
rng
=
jax
.
random
.
PRNGKey
(
training_args
.
seed
)
dropout_rngs
=
jax
.
random
.
split
(
rng
,
jax
.
local_device_count
())
if
model_args
.
model_name_or_path
:
model
=
FlaxT5ForConditionalGeneration
.
from_pretrained
(
model_args
.
model_name_or_path
,
config
=
config
,
seed
=
training_args
.
seed
,
dtype
=
getattr
(
jnp
,
model_args
.
dtype
)
)
else
:
model
=
FlaxT5ForConditionalGeneration
(
config
,
seed
=
training_args
.
seed
,
dtype
=
getattr
(
jnp
,
model_args
.
dtype
))
# Data collator
...
...
examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py
View file @
7d321b76
...
...
@@ -427,7 +427,14 @@ if __name__ == "__main__":
rng
=
jax
.
random
.
PRNGKey
(
training_args
.
seed
)
dropout_rngs
=
jax
.
random
.
split
(
rng
,
jax
.
local_device_count
())
model
=
FlaxAutoModelForMaskedLM
.
from_config
(
config
,
seed
=
training_args
.
seed
,
dtype
=
getattr
(
jnp
,
model_args
.
dtype
))
if
model_args
.
model_name_or_path
:
model
=
FlaxAutoModelForMaskedLM
.
from_pretrained
(
model_args
.
model_name_or_path
,
config
=
config
,
seed
=
training_args
.
seed
,
dtype
=
getattr
(
jnp
,
model_args
.
dtype
)
)
else
:
model
=
FlaxAutoModelForMaskedLM
.
from_config
(
config
,
seed
=
training_args
.
seed
,
dtype
=
getattr
(
jnp
,
model_args
.
dtype
)
)
# Store some constant
num_epochs
=
int
(
training_args
.
num_train_epochs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment