Unverified Commit b59d1d8b authored by Li Tao's avatar Li Tao Committed by GitHub
Browse files

[PyTorch] Fix issues for MCore DDP in grouped GEMM. (#1609)



fix mcore DDP error
Signed-off-by: default avatarlit <lit@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 945a559b
......@@ -78,8 +78,8 @@ class _GroupedLinear(torch.autograd.Function):
skip_fp8_weight_update,
*weights_and_biases,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:]
......@@ -180,7 +180,12 @@ class _GroupedLinear(torch.autograd.Function):
ctx.weights_shape_1 = weights[0].shape[1]
tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases)
tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
*weights_fp8,
*weights,
*biases,
)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
......@@ -220,7 +225,8 @@ class _GroupedLinear(torch.autograd.Function):
N = ctx.num_gemms
inputmats = saved_tensors[:N]
weights = saved_tensors[N : 2 * N]
biases = saved_tensors[2 * N : 3 * N]
origin_weights = saved_tensors[2 * N : 3 * N]
biases = saved_tensors[3 * N : 4 * N]
main_grads = ctx.main_grads
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO
......@@ -311,21 +317,24 @@ class _GroupedLinear(torch.autograd.Function):
# Deallocate input tensor
clear_tensor_data(*inputmats)
def handle_custom_ddp_from_mcore(w, wgrad):
def handle_custom_ddp_from_mcore(weight, wgrad):
if ctx.weights_requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(
weight, "grad_added_to_main_grad"
):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
w.main_grad.shape,
dtype=w.dtype,
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = torch.empty(
w.main_grad.shape,
dtype=w.dtype,
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
......@@ -336,7 +345,8 @@ class _GroupedLinear(torch.autograd.Function):
return wgrad
wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
handle_custom_ddp_from_mcore(weight, wgrad)
for weight, wgrad in zip(origin_weights, wgrad_list)
]
else:
wgrad_list = [None] * ctx.num_gemms
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment