Unverified Commit d2945c6a authored by Tong Liu's avatar Tong Liu Committed by GitHub
Browse files

[PyTorch] Use dummy wgrad in GroupedLinear (#2305)



dummy wgrad
Signed-off-by: default avatartongliu <tongliu@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 87cb26c6
......@@ -13,6 +13,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from .base import (
get_dummy_wgrad,
get_multi_stream_cublas_workspace,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
......@@ -447,18 +448,15 @@ class _GroupedLinear(torch.autograd.Function):
):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
zero=True,
)
else:
wgrad = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment