add set_weight_decay to support custom weight decay setting (#5671)
* add set_weight_decay
* Update _utils.py
* refactor code
* fix import
* add set_weight_decay
* fix lint
* fix lint
* replace split_normalization_params with set_weight_decay
* simplfy the code
* refactor code
* refactor code
* fix lint
* remove unused
* Update test_ops.py
* Update train.py
* Update _utils.py
* Update train.py
* add set_weight_decay
* add set_weight_decay
* Update _utils.py
* Update test_ops.py
* Change `--transformer-weight-decay` to `--transformer-embedding-decay`
Co-authored-by:
Vasilis Vryniotis <datumbox@users.noreply.github.com>
Showing
Please register or sign in to comment