Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d1500d91
Unverified
Commit
d1500d91
authored
Jun 09, 2021
by
Suraj Patil
Committed by
GitHub
Jun 09, 2021
Browse files
pass decay_mask fn to optimizer (#12087)
parent
d472bd7b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
1 deletion
+11
-1
examples/flax/language-modeling/run_mlm_flax.py
examples/flax/language-modeling/run_mlm_flax.py
+11
-1
No files found.
examples/flax/language-modeling/run_mlm_flax.py
View file @
d1500d91
...
@@ -38,7 +38,7 @@ import flax
...
@@ -38,7 +38,7 @@ import flax
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
optax
import
optax
from
flax
import
jax_utils
from
flax
import
jax_utils
,
traverse_util
from
flax.training
import
train_state
from
flax.training
import
train_state
from
flax.training.common_utils
import
get_metrics
,
onehot
,
shard
from
flax.training.common_utils
import
get_metrics
,
onehot
,
shard
from
transformers
import
(
from
transformers
import
(
...
@@ -504,6 +504,15 @@ if __name__ == "__main__":
...
@@ -504,6 +504,15 @@ if __name__ == "__main__":
schedules
=
[
warmup_fn
,
decay_fn
],
boundaries
=
[
training_args
.
warmup_steps
]
schedules
=
[
warmup_fn
,
decay_fn
],
boundaries
=
[
training_args
.
warmup_steps
]
)
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def
decay_mask_fn
(
params
):
flat_params
=
traverse_util
.
flatten_dict
(
params
)
flat_mask
=
{
path
:
(
path
[
-
1
]
!=
"bias"
and
path
[
-
2
:]
!=
(
"LayerNorm"
,
"scale"
))
for
path
in
flat_params
}
return
traverse_util
.
unflatten_dict
(
flat_mask
)
# create adam optimizer
# create adam optimizer
adamw
=
optax
.
adamw
(
adamw
=
optax
.
adamw
(
learning_rate
=
linear_decay_lr_schedule_fn
,
learning_rate
=
linear_decay_lr_schedule_fn
,
...
@@ -511,6 +520,7 @@ if __name__ == "__main__":
...
@@ -511,6 +520,7 @@ if __name__ == "__main__":
b2
=
training_args
.
adam_beta2
,
b2
=
training_args
.
adam_beta2
,
eps
=
1e-8
,
eps
=
1e-8
,
weight_decay
=
training_args
.
weight_decay
,
weight_decay
=
training_args
.
weight_decay
,
mask
=
decay_mask_fn
,
)
)
# Setup train state
# Setup train state
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment