Commit cc85a2e5 authored by Michael Carilli's avatar Michael Carilli
Browse files

async->non_blocking, module-specific logging

parent 859f528b
......@@ -56,7 +56,8 @@ class AmpHandle(object):
if should_skip:
optimizer_step = optimizer.step
def skip_step():
logging.info('Gradient overflow, skipping update')
logger = logging.getLogger('apex.amp')
logger.info('Gradient overflow, skipping update')
optimizer.step = optimizer_step
optimizer.step = skip_step
......
......@@ -76,7 +76,8 @@ class OptimWrapper(object):
'The `closure` argument is unsupported by the amp ' +
'optimizer wrapper.')
if any(self._skip_next):
logging.info('Gradient overflow, skipping update')
logger = logging.getLogger('apex.amp')
logger.info('Gradient overflow, skipping update')
self._skip_next = [False] * self._num_loss
else:
return self._optimizer.step(closure=closure)
......
......@@ -275,8 +275,8 @@ class data_prefetcher():
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True)
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
if args.fp16:
self.next_input = self.next_input.half()
else:
......
......@@ -256,8 +256,8 @@ class data_prefetcher():
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True)
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
# With Amp, it isn't necessary to manually convert data to half.
# Type conversions are done internally on the fly within patched torch functions.
# if args.fp16:
......
......@@ -265,8 +265,8 @@ class data_prefetcher():
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True)
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
if args.fp16:
self.next_input = self.next_input.half()
else:
......
......@@ -259,8 +259,8 @@ class data_prefetcher():
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True)
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
if args.fp16:
self.next_input = self.next_input.half()
else:
......@@ -377,8 +377,6 @@ def validate(val_loader, model, criterion):
while input is not None:
i += 1
target = target.cuda(async=True)
# compute output
with torch.no_grad():
output = model(input)
......
......@@ -6,6 +6,8 @@ import itertools as it
import torch
from apex.fp16_utils import FP16_Optimizer
# Currently no-ops (tested via examples).
# FP16_Optimizer to be deprecated and moved under unified Amp API.
class TestFP16Optimizer(unittest.TestCase):
def setUp(self):
N, D_in, D_out = 64, 1024, 16
......
import unittest
import sys
test_dirs = ["run_fp16_optimizer", "run_amp", "run_mixed_adam"]
test_dirs = ["run_amp", "run_mixed_adam"]
runner = unittest.TextTestRunner(verbosity=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment