Unverified Commit 2d7020e2 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Fix wgrads for GroupedLinear when weights don't require grad (#1258)



Fix wgrad for GroupedLinear when weights doesn't require grad
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9001081d
...@@ -443,36 +443,38 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -443,36 +443,38 @@ class _GroupedLinear(torch.autograd.Function):
clear_tensor_data(*inputmats) clear_tensor_data(*inputmats)
clear_tensor_data(*inputmats_t) clear_tensor_data(*inputmats_t)
if not ctx.use_bias: def handle_custom_ddp_from_mcore(w, wgrad):
grad_biases = [None] * ctx.num_gemms if w.requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
def handle_custom_ddp_from_mcore(w, wgrad): w.grad_added_to_main_grad = True
if w.requires_grad: if getattr(w, "zero_out_wgrad", False):
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): wgrad = torch.zeros(
w.grad_added_to_main_grad = True w.main_grad.shape,
if getattr(w, "zero_out_wgrad", False): dtype=w.dtype,
wgrad = torch.zeros( device=torch.cuda.current_device(),
w.main_grad.shape, requires_grad=False,
dtype=w.dtype, )
device=torch.cuda.current_device(), else:
requires_grad=False, wgrad = torch.empty(
) w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else: else:
wgrad = torch.empty( wgrad = None
w.main_grad.shape, return wgrad
dtype=w.dtype,
device=torch.cuda.current_device(), wgrad_list = [
requires_grad=False, handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
) ]
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else: else:
wgrad = None wgrad_list = [None] * ctx.num_gemms
return wgrad
wgrad_list = [ if not ctx.use_bias:
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list) grad_biases = [None] * ctx.num_gemms
]
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment