Unverified Commit d88137c4 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug Mcore wgrad fusion with te.ops (#2097)



* Return dummy wgrad tensors when requested by Mcore
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarJan Bielak <janekb04@icloud.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarJan Bielak <janekb04@icloud.com>
parent 40dde4dd
...@@ -12,7 +12,6 @@ from typing import Any, Optional ...@@ -12,7 +12,6 @@ from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.module.base import get_workspace
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...distributed import ( from ...distributed import (
CudaRNGStatesTracker, CudaRNGStatesTracker,
...@@ -20,18 +19,24 @@ from ...distributed import ( ...@@ -20,18 +19,24 @@ from ...distributed import (
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
) )
from ...fp8 import FP8GlobalStateManager, Recipe from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...module.base import (
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
get_dummy_wgrad,
get_workspace,
)
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
clear_tensor_data, clear_tensor_data,
devices_match, devices_match,
) )
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
def _wait_async(handle: Optional[Any]) -> None: def _wait_async(handle: Optional[Any]) -> None:
...@@ -73,7 +78,8 @@ class BasicLinear(BasicOperation): ...@@ -73,7 +78,8 @@ class BasicLinear(BasicOperation):
weight's `main_grad` attribute instead of relying on PyTorch weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be there is no guarantee that `grad` will be set or be
meaningful. meaningful. This is primarily intented to integrate with
Megatron-LM.
userbuffers_options, dict, optional userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly compute using Userbuffers. This feature is highly
...@@ -979,20 +985,22 @@ class BasicLinear(BasicOperation): ...@@ -979,20 +985,22 @@ class BasicLinear(BasicOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = ctx.saved_tensors (x_local, w) = ctx.saved_tensors
# wgrad fusion # Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = self._accumulate_into_main_grad accumulate_into_main_grad = self._accumulate_into_main_grad
grad_weight = None grad_weight = None
if ctx.weight_requires_grad and accumulate_into_main_grad: if ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(self.weight, "__fsdp_param__"): weight_param = self.weight
self.weight.main_grad = self.weight.get_main_grad() if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(self.weight, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
"accumulate_into_main_grad=True, " "accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute" "but weight parameter does not have main_grad attribute"
) )
grad_weight = self.weight.main_grad.detach() grad_weight = weight_param.main_grad.detach()
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
...@@ -1019,6 +1027,17 @@ class BasicLinear(BasicOperation): ...@@ -1019,6 +1027,17 @@ class BasicLinear(BasicOperation):
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad: if accumulate_into_main_grad:
grad_weight = None grad_weight = None
weight_param = self.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [grad_weight] return grad_input, [grad_weight]
...@@ -9,13 +9,10 @@ from typing import Optional ...@@ -9,13 +9,10 @@ from typing import Optional
import torch import torch
from transformer_engine.pytorch.ops.basic import BasicLinear, MakeExtraOutput from ...module.base import get_dummy_wgrad
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ..basic import BasicLinear, MakeExtraOutput
from ..op import FusedOperation, FusibleOperation, OperationContext
class BackwardLinearAdd(FusedOperation): class BackwardLinearAdd(FusedOperation):
...@@ -53,20 +50,22 @@ class BackwardLinearAdd(FusedOperation): ...@@ -53,20 +50,22 @@ class BackwardLinearAdd(FusedOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"): weight_param = linear_op.weight
linear_op.weight.main_grad = linear_op.weight.get_main_grad() if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
"accumulate_into_main_grad=True, " "accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute" "but weight parameter does not have main_grad attribute"
) )
grad_weight = linear_op.weight.main_grad.detach() grad_weight = weight_param.main_grad.detach()
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
...@@ -92,12 +91,23 @@ class BackwardLinearAdd(FusedOperation): ...@@ -92,12 +91,23 @@ class BackwardLinearAdd(FusedOperation):
grad_output_quantizer=linear_op_ctx.grad_output_quantizer, grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer, grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
) )
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(grad_weight,), ()], [(), ()] return grad_input, [(grad_weight,), ()], [(), ()]
......
...@@ -9,13 +9,10 @@ from typing import Optional ...@@ -9,13 +9,10 @@ from typing import Optional
import torch import torch
from ..basic import BasicLinear, ConstantScale from ...module.base import get_dummy_wgrad
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ..basic import BasicLinear, ConstantScale
from ..op import FusedOperation, FusibleOperation, OperationContext
class BackwardLinearScale(FusedOperation): class BackwardLinearScale(FusedOperation):
...@@ -54,20 +51,22 @@ class BackwardLinearScale(FusedOperation): ...@@ -54,20 +51,22 @@ class BackwardLinearScale(FusedOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"): weight_param = linear_op.weight
linear_op.weight.main_grad = linear_op.weight.get_main_grad() if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
"accumulate_into_main_grad=True, " "accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute" "but weight parameter does not have main_grad attribute"
) )
grad_weight = linear_op.weight.main_grad.detach() grad_weight = weight_param.main_grad.detach()
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
...@@ -92,12 +91,23 @@ class BackwardLinearScale(FusedOperation): ...@@ -92,12 +91,23 @@ class BackwardLinearScale(FusedOperation):
grad_output_quantizer=linear_op_ctx.grad_output_quantizer, grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer, grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
) )
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(), (grad_weight,)], [(), ()] return grad_input, [(), (grad_weight,)], [(), ()]
......
...@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter ...@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size from ...distributed import get_distributed_world_size
from ...module.base import ( from ...module.base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub, get_ub,
get_workspace, get_workspace,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
) )
from ...tensor.quantized_tensor import Quantizer from ...tensor.quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
...@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"): weight_param = linear_op.weight
linear_op.weight.main_grad = linear_op.weight.get_main_grad() if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
"accumulate_into_main_grad=True, " "accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute" "but weight parameter does not have main_grad attribute"
) )
grad_weight = linear_op.weight.main_grad.detach() grad_weight = weight_param.main_grad.detach()
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
...@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation):
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Return gradients # Megatron-LM wgrad fusion
grad_params = [() for _ in range(len(self.basic_ops))] # Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad: if accumulate_into_main_grad:
grad_weight = None grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
# Return gradients
grad_params = [() for _ in range(len(self.basic_ops))]
grad_params[self._op_idxs["linear"]] = (grad_weight,) grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None: if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,) grad_params[self._op_idxs["bias"]] = (grad_bias,)
......
...@@ -54,7 +54,8 @@ class Linear(FusedOperation): ...@@ -54,7 +54,8 @@ class Linear(FusedOperation):
weight's `main_grad` attribute instead of relying on PyTorch weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be there is no guarantee that `grad` will be set or be
meaningful. meaningful. This is primarily intented to integrate with
Megatron-LM.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment