Unverified Commit c8c05f38 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Load balanced offloading algorithm (#1057)



* Load balanced offloading algorithm
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 54c1cfad
...@@ -274,7 +274,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -274,7 +274,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
def __init__( def __init__(
self, self,
num_offload_group, # must be <= actual number of groups (number of commits) num_offload_group, # must be <= actual number of groups (number of commits)
num_prefetch_group=1, num_model_group,
tensor_need_offloading_checker=(lambda t: True), tensor_need_offloading_checker=(lambda t: True),
debug=False, debug=False,
) -> None: ) -> None:
...@@ -283,19 +283,29 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -283,19 +283,29 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
tensor_need_offloading_checker=tensor_need_offloading_checker, tensor_need_offloading_checker=tensor_need_offloading_checker,
debug=debug, debug=debug,
) )
self.num_prefetch_group = num_prefetch_group # Number of layers in the model
self.num_layers = num_model_group
# Data Structure to maintain reference to activation tensors # Data Structure to maintain reference to activation tensors
self.tensor_tag_to_buf = {} self.tensor_tag_to_buf = {}
# Tracking the number of layers offloaded
self.offloaded_group_count = 0
# Core data structure that decides the window for offloading
self.layer_window_map = {}
# Logic to make offloading load balance across computation
# for optimal CPU/GPU interconnect usage
constant = 0
for i in range(self.num_offload_group):
self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1
if i < (self.num_layers % self.num_offload_group):
self.layer_window_map[i] += i + 1
constant = i + 1
else:
self.layer_window_map[i] += constant
# allocate streams and events for synchronization # allocate streams and events for synchronization
self.d2h_stream = torch.cuda.Stream() self.d2h_stream = torch.cuda.Stream()
self.h2d_stream = torch.cuda.Stream() self.h2d_stream = torch.cuda.Stream()
self.h2d_finish_events = []
self.compute_stream_bwd_start_events = []
for _ in range(self.num_offload_group):
self.h2d_finish_events.append(torch.cuda.Event())
self.compute_stream_bwd_start_events.append(torch.cuda.Event())
self.d2h_final_event = torch.cuda.Event()
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
...@@ -352,41 +362,44 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -352,41 +362,44 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
def synchronize_on_group_commit_forward(self, current_group): def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward.""" """Synchronize on group commit forward."""
# the host should wait for the copying of previous group
# to avoid overwriting buffer # For the first group, kickstart the offload after we have
previous_group = current_group - 1 # the first compute completion
if previous_group < self.num_offload_group: if current_group == 0:
torch.cuda.synchronize() self.d2h_stream.wait_stream(torch.cuda.current_stream())
self.bulk_offload_group(current_group)
# Have to release the memory held by activations of the previous layer
if previous_group >= 0: # Window map data structure helps us synchronize based on number
# of layers offloaded
if self.layer_window_map[self.offloaded_group_count] == current_group:
# Stream synchronization both ways
self.d2h_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.current_stream().wait_stream(self.d2h_stream)
# Time to free the activation memory after usage
for tensor_tag, _ in self.tensor_tag_to_buf.items(): for tensor_tag, _ in self.tensor_tag_to_buf.items():
if tensor_tag[0] == previous_group: if tensor_tag[0] == self.offloaded_group_count:
self.tensor_tag_to_buf[tensor_tag] = None self.tensor_tag_to_buf[tensor_tag] = None
# the copying of this group should wait for the computation stream event # Time to offload the next group
if current_group < self.num_offload_group: if self.offloaded_group_count < (self.num_offload_group - 1):
# perform bulk offloading self.bulk_offload_group(self.offloaded_group_count + 1)
self.bulk_offload_group(current_group)
if current_group == self.num_offload_group - 1: # Increment the offload group count to keep track
self.d2h_stream.record_event(self.d2h_final_event) self.offloaded_group_count += 1
def on_group_commit_forward(self): def on_group_commit_forward(self):
"""This function will cause host device synchronization""" """This function will cause host device synchronization"""
# handle synchronization events # handle synchronization events
self.synchronize_on_group_commit_forward(self.current_group) self.synchronize_on_group_commit_forward(self.current_group)
# during forward, the next_group_to_fetch always points to the min of
# the last commited group, and the last offloaded group
self.next_group_to_fetch = min(self.current_group, self.num_offload_group - 1)
super().on_group_commit_forward() super().on_group_commit_forward()
def bulk_reload_group(self, group_to_reload): def bulk_reload_group(self, group_to_reload):
"""Bulk reload group.""" """Bulk reload group."""
assert group_to_reload < self.num_offload_group assert group_to_reload < self.num_offload_group
if group_to_reload == self.num_offload_group - 1:
self.h2d_stream.wait_event(self.d2h_final_event)
with torch.cuda.stream(self.h2d_stream): with torch.cuda.stream(self.h2d_stream):
# move back tensors # move back tensors
for tensor_label, state in self.tensor_tag_to_state.items(): for tensor_label, state in self.tensor_tag_to_state.items():
...@@ -403,39 +416,29 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -403,39 +416,29 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self.current_group -= 1 self.current_group -= 1
assert self.current_group >= 0 assert self.current_group >= 0
# decide the range of group to prefetch # Layer window data structure helps us to reload at right times
should_prefetch_until_group = self.current_group - self.num_prefetch_group if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group:
should_prefetch_until_group = max(should_prefetch_until_group, 0)
# do prefetch
for group_num_to_prefetch in range(
self.next_group_to_fetch, should_prefetch_until_group - 1, -1
):
# record the event in the compute stream, for h2d to wait
torch.cuda.current_stream().record_event(
self.compute_stream_bwd_start_events[group_num_to_prefetch]
)
# start of h2d should wait for the compute and the d2h
self.h2d_stream.wait_event(self.compute_stream_bwd_start_events[group_num_to_prefetch])
# recover tensors (copy back from host) # Stream synchronization both ways
self.bulk_reload_group(group_num_to_prefetch) self.h2d_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.current_stream().wait_stream(self.h2d_stream)
# record an event for the backward of this layer to wait # Time to reload the next group
self.h2d_stream.record_event(self.h2d_finish_events[group_num_to_prefetch]) self.bulk_reload_group(self.offloaded_group_count - 1)
# always is set to -1 at the end of the backward # Decrease the offloading group counter
self.next_group_to_fetch = min(self.num_offload_group - 1, should_prefetch_until_group - 1) self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0
# wait for the current group # Last group computation needs to wait till all the reloads complete
if self.current_group < self.num_offload_group: if self.current_group == 0:
torch.cuda.current_stream().wait_event(self.h2d_finish_events[self.current_group]) torch.cuda.current_stream().wait_stream(self.h2d_stream)
self.offloaded_group_count = 0
def get_cpu_offload_context( def get_cpu_offload_context(
enabled: bool = False, enabled: bool = False,
num_layers: int = 1, num_layers: int = 1,
model_layers: int = 1,
offload_activations: bool = True, offload_activations: bool = True,
offload_weights: bool = True, offload_weights: bool = True,
): ):
...@@ -460,6 +463,8 @@ def get_cpu_offload_context( ...@@ -460,6 +463,8 @@ def get_cpu_offload_context(
num_layers: int, default = 1 num_layers: int, default = 1
Determines the number of transformer layers Determines the number of transformer layers
you want to offload activations/weights for. you want to offload activations/weights for.
model_layers: int, default = 1
Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True` offload_activations: bool, default = `True`
When set to `True`, offloads the activations for the TE layer. When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True` offload_weights: bool, default = `True`
...@@ -491,7 +496,7 @@ def get_cpu_offload_context( ...@@ -491,7 +496,7 @@ def get_cpu_offload_context(
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
num_offload_group=num_layers, num_offload_group=num_layers,
num_prefetch_group=1, num_model_group=model_layers,
tensor_need_offloading_checker=tensor_need_offloading_checker, tensor_need_offloading_checker=tensor_need_offloading_checker,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment