Unverified Commit a3b749b1 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

FSDP grad fusion support (#2191)



* FSDP grad fusion support
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Re-factored grad overwriting usage
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>

* Update transformer_engine/pytorch/ops/basic/basic_linear.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@nvidia.com>

* Update transformer_engine/pytorch/ops/fused/backward_linear_add.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@nvidia.com>

* Update transformer_engine/pytorch/ops/fused/backward_linear_scale.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@nvidia.com>

* Update transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@nvidia.com>

* Modified API usage, added arg details
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 76e1af33
...@@ -402,7 +402,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -402,7 +402,11 @@ class _GroupedLinear(torch.autograd.Function):
use_bias=ctx.use_bias if grad_biases[0] is None else None, use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases, bias=biases,
use_split_accumulator=wgrad_gemm_use_split_accumulator, use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=(
accumulate_wgrad_into_param_main_grad
if not getattr(weights[0], "overwrite_main_grad", False)
else False
),
) )
# WGRAD # WGRAD
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
...@@ -519,7 +523,9 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -519,7 +523,9 @@ class GroupedLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in. size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
......
...@@ -849,7 +849,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -849,7 +849,11 @@ class _LayerNormLinear(torch.autograd.Function):
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
"quantization_params": ctx.grad_weight_quantizer, "quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad, "accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
else False
),
"layout": "NT", "layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None, "out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None), "bias": (bias if (grad_bias is None and not ctx.fp8) else None),
...@@ -1125,7 +1129,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1125,7 +1129,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in. size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
......
...@@ -948,7 +948,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -948,7 +948,11 @@ class _LayerNormMLP(torch.autograd.Function):
else ctx.activation_dtype else ctx.activation_dtype
), ),
"quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision
"accumulate": accumulate_wgrad_into_param_main_grad, "accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc1_weight, "overwrite_main_grad", False)
else False
),
"layout": "NT", "layout": "NT",
"out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, "out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, "bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
...@@ -1189,7 +1193,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1189,7 +1193,11 @@ class _LayerNormMLP(torch.autograd.Function):
else ctx.activation_dtype else ctx.activation_dtype
), ),
"quantization_params": ctx.fc1_grad_weight_quantizer, "quantization_params": ctx.fc1_grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad, "accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc2_weight, "overwrite_main_grad", False)
else False
),
"layout": "NT", "layout": "NT",
"out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, "out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, "bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
...@@ -1484,7 +1492,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1484,7 +1492,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in. size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias for FC2, but when set to `True`, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
......
...@@ -843,7 +843,11 @@ class _Linear(torch.autograd.Function): ...@@ -843,7 +843,11 @@ class _Linear(torch.autograd.Function):
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
"quantization_params": ctx.grad_weight_quantizer, "quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad, "accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
else False
),
"layout": "NT", "layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None, "out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None), "bias": (bias if (grad_bias is None and not ctx.fp8) else None),
...@@ -1061,7 +1065,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -1061,7 +1065,9 @@ class Linear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in. size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False` return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
......
...@@ -80,7 +80,9 @@ class BasicLinear(BasicOperation): ...@@ -80,7 +80,9 @@ class BasicLinear(BasicOperation):
autograd. The weight's `main_grad` must be set externally and autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be there is no guarantee that `grad` will be set or be
meaningful. This is primarily intented to integrate with meaningful. This is primarily intented to integrate with
Megatron-LM. Megatron-LM. This argument along with weight tensor having
attribute 'overwrite_main_grad' set to True will overwrite
`main_grad` instead of accumulating.
userbuffers_options, dict, optional userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly compute using Userbuffers. This feature is highly
...@@ -1019,6 +1021,7 @@ class BasicLinear(BasicOperation): ...@@ -1019,6 +1021,7 @@ class BasicLinear(BasicOperation):
weight_param = self.weight weight_param = self.weight
if hasattr(weight_param, "__fsdp_param__"): if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad() weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
......
...@@ -59,6 +59,7 @@ class BackwardLinearAdd(FusedOperation): ...@@ -59,6 +59,7 @@ class BackwardLinearAdd(FusedOperation):
weight_param = linear_op.weight weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"): if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad() weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
......
...@@ -60,6 +60,7 @@ class BackwardLinearScale(FusedOperation): ...@@ -60,6 +60,7 @@ class BackwardLinearScale(FusedOperation):
weight_param = linear_op.weight weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"): if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad() weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
......
...@@ -523,6 +523,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -523,6 +523,7 @@ class UserbuffersBackwardLinear(FusedOperation):
weight_param = linear_op.weight weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"): if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad() weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment