Unverified Commit 2d7020e2 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Fix wgrads for GroupedLinear when weights don't require grad (#1258)



Fix wgrad for GroupedLinear when weights doesn't require grad
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9001081d
...@@ -443,9 +443,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -443,9 +443,6 @@ class _GroupedLinear(torch.autograd.Function):
clear_tensor_data(*inputmats) clear_tensor_data(*inputmats)
clear_tensor_data(*inputmats_t) clear_tensor_data(*inputmats_t)
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms
def handle_custom_ddp_from_mcore(w, wgrad): def handle_custom_ddp_from_mcore(w, wgrad):
if w.requires_grad: if w.requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
...@@ -473,6 +470,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -473,6 +470,11 @@ class _GroupedLinear(torch.autograd.Function):
wgrad_list = [ wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list) handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
] ]
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment