Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d5e2bb4b
Commit
d5e2bb4b
authored
Jun 17, 2019
by
Michael Carilli
Browse files
Fix rare caching allocator race condition in imagenet prefetcher
parent
c3bcf18e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
0 deletions
+16
-0
examples/imagenet/main_amp.py
examples/imagenet/main_amp.py
+16
-0
No files found.
examples/imagenet/main_amp.py
View file @
d5e2bb4b
...
@@ -272,9 +272,23 @@ class data_prefetcher():
...
@@ -272,9 +272,23 @@ class data_prefetcher():
self
.
next_input
=
None
self
.
next_input
=
None
self
.
next_target
=
None
self
.
next_target
=
None
return
return
# if record_stream() doesn't work, another option is to make sure device inputs are created
# on the main stream.
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
# Need to make sure the memory allocated for next_* is not still in use by the main stream
# at the time we start copying to next_*:
# self.stream.wait_stream(torch.cuda.current_stream())
with
torch
.
cuda
.
stream
(
self
.
stream
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
self
.
next_input
=
self
.
next_input
.
cuda
(
non_blocking
=
True
)
self
.
next_input
=
self
.
next_input
.
cuda
(
non_blocking
=
True
)
self
.
next_target
=
self
.
next_target
.
cuda
(
non_blocking
=
True
)
self
.
next_target
=
self
.
next_target
.
cuda
(
non_blocking
=
True
)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
# self.next_input = self.next_input_gpu
# self.next_target = self.next_target_gpu
# With Amp, it isn't necessary to manually convert data to half.
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# if args.fp16:
# self.next_input = self.next_input.half()
# self.next_input = self.next_input.half()
...
@@ -286,6 +300,8 @@ class data_prefetcher():
...
@@ -286,6 +300,8 @@ class data_prefetcher():
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
stream
)
input
=
self
.
next_input
input
=
self
.
next_input
target
=
self
.
next_target
target
=
self
.
next_target
input
.
record_stream
(
torch
.
cuda
.
current_stream
())
target
.
record_stream
(
torch
.
cuda
.
current_stream
())
self
.
preload
()
self
.
preload
()
return
input
,
target
return
input
,
target
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment