[PyTorch] Fix fuse_wgrad_accumulation for GroupedLinear (#1488)
* fix fuse_wgrad_accumulation for GroupedLinear Signed-off-by:Xin Yao <xiny@nvidia.com> * fix fuse_wgrad_accumulation for GroupedLinear Signed-off-by:
Xin Yao <xiny@nvidia.com> * update tests Signed-off-by:
Xin Yao <xiny@nvidia.com> --------- Signed-off-by:
Xin Yao <xiny@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com>
Showing
Please register or sign in to comment