Commit 90e5b05a authored by Michael Carilli's avatar Michael Carilli
Browse files

Fix end-of-epoch with record_stream

parent 1ccaaf4a
...@@ -300,8 +300,10 @@ class data_prefetcher(): ...@@ -300,8 +300,10 @@ class data_prefetcher():
torch.cuda.current_stream().wait_stream(self.stream) torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input input = self.next_input
target = self.next_target target = self.next_target
input.record_stream(torch.cuda.current_stream()) if input is not None:
target.record_stream(torch.cuda.current_stream()) input.record_stream(torch.cuda.current_stream())
if target is not None:
target.record_stream(torch.cuda.current_stream())
self.preload() self.preload()
return input, target return input, target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment