Unverified Commit 978f1d72 authored by Zhenhuan Liu's avatar Zhenhuan Liu Committed by GitHub
Browse files

Fix issues for MCore DDP. (#1474)



* Fix issues for MCore DDP.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* Remove force data release for CPU offloading.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* Add preserved attributeds.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add main_grad to prevserved attributes.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* Change prepare_for_saving to original tensor and add .data to CPU hook.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

* Fix for LayernormLinear in FP8.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>

---------
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6673f165
...@@ -137,7 +137,9 @@ class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): ...@@ -137,7 +137,9 @@ class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
super().__init__() super().__init__()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any: def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) retrieve_identifier = self.offload_handler.tensor_push(
tensor.data, **self.handler_extra_kwargs
)
return retrieve_identifier return retrieve_identifier
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
......
...@@ -441,7 +441,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -441,7 +441,7 @@ class _LayerNormLinear(torch.autograd.Function):
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
weight, weight,
_, origin_weight,
bias, bias,
ln_weight, ln_weight,
ln_out, ln_out,
...@@ -722,17 +722,22 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -722,17 +722,22 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.requires_wgrad: if ctx.requires_wgrad:
# Handle custom DDP from mcore. # Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"): if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"):
weight.grad_added_to_main_grad = True origin_weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False): if getattr(origin_weight, "zero_out_wgrad", False):
wgrad = torch.zeros( wgrad = torch.zeros(
weight.main_grad.shape, origin_weight.main_grad.shape,
dtype=weight.dtype, dtype=origin_weight.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False, requires_grad=False,
) )
else: else:
wgrad = None wgrad = torch.empty(
origin_weight.main_grad.shape,
dtype=origin_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
wgrad = None wgrad = None
else: else:
......
...@@ -606,7 +606,12 @@ class _Linear(torch.autograd.Function): ...@@ -606,7 +606,12 @@ class _Linear(torch.autograd.Function):
requires_grad=False, requires_grad=False,
) )
else: else:
wgrad = None wgrad = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
wgrad = None wgrad = None
else: else:
......
...@@ -28,7 +28,7 @@ def prepare_for_saving( ...@@ -28,7 +28,7 @@ def prepare_for_saving(
tensor_list.append(None) tensor_list.append(None)
tensor_objects_list.append(None) tensor_objects_list.append(None)
elif type(tensor) in (torch.Tensor, torch.nn.Parameter): elif type(tensor) in (torch.Tensor, torch.nn.Parameter):
tensor_list.append(tensor.data) tensor_list.append(tensor)
tensor_objects_list.append(None) tensor_objects_list.append(None)
else: else:
t, t_obj = tensor.prepare_for_saving() t, t_obj = tensor.prepare_for_saving()
...@@ -116,10 +116,7 @@ class Quantizer(abc.ABC): ...@@ -116,10 +116,7 @@ class Quantizer(abc.ABC):
"""Quantize tensor in-place""" """Quantize tensor in-place"""
def quantize( def quantize(
self, self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None
tensor: torch.Tensor,
*,
out: Optional[QuantizedTensor] = None,
) -> QuantizedTensor: ) -> QuantizedTensor:
"""Quantize tensor""" """Quantize tensor"""
if out is not None: if out is not None:
...@@ -159,10 +156,7 @@ class Quantizer(abc.ABC): ...@@ -159,10 +156,7 @@ class Quantizer(abc.ABC):
""" """
def set_usage( def set_usage(
self, self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None
*,
rowwise: Optional[bool] = None,
columnwise: Optional[bool] = None,
) -> None: ) -> None:
"""Set how the quantized tensor is expected to be used """Set how the quantized tensor is expected to be used
...@@ -194,8 +188,7 @@ class _QuantizeFunc(torch.autograd.Function): ...@@ -194,8 +188,7 @@ class _QuantizeFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward( def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused _ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]: ) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision # Assume that we want gradients in full precision
...@@ -212,9 +205,7 @@ class _IdentityFunc(torch.autograd.Function): ...@@ -212,9 +205,7 @@ class _IdentityFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None
tensor: QuantizedTensor,
init_kwargs: Optional[Dict[str, Any]] = None,
) -> QuantizedTensor: ) -> QuantizedTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -408,8 +399,7 @@ class QuantizedTensor(torch.Tensor): ...@@ -408,8 +399,7 @@ class QuantizedTensor(torch.Tensor):
return torch._C._disabled_torch_function_impl(func, types, args, kwargs) return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
def contiguous( def contiguous(
self, self, memory_format: torch.memory_format = torch.contiguous_format
memory_format: torch.memory_format = torch.contiguous_format,
) -> QuantizedTensor: ) -> QuantizedTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
raise NotImplementedError( raise NotImplementedError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment