Unverified Commit f196d14b authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Activation offloading to CPU's for the Linear, Layernorm Linear and the...


Activation offloading to CPU's for the Linear, Layernorm Linear and the Layernorm MLP modules (#571)

* Added support activation offloading to CPU's
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Moving CPU offloading library to TE
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Restructured code, added switch to choose between weight/activation offloading
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Removed arg during constructor
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fix nit-pick errors
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Documentation fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix to the code block in docs
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added offloading unit test
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed formatting
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* wgrad fusion fix, minor errors and lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Errors, test, lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* RM test file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixed stray PyT tensors in LayernormMLP getting offloaded
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed typi
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fix offloading for rmsnorm, rm test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Float8Tensor compatible offloading
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent bacefdbb
......@@ -40,3 +40,5 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.checkpoint
.. autoapifunction:: transformer_engine.pytorch.onnx_export
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
......@@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import Optional
from contextlib import nullcontext
import torch
import pytest
......@@ -20,6 +21,7 @@ from transformer_engine.pytorch import (
TransformerLayer,
RMSNorm,
LayerNorm,
get_cpu_offload_context,
)
from transformer_engine.common import recipe
......@@ -215,7 +217,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
......@@ -223,9 +225,16 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
if skip_wgrad:
_disable_wgrads(block)
if cpu_offload:
offload_context, sync_function = get_cpu_offload_context(enabled=True)
else:
offload_context = nullcontext()
sync_function = lambda x: x
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
te_out = block(te_inp_hidden_states)
te_out = sync_function(te_out)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
......@@ -449,9 +458,11 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean)
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, bias, activation,
normalization, parallel_attention_mlp):
normalization, parallel_attention_mlp,
cpu_offload):
config = model_configs[model]
if fp8_recipe is not None:
......@@ -489,7 +500,7 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
.cuda()
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
def test_sanity_gpt_126m():
......@@ -512,6 +523,7 @@ def test_sanity_gpt_126m():
activation="gelu",
normalization="LayerNorm",
parallel_attention_mlp=False,
cpu_offload=False,
)
......@@ -713,7 +725,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
.cuda()
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
@pytest.mark.parametrize("dtype", param_types)
......@@ -751,7 +763,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
.cuda()
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
@pytest.mark.parametrize("dtype", param_types)
......
......@@ -17,6 +17,7 @@ from .fp8 import fp8_model_init
from .export import onnx_export
from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker
from .cpu_offload import get_cpu_offload_context
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
onnx_cast_to_fp8,
......
This diff is collapsed.
......@@ -42,7 +42,6 @@ from ..jit import no_torch_dynamo
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor
__all__ = ["LayerNormLinear"]
......@@ -68,6 +67,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
......@@ -239,12 +239,27 @@ class _LayerNormLinear(torch.autograd.Function):
)
if is_grad_enabled:
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8:
weight_t_fp8.weight_offloading = True
ln_weight.weight_offloading = True
weight.weight_offloading = True
inputmat.activation_offloading = True
if normalization == "LayerNorm":
mu.activation_offloading = True
rsigma.activation_offloading = True
ln_out.activation_offloading = True
ctx.save_for_backward(
inputmat,
ln_weight,
mu,
rsigma,
weight,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8,
ln_out,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
......@@ -254,6 +269,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
......@@ -298,11 +314,16 @@ class _LayerNormLinear(torch.autograd.Function):
mu,
rsigma,
weight,
main_grad,
weight_t_fp8,
ln_out,
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False)
weight.main_grad = main_grad
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
......@@ -582,6 +603,7 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -992,6 +1014,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch
)
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
args = []
......@@ -1013,6 +1037,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
self.tp_group,
self.tp_size,
self.sequence_parallel,
......
......@@ -51,7 +51,6 @@ from ..jit import no_torch_dynamo
from ..float8_tensor import Float8Tensor
from ._common import _apply_normalization
__all__ = ["LayerNormMLP"]
......@@ -95,6 +94,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
......@@ -420,6 +420,26 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(gelu_out)
if is_grad_enabled:
if cpu_offloading:
if fuse_wgrad_accumulation:
fc1_weight.main_grad.weight_offloading = True
fc2_weight.main_grad.weight_offloading = True
if fp8:
fc1_weight_t_fp8.weight_offloading = True
fc2_weight_t_fp8.weight_offloading = True
ln_weight.weight_offloading = True
fc1_weight.weight_offloading = True
fc2_weight.weight_offloading = True
fc1_bias.weight_offloading = True
inputmat.activation_offloading = True
if normalization == "LayerNorm":
mu.activation_offloading = True
rsigma.activation_offloading = True
ln_out.activation_offloading = True
fc1_out.activation_offloading = True
gelu_out.activation_offloading = True
ctx.save_for_backward(
inputmat,
ln_weight,
......@@ -429,8 +449,10 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_out,
gelu_out,
fc1_weight,
fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None,
fc1_weight_t_fp8,
fc2_weight,
fc2_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None,
fc2_weight_t_fp8,
fc1_bias,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
......@@ -440,6 +462,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_fc1_bias = use_fc1_bias
ctx.use_fc2_bias = use_fc2_bias
......@@ -492,13 +515,22 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_out,
gelu_out,
fc1_weight,
fc1_weight_main_grad,
fc1_weight_t_fp8,
fc2_weight,
fc2_weight_main_grad,
fc2_weight_t_fp8,
fc1_bias,
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
fc1_weight = Parameter(fc1_weight, False)
fc2_weight = Parameter(fc2_weight, False)
fc1_weight.main_grad = fc1_weight_main_grad
fc2_weight.main_grad = fc2_weight_main_grad
# Primary weights are in FP8.
if ctx.fp8 and fc1_weight_t_fp8 is None:
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=ctx.is_first_microbatch)
......@@ -993,6 +1025,7 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1336,6 +1369,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
is_first_microbatch
)
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
fwd_fn = _LayerNormMLP.apply
args = []
......@@ -1362,6 +1397,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
self.tp_group,
self.tp_size,
self.sequence_parallel,
......
......@@ -45,7 +45,6 @@ from ..jit import no_torch_dynamo
from ..float8_tensor import Float8Tensor
__all__ = ["Linear"]
......@@ -68,6 +67,7 @@ class _Linear(torch.autograd.Function):
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
......@@ -266,12 +266,26 @@ class _Linear(torch.autograd.Function):
saved_inputmat = inputmat
else:
saved_inputmat_t = inputmat_t
if cpu_offloading:
saved_inputmat_t.activation_offloading = True
else:
saved_inputmat = inputmat_no_fp8
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8:
weight_t_fp8.weight_offloading = True
weight.weight_offloading = True
if saved_inputmat is not None:
saved_inputmat.activation_offloading = True
ctx.save_for_backward(
saved_inputmat,
saved_inputmat_t,
weight,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
......@@ -279,6 +293,7 @@ class _Linear(torch.autograd.Function):
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
......@@ -315,10 +330,15 @@ class _Linear(torch.autograd.Function):
inputmat,
inputmat_t,
weight,
main_grad,
weight_t_fp8,
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False)
weight.main_grad = main_grad
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
......@@ -515,6 +535,7 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -862,6 +883,8 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch
)
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
linear_fn = _Linear.apply
args = []
......@@ -880,6 +903,7 @@ class Linear(TransformerEngineBaseModule):
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
self.tp_group,
self.tp_size,
self.sequence_parallel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment