Unverified Commit 2d7020e2 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Fix wgrads for GroupedLinear when weights don't require grad (#1258)



Fix wgrad for GroupedLinear when weights doesn't require grad
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9001081d
......@@ -443,36 +443,38 @@ class _GroupedLinear(torch.autograd.Function):
clear_tensor_data(*inputmats)
clear_tensor_data(*inputmats_t)
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms
def handle_custom_ddp_from_mcore(w, wgrad):
if w.requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
wgrad = torch.zeros(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
def handle_custom_ddp_from_mcore(w, wgrad):
if w.requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
wgrad = torch.zeros(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = torch.empty(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
wgrad = torch.empty(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
wgrad = None
return wgrad
wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
]
else:
wgrad = None
return wgrad
wgrad_list = [None] * ctx.num_gemms
wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
]
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment