Commit d5e2bb4b authored by Michael Carilli's avatar Michael Carilli
Browse files

Fix rare caching allocator race condition in imagenet prefetcher

parent c3bcf18e
...@@ -272,9 +272,23 @@ class data_prefetcher(): ...@@ -272,9 +272,23 @@ class data_prefetcher():
self.next_input = None self.next_input = None
self.next_target = None self.next_target = None
return return
# if record_stream() doesn't work, another option is to make sure device inputs are created
# on the main stream.
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
# Need to make sure the memory allocated for next_* is not still in use by the main stream
# at the time we start copying to next_*:
# self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True) self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True) self.next_target = self.next_target.cuda(non_blocking=True)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
# self.next_input = self.next_input_gpu
# self.next_target = self.next_target_gpu
# With Amp, it isn't necessary to manually convert data to half. # With Amp, it isn't necessary to manually convert data to half.
# if args.fp16: # if args.fp16:
# self.next_input = self.next_input.half() # self.next_input = self.next_input.half()
...@@ -286,6 +300,8 @@ class data_prefetcher(): ...@@ -286,6 +300,8 @@ class data_prefetcher():
torch.cuda.current_stream().wait_stream(self.stream) torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input input = self.next_input
target = self.next_target target = self.next_target
input.record_stream(torch.cuda.current_stream())
target.record_stream(torch.cuda.current_stream())
self.preload() self.preload()
return input, target return input, target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment