Unverified Commit 67fcc152 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files
parent e0e3d123
......@@ -551,17 +551,23 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
buffer_idx = 0
double_buffer_idx = group_to_reload % 2
main_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.h2d_stream):
# move back tensors
for tensor_label, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_label
if group_id == group_to_reload:
if isinstance(state, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx]
else:
reload_buffer = None
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state[1], device=torch.cuda.current_device()
)
if isinstance(state, tuple):
recovered_tensor = SynchronizedGroupOffloadHandler.reload(
state, True, reload_buffer
)
......@@ -570,14 +576,18 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
elif isinstance(state, list):
tensor_list = []
for state_tuple in state:
if isinstance(state_tuple, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][
buffer_idx
]
else:
reload_buffer = None
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state_tuple[1], device=torch.cuda.current_device()
)
if isinstance(state_tuple, tuple):
tensor_list.append(
SynchronizedGroupOffloadHandler.reload(
state_tuple,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment