"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c206fc8779f9e4579b9fe155b396fe3d5479dde9"
Unverified Commit 2d42915a authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[examples/flax] add adafactor optimizer (#12544)



* add adafactor

* Update examples/flax/language-modeling/run_mlm_flax.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 208df208
...@@ -2,4 +2,4 @@ datasets >= 1.1.3 ...@@ -2,4 +2,4 @@ datasets >= 1.1.3
jax>=0.2.8 jax>=0.2.8
jaxlib>=0.1.59 jaxlib>=0.1.59
flax>=0.3.4 flax>=0.3.4
optax>=0.0.8 optax>=0.0.9
...@@ -489,17 +489,24 @@ def main(): ...@@ -489,17 +489,24 @@ def main():
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer # create adam optimizer
adamw = optax.adamw( if training_args.adafactor:
learning_rate=linear_decay_lr_schedule_fn, # We use the default parameters here to initialize adafactor,
b1=training_args.adam_beta1, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
b2=training_args.adam_beta2, optimizer = optax.adafactor(
eps=training_args.adam_epsilon, learning_rate=linear_decay_lr_schedule_fn,
weight_decay=training_args.weight_decay, )
mask=decay_mask_fn, else:
) optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state # Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
def loss_fn(logits, labels): def loss_fn(logits, labels):
shift_logits = logits[..., :-1, :] shift_logits = logits[..., :-1, :]
......
...@@ -513,17 +513,24 @@ if __name__ == "__main__": ...@@ -513,17 +513,24 @@ if __name__ == "__main__":
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer # create adam optimizer
adamw = optax.adamw( if training_args.adafactor:
learning_rate=linear_decay_lr_schedule_fn, # We use the default parameters here to initialize adafactor,
b1=training_args.adam_beta1, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
b2=training_args.adam_beta2, optimizer = optax.adafactor(
eps=1e-8, learning_rate=linear_decay_lr_schedule_fn,
weight_decay=training_args.weight_decay, )
mask=decay_mask_fn, else:
) optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state # Setup train state
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
# Define gradient update step fn # Define gradient update step fn
def train_step(state, batch, dropout_rng): def train_step(state, batch, dropout_rng):
......
...@@ -635,16 +635,23 @@ if __name__ == "__main__": ...@@ -635,16 +635,23 @@ if __name__ == "__main__":
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer # create adam optimizer
adamw = optax.adamw( if training_args.adafactor:
learning_rate=linear_decay_lr_schedule_fn, # We use the default parameters here to initialize adafactor,
b1=training_args.adam_beta1, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
b2=training_args.adam_beta2, optimizer = optax.adafactor(
weight_decay=training_args.weight_decay, learning_rate=linear_decay_lr_schedule_fn,
mask=decay_mask_fn, )
) else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state # Setup train state
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
# Define gradient update step fn # Define gradient update step fn
def train_step(state, batch, dropout_rng): def train_step(state, batch, dropout_rng):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment