Unverified Commit 69cdc257 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Fix group offloading synchronization bug for parameter-only GroupModule's (#12077)



* update

* update

* refactor

* fuck yeah

* make style

* Update src/diffusers/hooks/group_offloading.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/hooks/group_offloading.py

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent cfd6ec74
...@@ -245,7 +245,6 @@ class ModuleGroup: ...@@ -245,7 +245,6 @@ class ModuleGroup:
param.data = self.cpu_param_dict[param] param.data = self.cpu_param_dict[param]
for buffer in self.buffers: for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer] buffer.data = self.cpu_param_dict[buffer]
else: else:
for group_module in self.modules: for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False) group_module.to(self.offload_device, non_blocking=False)
...@@ -303,9 +302,23 @@ class GroupOffloadingHook(ModelHook): ...@@ -303,9 +302,23 @@ class GroupOffloadingHook(ModelHook):
if self.group.onload_leader == module: if self.group.onload_leader == module:
if self.group.onload_self: if self.group.onload_self:
self.group.onload_() self.group.onload_()
if self.next_group is not None and not self.next_group.onload_self:
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
if should_onload_next_group:
self.next_group.onload_() self.next_group.onload_()
should_synchronize = (
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
)
if should_synchronize:
# If this group didn't onload itself, it means it was asynchronously onloaded by the
# previous group. We need to synchronize the side stream to ensure parameters
# are completely loaded to proceed with forward pass. Without this, uninitialized
# weights will be used in the computation, leading to incorrect results
# Also, we should only do this synchronization if we don't already do it from the sync call in
# self.next_group.onload_, hence the `not should_onload_next_group` check.
self.group.stream.synchronize()
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
return args, kwargs return args, kwargs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment