Unverified Commit 978f1d72 authored by Zhenhuan Liu's avatar Zhenhuan Liu Committed by GitHub
Browse files

Fix issues for MCore DDP. (#1474)



* Fix issues for MCore DDP.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* Remove force data release for CPU offloading.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* Add preserved attributeds.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add main_grad to prevserved attributes.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* Change prepare_for_saving to original tensor and add .data to CPU hook.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* Fix for LayernormLinear in FP8.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

---------
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6673f165
......@@ -137,7 +137,9 @@ class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
super().__init__()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)
retrieve_identifier = self.offload_handler.tensor_push(
tensor.data, **self.handler_extra_kwargs
)
return retrieve_identifier
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
......
......@@ -441,7 +441,7 @@ class _LayerNormLinear(torch.autograd.Function):
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
weight,
_,
origin_weight,
bias,
ln_weight,
ln_out,
......@@ -722,17 +722,22 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.requires_wgrad:
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"):
origin_weight.grad_added_to_main_grad = True
if getattr(origin_weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
weight.main_grad.shape,
dtype=weight.dtype,
origin_weight.main_grad.shape,
dtype=origin_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = None
wgrad = torch.empty(
origin_weight.main_grad.shape,
dtype=origin_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
......
......@@ -606,7 +606,12 @@ class _Linear(torch.autograd.Function):
requires_grad=False,
)
else:
wgrad = None
wgrad = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
......
......@@ -28,7 +28,7 @@ def prepare_for_saving(
tensor_list.append(None)
tensor_objects_list.append(None)
elif type(tensor) in (torch.Tensor, torch.nn.Parameter):
tensor_list.append(tensor.data)
tensor_list.append(tensor)
tensor_objects_list.append(None)
else:
t, t_obj = tensor.prepare_for_saving()
......@@ -116,10 +116,7 @@ class Quantizer(abc.ABC):
"""Quantize tensor in-place"""
def quantize(
self,
tensor: torch.Tensor,
*,
out: Optional[QuantizedTensor] = None,
self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None
) -> QuantizedTensor:
"""Quantize tensor"""
if out is not None:
......@@ -159,10 +156,7 @@ class Quantizer(abc.ABC):
"""
def set_usage(
self,
*,
rowwise: Optional[bool] = None,
columnwise: Optional[bool] = None,
self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None
) -> None:
"""Set how the quantized tensor is expected to be used
......@@ -194,8 +188,7 @@ class _QuantizeFunc(torch.autograd.Function):
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
_ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
......@@ -212,9 +205,7 @@ class _IdentityFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
tensor: QuantizedTensor,
init_kwargs: Optional[Dict[str, Any]] = None,
ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
......@@ -408,8 +399,7 @@ class QuantizedTensor(torch.Tensor):
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
def contiguous(
self,
memory_format: torch.memory_format = torch.contiguous_format,
self, memory_format: torch.memory_format = torch.contiguous_format
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
raise NotImplementedError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment