Unverified Commit a3b749b1 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

FSDP grad fusion support (#2191)



* FSDP grad fusion support
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Re-factored grad overwriting usage
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>

* Update transformer_engine/pytorch/ops/basic/basic_linear.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@nvidia.com>

* Update transformer_engine/pytorch/ops/fused/backward_linear_add.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@nvidia.com>

* Update transformer_engine/pytorch/ops/fused/backward_linear_scale.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@nvidia.com>

* Update transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@nvidia.com>

* Modified API usage, added arg details
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 76e1af33
......@@ -402,7 +402,11 @@ class _GroupedLinear(torch.autograd.Function):
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
accumulate=(
accumulate_wgrad_into_param_main_grad
if not getattr(weights[0], "overwrite_main_grad", False)
else False
),
)
# WGRAD
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
......@@ -519,7 +523,9 @@ class GroupedLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
......
......@@ -849,7 +849,11 @@ class _LayerNormLinear(torch.autograd.Function):
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
......@@ -1125,7 +1129,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
......
......@@ -948,7 +948,11 @@ class _LayerNormMLP(torch.autograd.Function):
else ctx.activation_dtype
),
"quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc1_weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
......@@ -1189,7 +1193,11 @@ class _LayerNormMLP(torch.autograd.Function):
else ctx.activation_dtype
),
"quantization_params": ctx.fc1_grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc2_weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
......@@ -1484,7 +1492,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the
......
......@@ -843,7 +843,11 @@ class _Linear(torch.autograd.Function):
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
......@@ -1061,7 +1065,9 @@ class Linear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
......
......@@ -80,7 +80,9 @@ class BasicLinear(BasicOperation):
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful. This is primarily intented to integrate with
Megatron-LM.
Megatron-LM. This argument along with weight tensor having
attribute 'overwrite_main_grad' set to True will overwrite
`main_grad` instead of accumulating.
userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly
......@@ -1019,6 +1021,7 @@ class BasicLinear(BasicOperation):
weight_param = self.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
......
......@@ -59,6 +59,7 @@ class BackwardLinearAdd(FusedOperation):
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
......
......@@ -60,6 +60,7 @@ class BackwardLinearScale(FusedOperation):
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
......
......@@ -523,6 +523,7 @@ class UserbuffersBackwardLinear(FusedOperation):
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment