Unverified Commit 38c26dd8 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Fixed double buffering issue for assymetric layers (#1984)



* Fixed double buffering issue for assymetric layers
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1470116e
...@@ -556,21 +556,33 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -556,21 +556,33 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
for tensor_label, state in self.tensor_tag_to_state.items(): for tensor_label, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_label group_id, _ = tensor_label
if group_id == group_to_reload: if group_id == group_to_reload:
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx]
else:
reload_buffer = None
if isinstance(state, tuple): if isinstance(state, tuple):
recovered_tensor = SynchronizedGroupOffloadHandler.reload( recovered_tensor = SynchronizedGroupOffloadHandler.reload(
state, True, self.reload_double_buffer[double_buffer_idx][buffer_idx] state, True, reload_buffer
) )
buffer_idx = buffer_idx + 1 buffer_idx = buffer_idx + 1
self.tensor_tag_to_state[tensor_label] = recovered_tensor self.tensor_tag_to_state[tensor_label] = recovered_tensor
elif isinstance(state, list): elif isinstance(state, list):
tensor_list = [] tensor_list = []
for state_tuple in state: for state_tuple in state:
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][
buffer_idx
]
else:
reload_buffer = None
if isinstance(state_tuple, tuple): if isinstance(state_tuple, tuple):
tensor_list.append( tensor_list.append(
SynchronizedGroupOffloadHandler.reload( SynchronizedGroupOffloadHandler.reload(
state_tuple, state_tuple,
True, True,
self.reload_double_buffer[double_buffer_idx][buffer_idx], reload_buffer,
) )
) )
buffer_idx = buffer_idx + 1 buffer_idx = buffer_idx + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment