"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "78a2654049624322817fbb592fd61dce78b58822"
Unverified Commit 4674061b authored by Nicholas Vadivelu's avatar Nicholas Vadivelu Committed by GitHub
Browse files

Fix weight decay masking in `run_flax_glue.py` (#11964)



* Fix weight decay masking in `run_flax_glue.py`

Issues with the previous implementation:
- The `dict` from `traverse_util.flatten_dict` has keys which are tuples of strings, not one long string with the path separated by periods.
- `optax.masked` applies the transformation wherever the mask is True, so the masks are flipped.
- Flax's LayerNorm calls the scale parameter `scale` not `weight`

* Fix formatting with black

* adapt results
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent 61c50634
...@@ -63,15 +63,15 @@ In the Tensorboard results linked below, the random seed of each model is equal ...@@ -63,15 +63,15 @@ In the Tensorboard results linked below, the random seed of each model is equal
| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics | | Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------| |-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
| CoLA | Matthew's corr | 60.82 | 59.04 | 1.17 | [tfhub.dev](https://tensorboard.dev/experiment/U2ncNFP3RpWW6YnA9PYJBA/) | | CoLA | Matthew's corr | 60.57 | 59.04 | 1.06 | [tfhub.dev](https://tensorboard.dev/experiment/lfr2adVpRtmLDALKrElkzg/) |
| SST-2 | Accuracy | 92.43 | 92.13 | 0.38 | [tfhub.dev](https://tensorboard.dev/experiment/vzxoOHZURcm0rO1I33x7uA/) | | SST-2 | Accuracy | 92.66 | 92.23 | 0.57 | [tfhub.dev](https://tensorboard.dev/experiment/jYvfv2trRHKMjoWnXVwrZA/) |
| MRPC | F1/Accuracy | 89.90/88.98 | 88.98/85.30 | 0.73/2.33 | [tfhub.dev](https://tensorboard.dev/experiment/EWPBIbfYSDGHjiYxrw2a2Q/) | | MRPC | F1/Accuracy | 89.90/85.78 | 88.97/84.36 | 0.72/1.09 | [tfhub.dev](https://tensorboard.dev/experiment/bo3W3DEoRw2Q7YXjWrJkfg/) |
| STS-B | Pearson/Spearman corr. | 89.04/88.70 | 88.94/88.63 | 0.07/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/3aYHKL10TeiaZYwH1M8ogA/) | | STS-B | Pearson/Spearman corr. | 89.04/88.70 | 88.94/88.63 | 0.07/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/fxVwbLD7QpKhbot0r9rn2w/) |
| QQP | Accuracy/F1 | 90.82/87.54 | 90.75/87.53 | 0.06/0.02 | [tfhub.dev](https://tensorboard.dev/experiment/VfVDLS4AQnqr4NMbng6yUw/) | | QQP | Accuracy/F1 | 90.81/87.58 | 90.76/87.51 | 0.05/0.06 | [tfhub.dev](https://tensorboard.dev/experiment/di089Rc9TZmsnKRMrYNLsA/) |
| MNLI | Matched acc. | 84.10 | 83.84 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/Sz9UdhoORaaSjzuOHRB4Jw/) | | MNLI | Matched acc. | 84.10 | 83.80 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/JgNCGHDJSRaW6HBx6YQFYQ/) |
| QNLI | Accuracy | 91.07 | 90.83 | 0.19 | [tfhub.dev](https://tensorboard.dev/experiment/zk6udb5MQAyAQ4eczrFBaQ/) | | QNLI | Accuracy | 91.01 | 90.82 | 0.17 | [tfhub.dev](https://tensorboard.dev/experiment/Bq7cMGJnQMSggYgL8qNGeQ/) |
| RTE | Accuracy | 66.06 | 64.76 | 1.04 | [tfhub.dev](https://tensorboard.dev/experiment/BwxaUoAEQ5aa3oQilEjADw/) | | RTE | Accuracy | 66.06 | 64.76 | 1.04 | [tfhub.dev](https://tensorboard.dev/experiment/66Eq24bhRjqN6CEhgDSGqQ/) |
| WNLI | Accuracy | 46.48 | 37.01 | 6.83 | [tfhub.dev](https://tensorboard.dev/experiment/b2Y8ouwMTRC8iBWzRzVYTA/) | | WNLI | Accuracy | 46.48 | 37.01 | 6.83 | [tfhub.dev](https://tensorboard.dev/experiment/TAqcnddqTkWvVEeGaWwIdQ/) |
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website. website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
......
...@@ -165,25 +165,17 @@ def create_train_state( ...@@ -165,25 +165,17 @@ def create_train_state(
logits_fn: Callable = struct.field(pytree_node=False) logits_fn: Callable = struct.field(pytree_node=False)
loss_fn: Callable = struct.field(pytree_node=False) loss_fn: Callable = struct.field(pytree_node=False)
# Creates a multi-optimizer consisting of two "Adam with weight decay" optimizers. # We use Optax's "masking" functionality to not apply weight decay
def adamw(decay): # to bias and LayerNorm scale parameters. decay_mask_fn returns a
return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=decay) # mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def traverse(fn): def decay_mask_fn(params):
def mask(data): flat_params = traverse_util.flatten_dict(params)
flat = traverse_util.flatten_dict(data) flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()}) return traverse_util.unflatten_dict(flat_mask)
return mask tx = optax.adamw(
learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn
# We use Optax's "masking" functionality to create a multi-optimizer, one
# with weight decay and the other without. Note masking means the optimizer
# will ignore these paths.
decay_path = lambda p: not any(x in p for x in ["bias", "LayerNorm.weight"]) # noqa: E731
tx = optax.chain(
optax.masked(adamw(0.0), mask=traverse(lambda path, _: decay_path(path))),
optax.masked(adamw(weight_decay), mask=traverse(lambda path, _: not decay_path(path))),
) )
if is_regression: if is_regression:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment