Enable mixed precision training for hubert_pretrain_model (#2854)
Summary: address https://github.com/pytorch/audio/issues/2847 In mixed precision training, the dtype of `mask_embedding` is **not** converted to fp16 automatically. This PR addresses the issue by changing the dtype of `mask_embedding` to `x` to enable mixed precision training. Pull Request resolved: https://github.com/pytorch/audio/pull/2854 Reviewed By: carolineechen Differential Revision: D41343486 Pulled By: nateanl fbshipit-source-id: 4a5cbb429ff8ba5d3c439a3d5acb5094f66bf705
Showing
Please register or sign in to comment