Commit 2389ed3f authored by yuguo's avatar yuguo
Browse files

Merge branch 'release_v2.7' of https://github.com/NVIDIA/TransformerEngine into release_v2.7

parents 87e3e56e 58c3ac80
......@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter
from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size
from ...module.base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub,
get_workspace,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...tensor.quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
......@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation):
# Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"):
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = linear_op.weight.main_grad.detach()
grad_weight = weight_param.main_grad.detach()
else:
accumulate_into_main_grad = False
......@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation):
# Clear input tensor if possible
clear_tensor_data(x_local)
# Return gradients
grad_params = [() for _ in range(len(self.basic_ops))]
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
# Return gradients
grad_params = [() for _ in range(len(self.basic_ops))]
grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,)
......
......@@ -54,7 +54,8 @@ class Linear(FusedOperation):
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
meaningful. This is primarily intented to integrate with
Megatron-LM.
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment