"tests/pytorch/attention/test_attention.py" did not exist on "59c0f096b61e43d2890cc64bfa00a52410f162e1"
Unverified Commit 4d4f1edb authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Cpu reload double buffer (#1695)



* Added double buffering support initial commit
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* Fixed bugs
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>

* Make only one double buffer creation
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>

* Fixed bug
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>

* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>

* Fixed flag setting
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Merge conflict
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <anandaraj@wisc.edu>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche02.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5d01ef21
......@@ -97,6 +97,8 @@ def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload
max_mem_used = torch.cuda.memory_allocated() / (1024**2)
torch.cuda.synchronize()
tensor.sum().backward()
return max_mem_used
......@@ -115,6 +117,9 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
the difference being the size of the FP8 cache that is not offloaded to the CPU.
We also expect this memory consumption to be smaller than in scenario (1).
"""
import gc
gc.collect()
model_cls = model_types[model_key]
models_list = [model_cls() for _ in range(NUM_LAYERS)]
......
......@@ -253,13 +253,21 @@ class SynchronizedGroupOffloadHandler(OffloadHandler):
return state
@staticmethod
def reload(state, non_blocking=None):
def reload(state, non_blocking=None, copy_buffer=None):
"""Reload."""
dev, cpu_backup = state
if non_blocking is None:
non_blocking = cpu_backup.is_pinned()
if copy_buffer is None:
return cpu_backup.to(dev, non_blocking=non_blocking)
assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!"
copy_buffer.copy_(cpu_backup, non_blocking=non_blocking)
return copy_buffer
def tensor_push(self, tensor: torch.Tensor, **kwargs):
"""Tensor push."""
# obtain a unique tensor tag
......@@ -300,6 +308,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
num_offload_group, # must be <= actual number of groups (number of commits)
num_model_group,
tensor_need_offloading_checker=(lambda t: True),
double_buffering=False,
debug=False,
) -> None:
super().__init__(
......@@ -320,6 +329,11 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Core data structure that decides the window for offloading
self.layer_window_map = {}
# Data structures fo double buffered reloading
self.double_buffering = double_buffering
self.reload_double_buffer = [[], []]
self.double_buffer_created = False
# Logic to make offloading load balance across computation
# for optimal CPU/GPU interconnect usage
constant = 0
......@@ -413,8 +427,10 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor)
tensor = self.fp8_tensor_object_map.pop(tensor_tag)
self.tensor_tag_to_buf.pop(tensor_tag, None)
if self.double_buffering:
tensor.do_not_clear = True
self.tensor_tag_to_buf.pop(tensor_tag, None)
# the tensor should have been copied back in on_group_commit_backward()
# which invokes bulk_reload_group.
assert not isinstance(tensor, tuple)
......@@ -466,6 +482,20 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# the first compute completion
if current_group == 0:
self.d2h_stream.wait_stream(torch.cuda.current_stream())
if not self.double_buffer_created:
# Creating the first copy of double buffer for tensors that are offloaded
for tensor_tag, buf in self.tensor_tag_to_buf.items():
if isinstance(buf, list):
for b in buf:
self.reload_double_buffer[0].append(
torch.empty_like(b) if self.double_buffering else None
)
else:
self.reload_double_buffer[0].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.bulk_offload_group(current_group)
# Window map data structure helps us synchronize based on number
......@@ -495,6 +525,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Increment the offload group count to keep track
self.offloaded_group_count += 1
if not self.double_buffer_created:
# Creating second copy of double buffer for tensors that are offloaded
if current_group == (self.num_layers - 1):
for buf in self.reload_double_buffer[0]:
self.reload_double_buffer[1].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.double_buffer_created = True
def on_group_commit_forward(self):
"""This function will cause host device synchronization"""
# handle synchronization events
......@@ -506,21 +545,32 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
"""Bulk reload group."""
assert group_to_reload < self.num_offload_group
buffer_idx = 0
double_buffer_idx = group_to_reload % 2
with torch.cuda.stream(self.h2d_stream):
# move back tensors
for tensor_label, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_label
if group_id == group_to_reload:
if isinstance(state, tuple):
recovered_tensor = SynchronizedGroupOffloadHandler.reload(state)
recovered_tensor = SynchronizedGroupOffloadHandler.reload(
state, True, self.reload_double_buffer[double_buffer_idx][buffer_idx]
)
buffer_idx = buffer_idx + 1
self.tensor_tag_to_state[tensor_label] = recovered_tensor
elif isinstance(state, list):
tensor_list = []
for state_tuple in state:
if isinstance(state_tuple, tuple):
tensor_list.append(
SynchronizedGroupOffloadHandler.reload(state_tuple)
SynchronizedGroupOffloadHandler.reload(
state_tuple,
True,
self.reload_double_buffer[double_buffer_idx][buffer_idx],
)
)
buffer_idx = buffer_idx + 1
else:
tensor_list.append(state_tuple)
......@@ -574,6 +624,7 @@ def get_cpu_offload_context(
model_layers: int = 1,
offload_activations: bool = True,
offload_weights: bool = False,
double_buffering: bool = False,
):
"""
This function returns the CPU Offload context and the synchronizer function that needs to be
......@@ -602,6 +653,8 @@ def get_cpu_offload_context(
When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True`
When set to `True`, offloads the weights for the TE layer.
double_buffering: bool, default = `False`
When set to `True`, uses double buffering for offloading.
"""
......@@ -633,6 +686,7 @@ def get_cpu_offload_context(
num_offload_group=num_layers,
num_model_group=model_layers,
tensor_need_offloading_checker=tensor_need_offloading_checker,
double_buffering=double_buffering,
)
def group_prefetch_offload_commit_async(tensor):
......
......@@ -37,8 +37,16 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
Must be used carefully.
"""
for t in tensors:
if t is not None:
# Workaround for double buffering in cpu offload
if hasattr(t, "do_not_clear"):
continue
if hasattr(t, "get_data_tensors"):
if any(hasattr(tensor, "do_not_clear") for tensor in t.get_data_tensors()):
continue
if hasattr(t, "clear"):
t.clear()
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment