Commit 484292f0 authored by Michael Carilli's avatar Michael Carilli
Browse files

some test cleanup

parent 2445031d
import contextlib import contextlib
import logging import logging
import warnings import warnings
import torch
from . import utils from . import utils
from .opt import OptimWrapper from .opt import OptimWrapper
...@@ -83,7 +84,6 @@ def scale_loss(loss, ...@@ -83,7 +84,6 @@ def scale_loss(loss,
"loss scale to {}".format(optimizer.loss_scaler.loss_scale())) "loss scale to {}".format(optimizer.loss_scaler.loss_scale()))
optimizer.step = optimizer_step optimizer.step = optimizer_step
optimizer.step = skip_step optimizer.step = skip_step
# Probably ok to skip this if not delay_unscale # Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.patch_torch_functions: if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache() _amp_state.handle._clear_cache()
......
...@@ -81,9 +81,7 @@ class LossScaler(object): ...@@ -81,9 +81,7 @@ class LossScaler(object):
self._overflow_buf.zero_() self._overflow_buf.zero_()
def unscale(self, model_params, master_params, scale): def unscale(self, model_params, master_params, scale):
# torch.cuda.nvtx.range_push("unscale")
if self._has_overflow: if self._has_overflow:
# torch.cuda.nvtx.range_pop()
return return
# Lots of defensive list processing going on here. Way more less efficient than # Lots of defensive list processing going on here. Way more less efficient than
...@@ -92,6 +90,12 @@ class LossScaler(object): ...@@ -92,6 +90,12 @@ class LossScaler(object):
in zip(model_params, master_params)] # some of these may be None in zip(model_params, master_params)] # some of these may be None
if LossScaler.has_fused_kernel: if LossScaler.has_fused_kernel:
# TODO: Make these lists permanent attributes of self, so they don't need to be created
# or garbage collected. Profiler shows that garbage collection overhead may be
# substantial (200-300 usec).
# This may be tricky because right now the lists need to be packed densely.
# Maybe this could be handled within the multi_tensor_apply wrapper
# (allow some Tensors to be None using at::optional).
src_dst_pairs = {torch.float16 : {torch.float16 : [[],[]], torch.float32 : [[],[]]}, src_dst_pairs = {torch.float16 : {torch.float16 : [[],[]], torch.float32 : [[],[]]},
torch.float32 : {torch.float16 : [[],[]], torch.float32 : [[],[]]}} torch.float32 : {torch.float16 : [[],[]], torch.float32 : [[],[]]}}
...@@ -142,6 +146,8 @@ class LossScaler(object): ...@@ -142,6 +146,8 @@ class LossScaler(object):
if scale == 1.0 and all_same and not self.dynamic: if scale == 1.0 and all_same and not self.dynamic:
return return
# TODO: Make these lists permanent attributes of self, so they don't need to be created
# or garbage collected?
model_grads = [mmp[0].grad.data for mmp in model_master_params if mmp[0].grad is not None] model_grads = [mmp[0].grad.data for mmp in model_master_params if mmp[0].grad is not None]
master_grads = [mmp[1].grad.data for mmp in model_master_params if mmp[1].grad is not None] master_grads = [mmp[1].grad.data for mmp in model_master_params if mmp[1].grad is not None]
...@@ -151,8 +157,6 @@ class LossScaler(object): ...@@ -151,8 +157,6 @@ class LossScaler(object):
if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow: if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
self._has_overflow = self._overflow_buf.item() self._has_overflow = self._overflow_buf.item()
# torch.cuda.nvtx.range_pop()
# Separate so unscale() can be called more that once before updating. # Separate so unscale() can be called more that once before updating.
def update_scale(self): def update_scale(self):
if self._has_overflow and self.dynamic: if self._has_overflow and self.dynamic:
......
...@@ -10,7 +10,7 @@ class MultiTensorApply(object): ...@@ -10,7 +10,7 @@ class MultiTensorApply(object):
MultiTensorApply.available = True MultiTensorApply.available = True
self.chunk_size = chunk_size self.chunk_size = chunk_size
except ImportError as err: except ImportError as err:
MultiTensorApply.availble = False MultiTensorApply.available = False
MultiTensorApply.import_err = err MultiTensorApply.import_err = err
def check_avail(self): def check_avail(self):
......
...@@ -107,6 +107,8 @@ print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.kee ...@@ -107,6 +107,8 @@ print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.kee
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
if args.deterministic: if args.deterministic:
cudnn.benchmark = False cudnn.benchmark = False
cudnn.deterministic = True cudnn.deterministic = True
......
...@@ -46,6 +46,8 @@ rm False* ...@@ -46,6 +46,8 @@ rm False*
set -e set -e
print_banner "Installing Apex with --cuda_ext and --cpp_ext"
pushd ../../.. pushd ../../..
python setup.py install --cuda_ext --cpp_ext python setup.py install --cuda_ext --cpp_ext
popd popd
...@@ -76,6 +78,8 @@ do ...@@ -76,6 +78,8 @@ do
set +x set +x
done done
print_banner "Reinstalling apex without extensions"
pushd ../../.. pushd ../../..
python setup.py install python setup.py install
popd popd
...@@ -102,6 +106,8 @@ do ...@@ -102,6 +106,8 @@ do
do do
for keep_batchnorm in "${keep_batchnorms[@]}" for keep_batchnorm in "${keep_batchnorms[@]}"
do do
echo ""
echo "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} [--has-ext] $DATADIR"
set -x set -x
python compare.py --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} python compare.py --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm}
set +x set +x
...@@ -109,6 +115,8 @@ do ...@@ -109,6 +115,8 @@ do
done done
done done
print_banner "Reinstalling Apex with --cuda_ext and --cpp_ext"
pushd ../../.. pushd ../../..
python setup.py install --cuda_ext --cpp_ext python setup.py install --cuda_ext --cpp_ext
popd popd
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment