Commit 2389ed3f authored by yuguo's avatar yuguo
Browse files

Merge branch 'release_v2.7' of https://github.com/NVIDIA/TransformerEngine into release_v2.7

parents 87e3e56e 58c3ac80
...@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter ...@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size from ...distributed import get_distributed_world_size
from ...module.base import ( from ...module.base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub, get_ub,
get_workspace, get_workspace,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
) )
from ...tensor.quantized_tensor import Quantizer from ...tensor.quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
...@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"): weight_param = linear_op.weight
linear_op.weight.main_grad = linear_op.weight.get_main_grad() if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
"accumulate_into_main_grad=True, " "accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute" "but weight parameter does not have main_grad attribute"
) )
grad_weight = linear_op.weight.main_grad.detach() grad_weight = weight_param.main_grad.detach()
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
...@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation):
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Return gradients # Megatron-LM wgrad fusion
grad_params = [() for _ in range(len(self.basic_ops))] # Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad: if accumulate_into_main_grad:
grad_weight = None grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
# Return gradients
grad_params = [() for _ in range(len(self.basic_ops))]
grad_params[self._op_idxs["linear"]] = (grad_weight,) grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None: if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,) grad_params[self._op_idxs["bias"]] = (grad_bias,)
......
...@@ -54,7 +54,8 @@ class Linear(FusedOperation): ...@@ -54,7 +54,8 @@ class Linear(FusedOperation):
weight's `main_grad` attribute instead of relying on PyTorch weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be there is no guarantee that `grad` will be set or be
meaningful. meaningful. This is primarily intented to integrate with
Megatron-LM.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment