Commit 858d7899 authored by Kexin Yu's avatar Kexin Yu
Browse files

Merge branch 'master' of https://github.com/NVIDIA/apex

parents 8d2647f8 2ca894da
......@@ -47,7 +47,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
reg_inside_moment (bool, optional): whether do regularization (norm and L2)
in momentum calculation. True for include, False for not include and
only do it on update term. (default: False)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
grad_averaging (bool, optional): whether apply (1-beta1) to grad when
calculating running averages of gradient. (default: True)
norm_type (int, optional): which norm to calculate for each layer.
2 for L2 norm, and 0 for infinite norm. These 2 are only supported
......
......@@ -25,21 +25,19 @@ try:
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
def fast_collate(batch, memory_format):
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format)
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
......@@ -90,6 +88,7 @@ def parse():
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--channels-last', type=bool, default=False)
args = parser.parse_args()
return args
......@@ -127,6 +126,11 @@ def main():
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
......@@ -140,10 +144,10 @@ def main():
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda()
model = model.cuda().to(memory_format=memory_format)
# Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
......@@ -161,7 +165,7 @@ def main():
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if args.distributed:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
......@@ -218,16 +222,18 @@ def main():
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
collate_fn = lambda b: fast_collate(b, memory_format)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True,
sampler=val_sampler,
collate_fn=fast_collate)
collate_fn=collate_fn)
if args.evaluate:
validate(val_loader, model, criterion)
......@@ -297,7 +303,7 @@ class data_prefetcher():
# else:
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
......@@ -361,20 +367,20 @@ def train(train_loader, model, criterion, optimizer, epoch):
# Measure accuracy
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
# Average loss and accuracy across processes for logging
# Average loss and accuracy across processes for logging
if args.distributed:
reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
# to_python_float incurs a host<->device sync
losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0))
top5.update(to_python_float(prec5), input.size(0))
torch.cuda.synchronize()
batch_time.update((time.time() - end)/args.print_freq)
end = time.time()
......
......@@ -2,7 +2,6 @@ import torch
from setuptools import setup, find_packages
import subprocess
from pip._internal import main as pipmain
import sys
import warnings
import os
......@@ -31,10 +30,11 @@ if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
cmdclass = {}
ext_modules = []
extras = {}
if "--pyprof" in sys.argv:
with open('requirements.txt') as f:
required_packages = f.read().splitlines()
pipmain(["install"] + required_packages)
extras['pyprof'] = required_packages
try:
sys.argv.remove("--pyprof")
except:
......@@ -153,9 +153,7 @@ if "--bnp" in sys.argv:
'nvcc':['-DCUDA_HAS_FP16=1',
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
'-gencode',
'arch=compute_70,code=sm_70'] + version_dependent_macros}))
'-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros}))
if "--xentropy" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
......@@ -209,4 +207,5 @@ setup(
description='PyTorch Extensions written by NVIDIA',
ext_modules=ext_modules,
cmdclass=cmdclass,
extras_require=extras,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment