Commit 858d7899 authored by Kexin Yu's avatar Kexin Yu
Browse files

Merge branch 'master' of https://github.com/NVIDIA/apex

parents 8d2647f8 2ca894da
...@@ -47,7 +47,7 @@ class FusedNovoGrad(torch.optim.Optimizer): ...@@ -47,7 +47,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
reg_inside_moment (bool, optional): whether do regularization (norm and L2) reg_inside_moment (bool, optional): whether do regularization (norm and L2)
in momentum calculation. True for include, False for not include and in momentum calculation. True for include, False for not include and
only do it on update term. (default: False) only do it on update term. (default: False)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when grad_averaging (bool, optional): whether apply (1-beta1) to grad when
calculating running averages of gradient. (default: True) calculating running averages of gradient. (default: True)
norm_type (int, optional): which norm to calculate for each layer. norm_type (int, optional): which norm to calculate for each layer.
2 for L2 norm, and 0 for infinite norm. These 2 are only supported 2 for L2 norm, and 0 for infinite norm. These 2 are only supported
......
...@@ -25,21 +25,19 @@ try: ...@@ -25,21 +25,19 @@ try:
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
def fast_collate(batch, memory_format):
def fast_collate(batch):
imgs = [img[0] for img in batch] imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0] w = imgs[0].size[0]
h = imgs[0].size[1] h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format)
for i, img in enumerate(imgs): for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8) nump_array = np.asarray(img, dtype=np.uint8)
if(nump_array.ndim < 3): if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1) nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2) nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array) tensor[i] += torch.from_numpy(nump_array)
return tensor, targets return tensor, targets
...@@ -90,6 +88,7 @@ def parse(): ...@@ -90,6 +88,7 @@ def parse():
parser.add_argument('--opt-level', type=str) parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--channels-last', type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -127,6 +126,11 @@ def main(): ...@@ -127,6 +126,11 @@ def main():
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# create model # create model
if args.pretrained: if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch)) print("=> using pre-trained model '{}'".format(args.arch))
...@@ -140,7 +144,7 @@ def main(): ...@@ -140,7 +144,7 @@ def main():
print("using apex synced BN") print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model) model = apex.parallel.convert_syncbn_model(model)
model = model.cuda() model = model.cuda().to(memory_format=memory_format)
# Scale learning rate based on global batch size # Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256. args.lr = args.lr*float(args.batch_size*args.world_size)/256.
...@@ -218,16 +222,18 @@ def main(): ...@@ -218,16 +222,18 @@ def main():
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
collate_fn = lambda b: fast_collate(b, memory_format)
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader( val_loader = torch.utils.data.DataLoader(
val_dataset, val_dataset,
batch_size=args.batch_size, shuffle=False, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, num_workers=args.workers, pin_memory=True,
sampler=val_sampler, sampler=val_sampler,
collate_fn=fast_collate) collate_fn=collate_fn)
if args.evaluate: if args.evaluate:
validate(val_loader, model, criterion) validate(val_loader, model, criterion)
......
...@@ -2,7 +2,6 @@ import torch ...@@ -2,7 +2,6 @@ import torch
from setuptools import setup, find_packages from setuptools import setup, find_packages
import subprocess import subprocess
from pip._internal import main as pipmain
import sys import sys
import warnings import warnings
import os import os
...@@ -31,10 +30,11 @@ if TORCH_MAJOR == 0 and TORCH_MINOR < 4: ...@@ -31,10 +30,11 @@ if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
cmdclass = {} cmdclass = {}
ext_modules = [] ext_modules = []
extras = {}
if "--pyprof" in sys.argv: if "--pyprof" in sys.argv:
with open('requirements.txt') as f: with open('requirements.txt') as f:
required_packages = f.read().splitlines() required_packages = f.read().splitlines()
pipmain(["install"] + required_packages) extras['pyprof'] = required_packages
try: try:
sys.argv.remove("--pyprof") sys.argv.remove("--pyprof")
except: except:
...@@ -153,9 +153,7 @@ if "--bnp" in sys.argv: ...@@ -153,9 +153,7 @@ if "--bnp" in sys.argv:
'nvcc':['-DCUDA_HAS_FP16=1', 'nvcc':['-DCUDA_HAS_FP16=1',
'-D__CUDA_NO_HALF_OPERATORS__', '-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__', '-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__', '-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros}))
'-gencode',
'arch=compute_70,code=sm_70'] + version_dependent_macros}))
if "--xentropy" in sys.argv: if "--xentropy" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
...@@ -209,4 +207,5 @@ setup( ...@@ -209,4 +207,5 @@ setup(
description='PyTorch Extensions written by NVIDIA', description='PyTorch Extensions written by NVIDIA',
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass=cmdclass, cmdclass=cmdclass,
extras_require=extras,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment