Unverified Commit d2945c6a authored by Tong Liu's avatar Tong Liu Committed by GitHub
Browse files

[PyTorch] Use dummy wgrad in GroupedLinear (#2305)



dummy wgrad
Signed-off-by: default avatartongliu <tongliu@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 87cb26c6
...@@ -13,6 +13,7 @@ import transformer_engine_torch as tex ...@@ -13,6 +13,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from .base import ( from .base import (
get_dummy_wgrad,
get_multi_stream_cublas_workspace, get_multi_stream_cublas_workspace,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
...@@ -447,18 +448,15 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -447,18 +448,15 @@ class _GroupedLinear(torch.autograd.Function):
): ):
weight.grad_added_to_main_grad = True weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False): if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros( wgrad = get_dummy_wgrad(
weight.main_grad.shape, list(weight.main_grad.shape),
dtype=weight.dtype, weight.dtype,
device=torch.cuda.current_device(), zero=True,
requires_grad=False,
) )
else: else:
wgrad = torch.empty( wgrad = get_dummy_wgrad(
weight.main_grad.shape, list(weight.main_grad.shape),
dtype=weight.dtype, weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
) )
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
wgrad = None wgrad = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment