Commit cc85a2e5 authored by Michael Carilli's avatar Michael Carilli
Browse files

async->non_blocking, module-specific logging

parent 859f528b
...@@ -56,7 +56,8 @@ class AmpHandle(object): ...@@ -56,7 +56,8 @@ class AmpHandle(object):
if should_skip: if should_skip:
optimizer_step = optimizer.step optimizer_step = optimizer.step
def skip_step(): def skip_step():
logging.info('Gradient overflow, skipping update') logger = logging.getLogger('apex.amp')
logger.info('Gradient overflow, skipping update')
optimizer.step = optimizer_step optimizer.step = optimizer_step
optimizer.step = skip_step optimizer.step = skip_step
......
...@@ -76,7 +76,8 @@ class OptimWrapper(object): ...@@ -76,7 +76,8 @@ class OptimWrapper(object):
'The `closure` argument is unsupported by the amp ' + 'The `closure` argument is unsupported by the amp ' +
'optimizer wrapper.') 'optimizer wrapper.')
if any(self._skip_next): if any(self._skip_next):
logging.info('Gradient overflow, skipping update') logger = logging.getLogger('apex.amp')
logger.info('Gradient overflow, skipping update')
self._skip_next = [False] * self._num_loss self._skip_next = [False] * self._num_loss
else: else:
return self._optimizer.step(closure=closure) return self._optimizer.step(closure=closure)
......
...@@ -275,8 +275,8 @@ class data_prefetcher(): ...@@ -275,8 +275,8 @@ class data_prefetcher():
self.next_target = None self.next_target = None
return return
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True) self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(async=True) self.next_target = self.next_target.cuda(non_blocking=True)
if args.fp16: if args.fp16:
self.next_input = self.next_input.half() self.next_input = self.next_input.half()
else: else:
......
...@@ -256,8 +256,8 @@ class data_prefetcher(): ...@@ -256,8 +256,8 @@ class data_prefetcher():
self.next_target = None self.next_target = None
return return
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True) self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(async=True) self.next_target = self.next_target.cuda(non_blocking=True)
# With Amp, it isn't necessary to manually convert data to half. # With Amp, it isn't necessary to manually convert data to half.
# Type conversions are done internally on the fly within patched torch functions. # Type conversions are done internally on the fly within patched torch functions.
# if args.fp16: # if args.fp16:
......
...@@ -265,8 +265,8 @@ class data_prefetcher(): ...@@ -265,8 +265,8 @@ class data_prefetcher():
self.next_target = None self.next_target = None
return return
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True) self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(async=True) self.next_target = self.next_target.cuda(non_blocking=True)
if args.fp16: if args.fp16:
self.next_input = self.next_input.half() self.next_input = self.next_input.half()
else: else:
......
...@@ -259,8 +259,8 @@ class data_prefetcher(): ...@@ -259,8 +259,8 @@ class data_prefetcher():
self.next_target = None self.next_target = None
return return
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True) self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(async=True) self.next_target = self.next_target.cuda(non_blocking=True)
if args.fp16: if args.fp16:
self.next_input = self.next_input.half() self.next_input = self.next_input.half()
else: else:
...@@ -377,8 +377,6 @@ def validate(val_loader, model, criterion): ...@@ -377,8 +377,6 @@ def validate(val_loader, model, criterion):
while input is not None: while input is not None:
i += 1 i += 1
target = target.cuda(async=True)
# compute output # compute output
with torch.no_grad(): with torch.no_grad():
output = model(input) output = model(input)
......
...@@ -6,6 +6,8 @@ import itertools as it ...@@ -6,6 +6,8 @@ import itertools as it
import torch import torch
from apex.fp16_utils import FP16_Optimizer from apex.fp16_utils import FP16_Optimizer
# Currently no-ops (tested via examples).
# FP16_Optimizer to be deprecated and moved under unified Amp API.
class TestFP16Optimizer(unittest.TestCase): class TestFP16Optimizer(unittest.TestCase):
def setUp(self): def setUp(self):
N, D_in, D_out = 64, 1024, 16 N, D_in, D_out = 64, 1024, 16
......
import unittest import unittest
import sys import sys
test_dirs = ["run_fp16_optimizer", "run_amp", "run_mixed_adam"] test_dirs = ["run_amp", "run_mixed_adam"]
runner = unittest.TextTestRunner(verbosity=2) runner = unittest.TextTestRunner(verbosity=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment