Unverified Commit 81332868 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Example scripts - correct weight decay (#12409)

* fix_torch_device_generate_test

* remove @

* finish

* finish

* correct style
parent aecae533
......@@ -477,9 +477,15 @@ def main():
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxGPT2.
# For other models, one should correct the layer norm parameter naming
# accordingly.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
flat_mask = {
path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
for path in flat_params
}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
......
......@@ -508,6 +508,9 @@ if __name__ == "__main__":
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxBERT-like models.
# For other models, one should correct the layer norm parameter naming
# accordingly.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
......
......@@ -626,7 +626,10 @@ if __name__ == "__main__":
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
flat_mask = {
path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
for path in flat_params
}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
......
......@@ -578,9 +578,15 @@ def main():
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxBart.
# For FlaxT5, one should correct the layer norm parameter naming
# accordingly - see `run_t5_mlm_flax.py` e.g.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
layer_norm_params = [
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
]
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment