Unverified Commit 3be67060 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Fix Group offloading behaviour when using streams (#11097)

* update

* update
parent cb1b8b21
...@@ -181,6 +181,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook): ...@@ -181,6 +181,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
self._layer_execution_tracker_module_names = set() self._layer_execution_tracker_module_names = set()
def initialize_hook(self, module): def initialize_hook(self, module):
def make_execution_order_update_callback(current_name, current_submodule):
def callback():
logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any # To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the # of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
# layers are executed during the forward pass. # layers are executed during the forward pass.
...@@ -192,14 +199,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook): ...@@ -192,14 +199,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING) group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
if group_offloading_hook is not None: if group_offloading_hook is not None:
# For the first forward pass, we have to load in a blocking manner
def make_execution_order_update_callback(current_name, current_submodule): group_offloading_hook.group.non_blocking = False
def callback():
logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule)) layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER) registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
self._layer_execution_tracker_module_names.add(name) self._layer_execution_tracker_module_names.add(name)
...@@ -229,6 +230,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook): ...@@ -229,6 +230,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# Remove the layer execution tracker hooks from the submodules # Remove the layer execution tracker hooks from the submodules
base_module_registry = module._diffusers_hook base_module_registry = module._diffusers_hook
registries = [submodule._diffusers_hook for _, submodule in self.execution_order] registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
for i in range(num_executed): for i in range(num_executed):
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False) registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
...@@ -236,8 +238,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook): ...@@ -236,8 +238,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False) base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
# Apply lazy prefetching by setting required attributes # LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries] # We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
# see the benefits of prefetching.
for hook in group_offloading_hooks:
hook.group.non_blocking = True
# Set required attributes for prefetching
if num_executed > 0: if num_executed > 0:
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING) base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment