"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "dc903e70acf9dba74d6afaa50e7b5650d6b9338a"
Unverified Commit b59d1d8b authored by Li Tao's avatar Li Tao Committed by GitHub
Browse files

[PyTorch] Fix issues for MCore DDP in grouped GEMM. (#1609)



fix mcore DDP error
Signed-off-by: default avatarlit <lit@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 945a559b
...@@ -78,8 +78,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -78,8 +78,8 @@ class _GroupedLinear(torch.autograd.Function):
skip_fp8_weight_update, skip_fp8_weight_update,
*weights_and_biases, *weights_and_biases,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
num_gemms = len(m_splits) num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:] biases = weights_and_biases[num_gemms:]
...@@ -180,7 +180,12 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -180,7 +180,12 @@ class _GroupedLinear(torch.autograd.Function):
ctx.weights_shape_1 = weights[0].shape[1] ctx.weights_shape_1 = weights[0].shape[1]
tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
*weights_fp8,
*weights,
*biases,
)
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
...@@ -220,7 +225,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -220,7 +225,8 @@ class _GroupedLinear(torch.autograd.Function):
N = ctx.num_gemms N = ctx.num_gemms
inputmats = saved_tensors[:N] inputmats = saved_tensors[:N]
weights = saved_tensors[N : 2 * N] weights = saved_tensors[N : 2 * N]
biases = saved_tensors[2 * N : 3 * N] origin_weights = saved_tensors[2 * N : 3 * N]
biases = saved_tensors[3 * N : 4 * N]
main_grads = ctx.main_grads main_grads = ctx.main_grads
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO
...@@ -311,21 +317,24 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -311,21 +317,24 @@ class _GroupedLinear(torch.autograd.Function):
# Deallocate input tensor # Deallocate input tensor
clear_tensor_data(*inputmats) clear_tensor_data(*inputmats)
def handle_custom_ddp_from_mcore(w, wgrad): def handle_custom_ddp_from_mcore(weight, wgrad):
if ctx.weights_requires_grad: if ctx.weights_requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): # Handle custom DDP from mcore.
w.grad_added_to_main_grad = True if ctx.fuse_wgrad_accumulation and hasattr(
if getattr(w, "zero_out_wgrad", False): weight, "grad_added_to_main_grad"
):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros( wgrad = torch.zeros(
w.main_grad.shape, weight.main_grad.shape,
dtype=w.dtype, dtype=weight.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
else: else:
wgrad = torch.empty( wgrad = torch.empty(
w.main_grad.shape, weight.main_grad.shape,
dtype=w.dtype, dtype=weight.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
...@@ -336,7 +345,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -336,7 +345,8 @@ class _GroupedLinear(torch.autograd.Function):
return wgrad return wgrad
wgrad_list = [ wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list) handle_custom_ddp_from_mcore(weight, wgrad)
for weight, wgrad in zip(origin_weights, wgrad_list)
] ]
else: else:
wgrad_list = [None] * ctx.num_gemms wgrad_list = [None] * ctx.num_gemms
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment