Unverified Commit 5bb771e3 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Verified TE2.0 with offloading (#1514)



* Verified TE2.0 with offloading
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skipping tests for Ampere and removed child class preparing
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* offloading support for MXFP8 dtype
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Changed quantized tensor detection mechanism
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* Fix mxfp8 offload, lint errors, and var name
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Supported disabling offloading for quantized tensors
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* bug fix
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed bugs
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added support for None in list of Quantized data tensors
Signed-off-by: default avatarroot <root@prenyx0095.a51.clusters.nvidia.com>

* Hopper backward compatibility cleanup
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Coding style nit
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added guards
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 77fa1e59
...@@ -24,6 +24,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1 ...@@ -24,6 +24,7 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || FAIL=1
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1
exit $FAIL exit $FAIL
...@@ -7,8 +7,12 @@ import torch ...@@ -7,8 +7,12 @@ import torch
from contextlib import nullcontext from contextlib import nullcontext
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
SIZE = 4096 # Check if FP8 supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
SIZE = 512
models = { models = {
"linear": te.Linear, "linear": te.Linear,
...@@ -18,40 +22,64 @@ models = { ...@@ -18,40 +22,64 @@ models = {
def _get_input(): def _get_input():
return torch.empty((1, SIZE, SIZE)).cuda() # input size - 1 * 2048 * 2048 * 4b = 16MB return torch.empty((128, SIZE, SIZE)).cuda()
def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload):
torch.cuda.empty_cache()
model = model_cls(SIZE, SIZE, 1) input_layer = model_cls(SIZE, SIZE)
hidden_layer = model_cls(SIZE, SIZE)
output_layer = model_cls(SIZE, SIZE)
input = _get_input() input = _get_input()
if cpu_offload: if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context(enabled=True) offload_context, sync_function = te.get_cpu_offload_context(
enabled=True,
num_layers=2,
model_layers=3,
offload_activations=True,
offload_weights=False,
)
else: else:
offload_context = nullcontext() offload_context = nullcontext()
sync_function = lambda x: x sync_function = lambda x: x
with te.fp8_autocast(enabled=fp8), offload_context: with te.fp8_autocast(enabled=fp8), offload_context:
out = model(input) out = input_layer(input)
out = sync_function(out)
with te.fp8_autocast(enabled=fp8), offload_context:
out = hidden_layer(out)
out = sync_function(out) out = sync_function(out)
input.data = torch.Tensor() # delete data from input with te.fp8_autocast(enabled=fp8), offload_context:
out.data = torch.Tensor() # delete data from out out = output_layer(out)
out = sync_function(out)
max_mem_used = torch.cuda.memory_allocated() / 1024**2
out.sum().backward()
del input_layer
del hidden_layer
del output_layer
del input del input
del out del out
torch.cuda.empty_cache()
allocated_memory_mb = torch.cuda.memory_allocated() / 1024**2
del model
return allocated_memory_mb
torch.cuda.synchronize()
return max_mem_used
@pytest.mark.parametrize("fp8", [False, True])
@pytest.mark.parametrize("fp8", [True, False])
@pytest.mark.parametrize("model_key", models.keys()) @pytest.mark.parametrize("model_key", models.keys())
def test_cpu_offload(fp8, model_key) -> None: def test_cpu_offload(fp8, model_key) -> None:
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
model_cls = models[model_key] model_cls = models[model_key]
without_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, False) without_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, False)
torch.cuda.empty_cache()
with_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, True) with_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, True)
assert without_offloading > 30 assert with_offloading < without_offloading
assert with_offloading < 10
...@@ -137,9 +137,7 @@ class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): ...@@ -137,9 +137,7 @@ class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
super().__init__() super().__init__()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any: def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
retrieve_identifier = self.offload_handler.tensor_push( retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)
tensor.data, **self.handler_extra_kwargs
)
return retrieve_identifier return retrieve_identifier
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
...@@ -235,19 +233,15 @@ class SynchronizedGroupOffloadHandler(OffloadHandler): ...@@ -235,19 +233,15 @@ class SynchronizedGroupOffloadHandler(OffloadHandler):
@staticmethod @staticmethod
def offload(src_tensor, pin_memory=True): def offload(src_tensor, pin_memory=True):
"""Offload.""" """Offload."""
fp8_offload = isinstance(src_tensor, Float8Tensor)
cpu_backup = torch.empty( cpu_backup = torch.empty(
src_tensor.size(), src_tensor.size(),
dtype=torch.uint8 if fp8_offload else src_tensor.dtype, dtype=src_tensor.dtype,
layout=src_tensor.layout, layout=src_tensor.layout,
device="cpu", device="cpu",
pin_memory=pin_memory, pin_memory=pin_memory,
) )
if fp8_offload:
cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup)
cpu_backup.copy_(src_tensor, non_blocking=pin_memory) cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
state = (src_tensor.device, cpu_backup) state = (src_tensor.device, cpu_backup)
return state return state
...@@ -311,6 +305,9 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -311,6 +305,9 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self.num_layers = num_model_group self.num_layers = num_model_group
# Data Structure to maintain reference to activation tensors # Data Structure to maintain reference to activation tensors
self.tensor_tag_to_buf = {} self.tensor_tag_to_buf = {}
# Data structure to hold the FP8/MXFP8 tensor objects
self.fp8_tensor_object_map = {}
self.float8_transpose_cache_valid = {}
# Tracking the number of layers offloaded # Tracking the number of layers offloaded
self.offloaded_group_count = 0 self.offloaded_group_count = 0
# Core data structure that decides the window for offloading # Core data structure that decides the window for offloading
...@@ -341,18 +338,46 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -341,18 +338,46 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
), ),
) )
is_quantized_tensor = callable(getattr(tensor, "prepare_for_saving", None))
if not torch_stray_tensor: if not torch_stray_tensor:
# obtain a unique tensor tag # obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group) tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1 self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state assert tensor_tag not in self.tensor_tag_to_state
self.tensor_tag_to_state[tensor_tag] = tensor if is_quantized_tensor:
tensor_list, _ = tensor.prepare_for_saving()
if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( self.tensor_tag_to_state[tensor_tag] = []
tensor self.tensor_tag_to_buf[tensor_tag] = []
self.fp8_tensor_object_map[tensor_tag] = tensor
if isinstance(tensor, Float8Tensor):
self.float8_transpose_cache_valid[tensor_tag] = getattr(
tensor, "_transpose_invalid"
)
else:
tensor_list = [tensor]
for t in tensor_list:
if is_quantized_tensor:
self.tensor_tag_to_state[tensor_tag].append(t)
else:
self.tensor_tag_to_state[tensor_tag] = t
if (
self.current_group < self.num_offload_group
and self.tensor_need_offloading_checker(t)
): ):
self.tensor_tag_to_buf[tensor_tag] = tensor if is_quantized_tensor:
self.tensor_tag_to_buf[tensor_tag].append(t)
# Need to clear the internal data reference for the quantized tensors
tensor.clear()
else:
self.tensor_tag_to_buf[tensor_tag] = t
else: else:
tensor_tag = (-1, self.torch_tensor_count) tensor_tag = (-1, self.torch_tensor_count)
self.torch_tensor_count += 1 self.torch_tensor_count += 1
...@@ -364,7 +389,14 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -364,7 +389,14 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
"""Tensor pop.""" """Tensor pop."""
assert tensor_tag in self.tensor_tag_to_state assert tensor_tag in self.tensor_tag_to_state
tensor = self.tensor_tag_to_state.pop(tensor_tag) tensor = self.tensor_tag_to_state.pop(tensor_tag)
# Handling the quantized tensor case specially here
if isinstance(tensor, list):
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor)
tensor = self.fp8_tensor_object_map.pop(tensor_tag)
self.tensor_tag_to_buf.pop(tensor_tag, None) self.tensor_tag_to_buf.pop(tensor_tag, None)
# the tensor should have been copied back in on_group_commit_backward() # the tensor should have been copied back in on_group_commit_backward()
# which invokes bulk_reload_group. # which invokes bulk_reload_group.
assert not isinstance(tensor, tuple) assert not isinstance(tensor, tuple)
...@@ -377,13 +409,23 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -377,13 +409,23 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
group_id, _ = tensor_tag group_id, _ = tensor_tag
if group_id == group_to_offload: if group_id == group_to_offload:
assert not isinstance(state, tuple) assert not isinstance(state, tuple)
tensor_on_device = state
is_quantized_tensor = isinstance(state, list)
if is_quantized_tensor:
tensor_list = state
self.tensor_tag_to_state[tensor_tag] = []
else:
tensor_list = [state]
for tensor_on_device in tensor_list:
# if offload, return the reference to cpu copy # if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device): if self.tensor_need_offloading_checker(tensor_on_device):
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
if is_quantized_tensor:
self.tensor_tag_to_state[tensor_tag].append(state)
else:
self.tensor_tag_to_state[tensor_tag] = state self.tensor_tag_to_state[tensor_tag] = state
tensor_on_device.data = torch.Tensor() # Force to release memory
def synchronize_on_group_commit_forward(self, current_group): def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward.""" """Synchronize on group commit forward."""
...@@ -433,6 +475,23 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -433,6 +475,23 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
if isinstance(state, tuple): if isinstance(state, tuple):
recovered_tensor = SynchronizedGroupOffloadHandler.reload(state) recovered_tensor = SynchronizedGroupOffloadHandler.reload(state)
self.tensor_tag_to_state[tensor_label] = recovered_tensor self.tensor_tag_to_state[tensor_label] = recovered_tensor
elif isinstance(state, list):
tensor_list = []
for state_tuple in state:
if isinstance(state_tuple, tuple):
tensor_list.append(
SynchronizedGroupOffloadHandler.reload(state_tuple)
)
else:
tensor_list.append(state_tuple)
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(tensor_list)
if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor):
self.fp8_tensor_object_map[tensor_label]._transpose_invalid = (
self.float8_transpose_cache_valid.pop(tensor_label)
)
self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop(
tensor_label
)
def on_group_commit_backward(self): def on_group_commit_backward(self):
# first decrement the current group. # first decrement the current group.
......
...@@ -509,10 +509,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -509,10 +509,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
self._transpose = torch.Tensor() if self._transpose is not None else None self._transpose = torch.Tensor() if self._transpose is not None else None
self._transpose_invalid = True self._transpose_invalid = True
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]:
"""Prepare the tensor base for saving for backward"""
return [self], None
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
......
...@@ -285,10 +285,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -285,10 +285,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]:
"""Prepare the tensor base for saving for backward"""
return [self], None
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
......
...@@ -27,7 +27,7 @@ def prepare_for_saving( ...@@ -27,7 +27,7 @@ def prepare_for_saving(
if tensor is None: if tensor is None:
tensor_list.append(None) tensor_list.append(None)
tensor_objects_list.append(None) tensor_objects_list.append(None)
elif type(tensor) in (torch.Tensor, torch.nn.Parameter): elif isinstance(tensor, torch.Tensor):
tensor_list.append(tensor) tensor_list.append(tensor)
tensor_objects_list.append(None) tensor_objects_list.append(None)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment