Unverified Commit d1500d91 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

pass decay_mask fn to optimizer (#12087)

parent d472bd7b
......@@ -38,7 +38,7 @@ import flax
import jax
import jax.numpy as jnp
import optax
from flax import jax_utils
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from transformers import (
......@@ -504,6 +504,15 @@ if __name__ == "__main__":
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
......@@ -511,6 +520,7 @@ if __name__ == "__main__":
b2=training_args.adam_beta2,
eps=1e-8,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment