Unverified Commit f196d14b authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Activation offloading to CPU's for the Linear, Layernorm Linear and the...


Activation offloading to CPU's for the Linear, Layernorm Linear and the Layernorm MLP modules (#571)

* Added support activation offloading to CPU's
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Moving CPU offloading library to TE
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Restructured code, added switch to choose between weight/activation offloading
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Removed arg during constructor
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fix nit-pick errors
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Documentation fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix to the code block in docs
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added offloading unit test
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed formatting
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* wgrad fusion fix, minor errors and lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Errors, test, lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* RM test file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixed stray PyT tensors in LayernormMLP getting offloaded
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed typi
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fix offloading for rmsnorm, rm test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Float8Tensor compatible offloading
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent bacefdbb
......@@ -40,3 +40,5 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.checkpoint
.. autoapifunction:: transformer_engine.pytorch.onnx_export
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
......@@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import Optional
from contextlib import nullcontext
import torch
import pytest
......@@ -20,6 +21,7 @@ from transformer_engine.pytorch import (
TransformerLayer,
RMSNorm,
LayerNorm,
get_cpu_offload_context,
)
from transformer_engine.common import recipe
......@@ -215,7 +217,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
......@@ -223,9 +225,16 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
if skip_wgrad:
_disable_wgrads(block)
if cpu_offload:
offload_context, sync_function = get_cpu_offload_context(enabled=True)
else:
offload_context = nullcontext()
sync_function = lambda x: x
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
te_out = block(te_inp_hidden_states)
te_out = sync_function(te_out)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
......@@ -449,9 +458,11 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean)
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, bias, activation,
normalization, parallel_attention_mlp):
normalization, parallel_attention_mlp,
cpu_offload):
config = model_configs[model]
if fp8_recipe is not None:
......@@ -489,7 +500,7 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
.cuda()
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
def test_sanity_gpt_126m():
......@@ -512,6 +523,7 @@ def test_sanity_gpt_126m():
activation="gelu",
normalization="LayerNorm",
parallel_attention_mlp=False,
cpu_offload=False,
)
......@@ -713,7 +725,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
.cuda()
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
@pytest.mark.parametrize("dtype", param_types)
......@@ -751,7 +763,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
.cuda()
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
@pytest.mark.parametrize("dtype", param_types)
......
......@@ -17,6 +17,7 @@ from .fp8 import fp8_model_init
from .export import onnx_export
from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker
from .cpu_offload import get_cpu_offload_context
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
onnx_cast_to_fp8,
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Functionality for CPU offloading of tensors saved for backward pass."""
from typing import Any
from contextlib import nullcontext
import torch
from .float8_tensor import Float8Tensor
__all__ = ['get_cpu_offload_context']
CPUOffloadEnabled = False
class CpuOffloadSavedTensorHook:
"""Contex-manager that executes a pair of pack/unpack hooks for saved tensors.
In this context, the ``on_save_for_backward`` method will be called every time
a tensor is saved for backward (this includes intermediary results saved using
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
also those recorded by a PyTorch-defined operation).
The ``on_get_saved_tensors`` method will be called when the backward function
of this op attempts to retrieve the saved tensor from context (this includes
:func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the
as input the return value of the ``on_save_for_backward``, and is meant to return
an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of
size, device and element values.
Example:
>>> import torch
>>> from typing import Any
>>>
>>> class DummyHook(CpuOffloadSavedTensorHook):
...
... def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
... logging.info("On save", tensor)
... return (tensor,)
...
... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
... logging.info("On get", saved_state)
... tensor, = saved_state
... return tensor
...
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with DummyHook():
... y = a * b
...
On save tensor([1., 1., 1., 1., 1.], requires_grad=True)
On save tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
>>> y.sum().backward()
On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),)
On get (tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>),)
"""
def __init__(self) -> None:
self.inside_context = False
def __enter__(self):
global CPUOffloadEnabled
CPUOffloadEnabled = True
self.inside_context = True
torch._C._autograd._push_saved_tensors_default_hooks(
self.on_save_for_backward,
self.on_get_saved_tensor
)
def __exit__(self, *args: Any):
global CPUOffloadEnabled
CPUOffloadEnabled = False
self.inside_context = False
torch._C._autograd._pop_saved_tensors_default_hooks()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
"""On save for backward."""
raise NotImplementedError("`on_save_for_backward: Callable[[torch.Tensor], Any]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks")
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
"""On get saved tensor."""
raise NotImplementedError("`on_get_saved_tensors: Callable[[Any], torch.Tensor]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks")
class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
"""Context-manager that offloads/recovers tensors through an offload hander.
The hook just offloads/recovers the tensor object to the handler through `tensor_push`
and `tensor_pop` interface. How the offload-handler manages the offloading, recovering
or prefetching timing is transparent to this hook.
"""
def __init__(self, offload_handler, handler_extra_kwargs={}, debug=False) -> None: # pylint: disable=dangerous-default-value
self.debug = debug
self.offload_handler = offload_handler
self.handler_extra_kwargs = handler_extra_kwargs
super().__init__()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
retrieve_identifier = self.offload_handler.tensor_push(
tensor,
**self.handler_extra_kwargs
)
return retrieve_identifier
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
tensor = self.offload_handler.tensor_pop(
saved_state,
**self.handler_extra_kwargs
)
return tensor
class OffloadHandler:
"""A base class for CPU offload-handler."""
def __init__(self) -> None:
pass
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
"""Tensor push."""
raise NotImplementedError("`tensor_push is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_push.")
def tensor_pop(self, tensor_tag: Any, **kwargs):
"""Tensor pop."""
raise NotImplementedError("`tensor_pop is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_pop.")
class GroupCommitFunction(torch.autograd.Function):
"""this is a dummy op with output identical to input.
However, it is necessary for marking a timepoint for offload handler to
accomplish all synchronizations. Implementing it as a function is necessary
because we need to actions in both forward and backward.
"""
@staticmethod
def forward(ctx, tensor, cpu_offload_handler):
cpu_offload_handler.on_group_commit_forward()
ctx.cpu_offload_handler = cpu_offload_handler
# return the identical tensor
return tensor
@staticmethod
def backward(ctx, grad_output):
cpu_offload_handler = ctx.cpu_offload_handler
cpu_offload_handler.on_group_commit_backward()
return grad_output, None
group_prefetch_offload_commit = GroupCommitFunction.apply
class SynchronizedGroupOffloadHandler(OffloadHandler):
"""Offload Handler that offloads/reloads in a synchronized way.
The device-to-host and host-to-device copying happen in the same stream
as the computation kernels, thus the copying will block computation.
"""
def __init__(self,
num_offload_group,
tensor_need_offloading_checker=(lambda _: True),
debug=False
) -> None:
super().__init__()
self.num_offload_group = num_offload_group
self.tensor_need_offloading_checker = tensor_need_offloading_checker
self.debug = debug
self.groupid_reset()
def groupid_reset(self):
"""Groupid reset."""
# Data structures to label saved tensors and book-keep their cpu copies.
# Currently, on push, create a new cpu tensor and copies; on pop, copies
# the tensor back to gpu and deletes the cpu tensor.
# These will increment whenever `group_commit()` is invoked
self.current_group, self.tensor_count_current_group = (0, 0)
self.tensor_tag_to_state = {}
def on_group_commit_forward(self):
"""On group commit forward."""
# finishing up with updating current group and tensor count
self.current_group += 1 # increment
self.tensor_count_current_group = 0 # reset
def on_group_commit_backward(self):
"""On group commit backward."""
self.current_group -= 1
assert self.current_group >= 0
@staticmethod
def offload(src_tensor, pin_memory=True):
"""Offload."""
fp8_offload = isinstance(src_tensor, Float8Tensor)
cpu_backup = torch.empty(
src_tensor.size(), dtype=torch.uint8 if fp8_offload else src_tensor.dtype,
layout=src_tensor.layout, device="cpu", pin_memory=pin_memory)
if fp8_offload:
cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup)
cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
state = (src_tensor.device, cpu_backup)
return state
@staticmethod
def reload(state, non_blocking=None):
"""Reload."""
dev, cpu_backup = state
if non_blocking is None:
non_blocking = cpu_backup.is_pinned()
return cpu_backup.to(dev, non_blocking=non_blocking)
def tensor_push(self, tensor: torch.Tensor, **kwargs):
"""Tensor push."""
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
if (self.current_group < self.num_offload_group
and self.tensor_need_offloading_checker(tensor)):
state = SynchronizedGroupOffloadHandler.offload(tensor)
self.tensor_tag_to_state[tensor_tag] = state
else:
# will be offloaded together after group commit
self.tensor_tag_to_state[tensor_tag] = tensor
return tensor_tag
def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop."""
assert tensor_tag in self.tensor_tag_to_state
state = self.tensor_tag_to_state.pop(tensor_tag)
if isinstance(state, tuple):
tensor = SynchronizedGroupOffloadHandler.reload(state)
else:
tensor = state
return tensor
class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
"""Compared to synchronize, this uses more memory because of the buffer but
achieves better performance due to the overlapping. D2h and h2d copying are
completely hidden behind computation if computation time of a layer is longer
than host-device communication time. Bulk offloading with delay and bulk reloading
with prefetch are implemented. """
def __init__(self,
num_offload_group, # must be <= actual number of groups (number of commits)
num_prefetch_group=1,
tensor_need_offloading_checker=(lambda t: True),
debug=False
) -> None:
super().__init__(num_offload_group=num_offload_group,
tensor_need_offloading_checker=tensor_need_offloading_checker,
debug=debug)
self.num_prefetch_group = num_prefetch_group
# prepare for tensor buffer
self.tensor_id_to_tensor_buf_double_bufs = []
for _ in range(2):
self.tensor_id_to_tensor_buf_double_bufs.append({})
# allocate streams and events for synchronization
self.d2h_stream = torch.cuda.Stream()
self.h2d_stream = torch.cuda.Stream()
self.h2d_finish_events = []
self.compute_stream_bwd_start_events = []
for _ in range(self.num_offload_group):
self.h2d_finish_events.append(torch.cuda.Event())
self.compute_stream_bwd_start_events.append(torch.cuda.Event())
self.d2h_final_event = torch.cuda.Event()
def get_tensor_buf_for_offloaded_tensor(self, tensor, tensor_tag):
"""Get tensor buffer for offloaded tensor."""
group_id, tensor_id = tensor_tag
# obtain ping-pong buffer
id_buf_map = self.tensor_id_to_tensor_buf_double_bufs[(group_id % 2)]
if not tensor_id in id_buf_map:
allocate_new_buf = True
else:
tensor_buf = id_buf_map[tensor_id]
if not (tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype): # pylint: disable=simplifiable-if-statement
allocate_new_buf = True
else:
allocate_new_buf = False # in this case, reuse the old buffer
if allocate_new_buf:
# supposed to only execute once
fp8_offload = isinstance(tensor, Float8Tensor)
buffer = torch.empty(
tensor.size(), dtype=torch.uint8 if fp8_offload else tensor.dtype,
layout=tensor.layout, device=tensor.device)
if isinstance(tensor, Float8Tensor):
id_buf_map[tensor_id] = Float8Tensor.make_like(tensor, data=buffer)
else:
id_buf_map[tensor_id] = buffer
return id_buf_map[tensor_id]
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
if (self.current_group < self.num_offload_group
and self.tensor_need_offloading_checker(tensor)):
# first copy the tensor to tensorbuf, so that the original tensor will not be deleted
tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag)
tensor_buf.copy_(tensor)
if hasattr(tensor,"weight_offloading"):
tensor_buf.weight_offloading = True
if hasattr(tensor,"activation_offloading"):
tensor_buf.activation_offloading = True
# Here we just save it, and at commit, bulk_offload_group will handle it
self.tensor_tag_to_state[tensor_tag] = tensor_buf
else:
self.tensor_tag_to_state[tensor_tag] = tensor
return tensor_tag
def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop."""
assert tensor_tag in self.tensor_tag_to_state
tensor = self.tensor_tag_to_state.pop(tensor_tag)
# the tensor should have been copied back in on_group_commit_backward()
# which invokes bulk_reload_group.
assert not isinstance(tensor, tuple)
return tensor
def bulk_offload_group(self, group_to_offload):
"""Bulk offload group."""
with torch.cuda.stream(self.d2h_stream):
for tensor_tag, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_tag
if group_id == group_to_offload:
assert not isinstance(state, tuple)
tensor_on_device = state
# if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device):
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
self.tensor_tag_to_state[tensor_tag] = state
def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward."""
# the host should wait for the copying of previous group
# to avoid overwriting buffer
previous_group = current_group - 1
if previous_group < self.num_offload_group:
torch.cuda.synchronize()
# TODO (guyueh): this part is originally designed to reduce the peak memory usage. # pylint: disable=fixme
# however, uncommenting this part will cause illegal access, have not figured out why.
if previous_group + 2 >= self.num_offload_group:
# this buffer is no longer required
self.tensor_id_to_tensor_buf_double_bufs[(previous_group % 2)] = {}
# the copying of this group should wait for the computation stream event
if current_group < self.num_offload_group:
# perform bulk offloading
self.bulk_offload_group(current_group)
if current_group == self.num_offload_group - 1:
self.d2h_stream.record_event(self.d2h_final_event)
def on_group_commit_forward(self):
"""This function will cause host device synchronization"""
# handle synchronization events
self.synchronize_on_group_commit_forward(self.current_group)
# during forward, the next_group_to_fetch always points to the min of
# the last commited group, and the last offloaded group
self.next_group_to_fetch = min(self.current_group, self.num_offload_group -1)
super().on_group_commit_forward()
def bulk_reload_group(self, group_to_reload):
"""Bulk reload group."""
assert group_to_reload < self.num_offload_group
if group_to_reload == self.num_offload_group - 1:
self.h2d_stream.wait_event(self.d2h_final_event)
with torch.cuda.stream(self.h2d_stream):
# move back tensors
for tensor_label, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_label
if group_id == group_to_reload:
if isinstance(state, tuple):
recovered_tensor = SynchronizedGroupOffloadHandler.reload(state)
self.tensor_tag_to_state[tensor_label] = recovered_tensor
def on_group_commit_backward(self):
# first decrement the current group.
# after last commit in forward, the group will +1; in backward it -1.
# Finally it should be decremented to 0.
self.current_group -= 1
assert self.current_group >= 0
# decide the range of group to prefetch
should_prefetch_until_group = self.current_group - self.num_prefetch_group
should_prefetch_until_group = max(should_prefetch_until_group, 0)
# do prefetch
for group_num_to_prefetch in range(
self.next_group_to_fetch, should_prefetch_until_group - 1, -1
):
# record the event in the compute stream, for h2d to wait
torch.cuda.current_stream().record_event(
self.compute_stream_bwd_start_events[group_num_to_prefetch])
# start of h2d should wait for the compute and the d2h
self.h2d_stream.wait_event(self.compute_stream_bwd_start_events[group_num_to_prefetch])
#recover tensors (copy back from host)
self.bulk_reload_group(group_num_to_prefetch)
# record an event for the backward of this layer to wait
self.h2d_stream.record_event(self.h2d_finish_events[group_num_to_prefetch])
# always is set to -1 at the end of the backward
self.next_group_to_fetch = min(self.num_offload_group - 1, should_prefetch_until_group - 1)
# wait for the current group
if self.current_group < self.num_offload_group:
torch.cuda.current_stream().wait_event(self.h2d_finish_events[self.current_group])
def get_cpu_offload_context(
enabled: bool = False,
num_layers: int = 1,
offload_activations: bool = True,
offload_weights: bool = True):
"""
This function returns the CPU Offload context and the synchronizer function that needs to be
used after every transformer layer. Returns `nullcontext()` if offloading is not enabled.
Usage:
.. code-block:: python
cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True)
with cpu_offload_context:
te_layer.forward(inp_tensor)
cpu_offload_synchronizer()
Parameters
----------
enabled: bool, default = `False`
When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1
Determines the number of transformer layers
you want to offload activations/weights for.
offload_activations: bool, default = `True`
When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True`
When set to `True`, offloads the weights for the TE layer.
"""
def tensor_need_offloading_checker_activations(tensor):
return hasattr(tensor,"activation_offloading")
# This includes the Gradient Accumulation Buffer
def tensor_need_offloading_checker_weights(tensor):
return hasattr(tensor, "weight_offloading")
def tensor_need_offloading_checker_all(tensor): # pylint: disable=unused-argument
return (hasattr(tensor,"activation_offloading") or hasattr(tensor, "weight_offloading"))
if offload_activations and offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_all
elif offload_activations:
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
elif offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_weights
else:
raise ValueError(
"CPU Offloading is enabled while it is not "
"mentioned what to offload (weights/activations)")
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
num_offload_group=num_layers,
num_prefetch_group=1,
tensor_need_offloading_checker=tensor_need_offloading_checker
)
def group_prefetch_offload_commit_async(tensor):
return group_prefetch_offload_commit(tensor,cpu_offload_handler)
if enabled:
return (
CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler),
group_prefetch_offload_commit_async,
)
return nullcontext(), group_prefetch_offload_commit_async
......@@ -42,7 +42,6 @@ from ..jit import no_torch_dynamo
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor
__all__ = ["LayerNormLinear"]
......@@ -68,6 +67,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
......@@ -239,12 +239,27 @@ class _LayerNormLinear(torch.autograd.Function):
)
if is_grad_enabled:
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8:
weight_t_fp8.weight_offloading = True
ln_weight.weight_offloading = True
weight.weight_offloading = True
inputmat.activation_offloading = True
if normalization == "LayerNorm":
mu.activation_offloading = True
rsigma.activation_offloading = True
ln_out.activation_offloading = True
ctx.save_for_backward(
inputmat,
ln_weight,
mu,
rsigma,
weight,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8,
ln_out,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
......@@ -254,6 +269,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
......@@ -298,11 +314,16 @@ class _LayerNormLinear(torch.autograd.Function):
mu,
rsigma,
weight,
main_grad,
weight_t_fp8,
ln_out,
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False)
weight.main_grad = main_grad
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
......@@ -582,6 +603,7 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -992,6 +1014,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch
)
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
args = []
......@@ -1013,6 +1037,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
self.tp_group,
self.tp_size,
self.sequence_parallel,
......
......@@ -51,7 +51,6 @@ from ..jit import no_torch_dynamo
from ..float8_tensor import Float8Tensor
from ._common import _apply_normalization
__all__ = ["LayerNormMLP"]
......@@ -95,6 +94,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
......@@ -420,6 +420,26 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(gelu_out)
if is_grad_enabled:
if cpu_offloading:
if fuse_wgrad_accumulation:
fc1_weight.main_grad.weight_offloading = True
fc2_weight.main_grad.weight_offloading = True
if fp8:
fc1_weight_t_fp8.weight_offloading = True
fc2_weight_t_fp8.weight_offloading = True
ln_weight.weight_offloading = True
fc1_weight.weight_offloading = True
fc2_weight.weight_offloading = True
fc1_bias.weight_offloading = True
inputmat.activation_offloading = True
if normalization == "LayerNorm":
mu.activation_offloading = True
rsigma.activation_offloading = True
ln_out.activation_offloading = True
fc1_out.activation_offloading = True
gelu_out.activation_offloading = True
ctx.save_for_backward(
inputmat,
ln_weight,
......@@ -429,8 +449,10 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_out,
gelu_out,
fc1_weight,
fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None,
fc1_weight_t_fp8,
fc2_weight,
fc2_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None,
fc2_weight_t_fp8,
fc1_bias,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
......@@ -440,6 +462,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_fc1_bias = use_fc1_bias
ctx.use_fc2_bias = use_fc2_bias
......@@ -492,13 +515,22 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_out,
gelu_out,
fc1_weight,
fc1_weight_main_grad,
fc1_weight_t_fp8,
fc2_weight,
fc2_weight_main_grad,
fc2_weight_t_fp8,
fc1_bias,
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
fc1_weight = Parameter(fc1_weight, False)
fc2_weight = Parameter(fc2_weight, False)
fc1_weight.main_grad = fc1_weight_main_grad
fc2_weight.main_grad = fc2_weight_main_grad
# Primary weights are in FP8.
if ctx.fp8 and fc1_weight_t_fp8 is None:
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=ctx.is_first_microbatch)
......@@ -993,6 +1025,7 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1336,6 +1369,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
is_first_microbatch
)
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
fwd_fn = _LayerNormMLP.apply
args = []
......@@ -1362,6 +1397,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
self.tp_group,
self.tp_size,
self.sequence_parallel,
......
......@@ -45,7 +45,6 @@ from ..jit import no_torch_dynamo
from ..float8_tensor import Float8Tensor
__all__ = ["Linear"]
......@@ -68,6 +67,7 @@ class _Linear(torch.autograd.Function):
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
......@@ -266,12 +266,26 @@ class _Linear(torch.autograd.Function):
saved_inputmat = inputmat
else:
saved_inputmat_t = inputmat_t
if cpu_offloading:
saved_inputmat_t.activation_offloading = True
else:
saved_inputmat = inputmat_no_fp8
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8:
weight_t_fp8.weight_offloading = True
weight.weight_offloading = True
if saved_inputmat is not None:
saved_inputmat.activation_offloading = True
ctx.save_for_backward(
saved_inputmat,
saved_inputmat_t,
weight,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
......@@ -279,6 +293,7 @@ class _Linear(torch.autograd.Function):
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
......@@ -315,10 +330,15 @@ class _Linear(torch.autograd.Function):
inputmat,
inputmat_t,
weight,
main_grad,
weight_t_fp8,
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight, False)
weight.main_grad = main_grad
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
......@@ -515,6 +535,7 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -862,6 +883,8 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch
)
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
linear_fn = _Linear.apply
args = []
......@@ -880,6 +903,7 @@ class Linear(TransformerEngineBaseModule):
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
self.tp_group,
self.tp_size,
self.sequence_parallel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment