Manually launch wgrad accumulation and reduce in backward_dw() instead of backward() (#1976)
* disable wgrad accumulation and reduce in backward() And manually launch it in backward_dw() Signed-off-by:Hongbin Liu <hongbinl@nvidia.com> * format Signed-off-by:
Hongbin Liu <hongbinl@nvidia.com> * refactor Signed-off-by:
Hongbin Liu <hongbinl@nvidia.com> * refactor Signed-off-by:
Hongbin Liu <hongbinl@nvidia.com> * set skip_backward_post_hook to True only if delay_wgrad_compute is True Signed-off-by:
Hongbin Liu <hongbinl@nvidia.com> * format Signed-off-by:
Hongbin Liu <hongbinl@nvidia.com> --------- Signed-off-by:
Hongbin Liu <hongbinl@nvidia.com> Co-authored-by:
Hongbin Liu <hongbinl@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Showing
Please register or sign in to comment