Unverified Commit f196d14b authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Activation offloading to CPU's for the Linear, Layernorm Linear and the...


Activation offloading to CPU's for the Linear, Layernorm Linear and the Layernorm MLP modules (#571)

* Added support activation offloading to CPU's
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Moving CPU offloading library to TE
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Restructured code, added switch to choose between weight/activation offloading
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Removed arg during constructor
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fix nit-pick errors
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Documentation fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix to the code block in docs
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added offloading unit test
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed formatting
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* wgrad fusion fix, minor errors and lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Errors, test, lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* RM test file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixed stray PyT tensors in LayernormMLP getting offloaded
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed typi
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fix offloading for rmsnorm, rm test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Float8Tensor compatible offloading
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent bacefdbb
...@@ -40,3 +40,5 @@ pyTorch ...@@ -40,3 +40,5 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.checkpoint .. autoapifunction:: transformer_engine.pytorch.checkpoint
.. autoapifunction:: transformer_engine.pytorch.onnx_export .. autoapifunction:: transformer_engine.pytorch.onnx_export
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from contextlib import nullcontext
import torch import torch
import pytest import pytest
...@@ -20,6 +21,7 @@ from transformer_engine.pytorch import ( ...@@ -20,6 +21,7 @@ from transformer_engine.pytorch import (
TransformerLayer, TransformerLayer,
RMSNorm, RMSNorm,
LayerNorm, LayerNorm,
get_cpu_offload_context,
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
...@@ -215,7 +217,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci ...@@ -215,7 +217,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated." assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
...@@ -223,9 +225,16 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -223,9 +225,16 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
if cpu_offload:
offload_context, sync_function = get_cpu_offload_context(enabled=True)
else:
offload_context = nullcontext()
sync_function = lambda x: x
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
te_out = block(te_inp_hidden_states) te_out = block(te_inp_hidden_states)
te_out = sync_function(te_out)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -449,9 +458,11 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad, ...@@ -449,9 +458,11 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean)
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad, def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, bias, activation, zero_centered_gamma, bias, activation,
normalization, parallel_attention_mlp): normalization, parallel_attention_mlp,
cpu_offload):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -489,7 +500,7 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad, ...@@ -489,7 +500,7 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
.cuda() .cuda()
) )
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
def test_sanity_gpt_126m(): def test_sanity_gpt_126m():
...@@ -512,6 +523,7 @@ def test_sanity_gpt_126m(): ...@@ -512,6 +523,7 @@ def test_sanity_gpt_126m():
activation="gelu", activation="gelu",
normalization="LayerNorm", normalization="LayerNorm",
parallel_attention_mlp=False, parallel_attention_mlp=False,
cpu_offload=False,
) )
...@@ -713,7 +725,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): ...@@ -713,7 +725,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
.cuda() .cuda()
) )
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -751,7 +763,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -751,7 +763,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
.cuda() .cuda()
) )
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
......
...@@ -17,6 +17,7 @@ from .fp8 import fp8_model_init ...@@ -17,6 +17,7 @@ from .fp8 import fp8_model_init
from .export import onnx_export from .export import onnx_export
from .distributed import checkpoint from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker from .distributed import CudaRNGStatesTracker
from .cpu_offload import get_cpu_offload_context
# Register custom op symbolic ONNX functions # Register custom op symbolic ONNX functions
from .te_onnx_extensions import ( from .te_onnx_extensions import (
onnx_cast_to_fp8, onnx_cast_to_fp8,
......
This diff is collapsed.
...@@ -42,7 +42,6 @@ from ..jit import no_torch_dynamo ...@@ -42,7 +42,6 @@ from ..jit import no_torch_dynamo
from ._common import _apply_normalization, _noop_cat from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
...@@ -68,6 +67,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -68,6 +67,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_calibration: bool, fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
sequence_parallel: bool, sequence_parallel: bool,
...@@ -239,12 +239,27 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -239,12 +239,27 @@ class _LayerNormLinear(torch.autograd.Function):
) )
if is_grad_enabled: if is_grad_enabled:
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8:
weight_t_fp8.weight_offloading = True
ln_weight.weight_offloading = True
weight.weight_offloading = True
inputmat.activation_offloading = True
if normalization == "LayerNorm":
mu.activation_offloading = True
rsigma.activation_offloading = True
ln_out.activation_offloading = True
ctx.save_for_backward( ctx.save_for_backward(
inputmat, inputmat,
ln_weight, ln_weight,
mu, mu,
rsigma, rsigma,
weight, weight,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8, weight_t_fp8,
ln_out, ln_out,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
...@@ -254,6 +269,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -254,6 +269,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
...@@ -298,11 +314,16 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -298,11 +314,16 @@ class _LayerNormLinear(torch.autograd.Function):
mu, mu,
rsigma, rsigma,
weight, weight,
main_grad,
weight_t_fp8, weight_t_fp8,
ln_out, ln_out,
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False)
weight.main_grad = main_grad
# Primary weights are in FP8. # Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None: if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch) weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
...@@ -582,6 +603,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -582,6 +603,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -992,6 +1014,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -992,6 +1014,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch is_first_microbatch
) )
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply fwd_fn = _LayerNormLinear.apply
args = [] args = []
...@@ -1013,6 +1037,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1013,6 +1037,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
self.sequence_parallel, self.sequence_parallel,
......
...@@ -51,7 +51,6 @@ from ..jit import no_torch_dynamo ...@@ -51,7 +51,6 @@ from ..jit import no_torch_dynamo
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
from ._common import _apply_normalization from ._common import _apply_normalization
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
...@@ -95,6 +94,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -95,6 +94,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_calibration: bool, fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
sequence_parallel: bool, sequence_parallel: bool,
...@@ -420,6 +420,26 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -420,6 +420,26 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(gelu_out) clear_tensor_data(gelu_out)
if is_grad_enabled: if is_grad_enabled:
if cpu_offloading:
if fuse_wgrad_accumulation:
fc1_weight.main_grad.weight_offloading = True
fc2_weight.main_grad.weight_offloading = True
if fp8:
fc1_weight_t_fp8.weight_offloading = True
fc2_weight_t_fp8.weight_offloading = True
ln_weight.weight_offloading = True
fc1_weight.weight_offloading = True
fc2_weight.weight_offloading = True
fc1_bias.weight_offloading = True
inputmat.activation_offloading = True
if normalization == "LayerNorm":
mu.activation_offloading = True
rsigma.activation_offloading = True
ln_out.activation_offloading = True
fc1_out.activation_offloading = True
gelu_out.activation_offloading = True
ctx.save_for_backward( ctx.save_for_backward(
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -429,8 +449,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -429,8 +449,10 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_out, fc1_out,
gelu_out, gelu_out,
fc1_weight, fc1_weight,
fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None,
fc1_weight_t_fp8, fc1_weight_t_fp8,
fc2_weight, fc2_weight,
fc2_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None,
fc2_weight_t_fp8, fc2_weight_t_fp8,
fc1_bias, fc1_bias,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
...@@ -440,6 +462,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -440,6 +462,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_fc1_bias = use_fc1_bias ctx.use_fc1_bias = use_fc1_bias
ctx.use_fc2_bias = use_fc2_bias ctx.use_fc2_bias = use_fc2_bias
...@@ -492,13 +515,22 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -492,13 +515,22 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_out, fc1_out,
gelu_out, gelu_out,
fc1_weight, fc1_weight,
fc1_weight_main_grad,
fc1_weight_t_fp8, fc1_weight_t_fp8,
fc2_weight, fc2_weight,
fc2_weight_main_grad,
fc2_weight_t_fp8, fc2_weight_t_fp8,
fc1_bias, fc1_bias,
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
fc1_weight = Parameter(fc1_weight, False)
fc2_weight = Parameter(fc2_weight, False)
fc1_weight.main_grad = fc1_weight_main_grad
fc2_weight.main_grad = fc2_weight_main_grad
# Primary weights are in FP8. # Primary weights are in FP8.
if ctx.fp8 and fc1_weight_t_fp8 is None: if ctx.fp8 and fc1_weight_t_fp8 is None:
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=ctx.is_first_microbatch) fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=ctx.is_first_microbatch)
...@@ -993,6 +1025,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -993,6 +1025,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1336,6 +1369,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1336,6 +1369,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
is_first_microbatch is_first_microbatch
) )
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _LayerNormMLP.apply fwd_fn = _LayerNormMLP.apply
args = [] args = []
...@@ -1362,6 +1397,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1362,6 +1397,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
self.sequence_parallel, self.sequence_parallel,
......
...@@ -45,7 +45,6 @@ from ..jit import no_torch_dynamo ...@@ -45,7 +45,6 @@ from ..jit import no_torch_dynamo
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -68,6 +67,7 @@ class _Linear(torch.autograd.Function): ...@@ -68,6 +67,7 @@ class _Linear(torch.autograd.Function):
fp8_calibration: bool, fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
tp_size: int, tp_size: int,
sequence_parallel: bool, sequence_parallel: bool,
...@@ -266,12 +266,26 @@ class _Linear(torch.autograd.Function): ...@@ -266,12 +266,26 @@ class _Linear(torch.autograd.Function):
saved_inputmat = inputmat saved_inputmat = inputmat
else: else:
saved_inputmat_t = inputmat_t saved_inputmat_t = inputmat_t
if cpu_offloading:
saved_inputmat_t.activation_offloading = True
else: else:
saved_inputmat = inputmat_no_fp8 saved_inputmat = inputmat_no_fp8
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8:
weight_t_fp8.weight_offloading = True
weight.weight_offloading = True
if saved_inputmat is not None:
saved_inputmat.activation_offloading = True
ctx.save_for_backward( ctx.save_for_backward(
saved_inputmat, saved_inputmat,
saved_inputmat_t, saved_inputmat_t,
weight, weight,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8 if fp8 else None, weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
) )
...@@ -279,6 +293,7 @@ class _Linear(torch.autograd.Function): ...@@ -279,6 +293,7 @@ class _Linear(torch.autograd.Function):
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
...@@ -315,10 +330,15 @@ class _Linear(torch.autograd.Function): ...@@ -315,10 +330,15 @@ class _Linear(torch.autograd.Function):
inputmat, inputmat,
inputmat_t, inputmat_t,
weight, weight,
main_grad,
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False)
weight.main_grad = main_grad
# Primary weights are in FP8. # Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None: if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch) weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
...@@ -515,6 +535,7 @@ class _Linear(torch.autograd.Function): ...@@ -515,6 +535,7 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -862,6 +883,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -862,6 +883,8 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch is_first_microbatch
) )
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled(): if torch.is_grad_enabled():
linear_fn = _Linear.apply linear_fn = _Linear.apply
args = [] args = []
...@@ -880,6 +903,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -880,6 +903,7 @@ class Linear(TransformerEngineBaseModule):
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
self.sequence_parallel, self.sequence_parallel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment