• Nicholas Vadivelu's avatar
    Fix weight decay masking in `run_flax_glue.py` (#11964) · 4674061b
    Nicholas Vadivelu authored
    
    
    * Fix weight decay masking in `run_flax_glue.py`
    
    Issues with the previous implementation:
    - The `dict` from `traverse_util.flatten_dict` has keys which are tuples of strings, not one long string with the path separated by periods.
    - `optax.masked` applies the transformation wherever the mask is True, so the masks are flipped.
    - Flax's LayerNorm calls the scale parameter `scale` not `weight`
    
    * Fix formatting with black
    
    * adapt results
    Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
    4674061b
run_flax_glue.py 20.1 KB