Commit 34582381 authored by Michael Carilli's avatar Michael Carilli
Browse files

Syncing imagenet examples

parent d6b2e7d3
...@@ -80,9 +80,6 @@ parser.add_argument('--rank', default=0, type=int, ...@@ -80,9 +80,6 @@ parser.add_argument('--rank', default=0, type=int,
cudnn.benchmark = True cudnn.benchmark = True
import numpy as np
def fast_collate(batch): def fast_collate(batch):
imgs = [img[0] for img in batch] imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
...@@ -113,8 +110,10 @@ def main(): ...@@ -113,8 +110,10 @@ def main():
if args.distributed: if args.distributed:
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, dist.init_process_group(backend=args.dist_backend,
world_size=args.world_size, rank=args.rank) init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank)
if args.fp16: if args.fp16:
assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
...@@ -177,8 +176,8 @@ def main(): ...@@ -177,8 +176,8 @@ def main():
transforms.Compose([ transforms.Compose([
transforms.RandomResizedCrop(crop_size), transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
#transforms.ToTensor(), Too slow # transforms.ToTensor(), Too slow
#normalize, # normalize,
])) ]))
if args.distributed: if args.distributed:
...@@ -227,13 +226,6 @@ def main(): ...@@ -227,13 +226,6 @@ def main():
'optimizer' : optimizer.state_dict(), 'optimizer' : optimizer.state_dict(),
}, is_best) }, is_best)
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
class data_prefetcher(): class data_prefetcher():
def __init__(self, loader): def __init__(self, loader):
self.loader = iter(loader) self.loader = iter(loader)
......
...@@ -16,14 +16,14 @@ import torchvision.transforms as transforms ...@@ -16,14 +16,14 @@ import torchvision.transforms as transforms
import torchvision.datasets as datasets import torchvision.datasets as datasets
import torchvision.models as models import torchvision.models as models
import numpy as np
try: try:
from apex.parallel import DistributedDataParallel as DDP from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import * from apex.fp16_utils import *
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
import numpy as np
model_names = sorted(name for name in models.__dict__ model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__") if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])) and callable(models.__dict__[name]))
...@@ -83,6 +83,23 @@ parser.add_argument('--rank', default=0, type=int, ...@@ -83,6 +83,23 @@ parser.add_argument('--rank', default=0, type=int,
cudnn.benchmark = True cudnn.benchmark = True
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
tens = torch.from_numpy(nump_array)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
best_prec1 = 0 best_prec1 = 0
args = parser.parse_args() args = parser.parse_args()
def main(): def main():
...@@ -151,8 +168,6 @@ def main(): ...@@ -151,8 +168,6 @@ def main():
# Data loading code # Data loading code
traindir = os.path.join(args.data, 'train') traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val') valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if(args.arch == "inception_v3"): if(args.arch == "inception_v3"):
crop_size = 299 crop_size = 299
...@@ -166,8 +181,8 @@ def main(): ...@@ -166,8 +181,8 @@ def main():
transforms.Compose([ transforms.Compose([
transforms.RandomResizedCrop(crop_size), transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToTensor(), # transforms.ToTensor(), Too slow
normalize, # normalize,
])) ]))
if args.distributed: if args.distributed:
...@@ -177,17 +192,16 @@ def main(): ...@@ -177,17 +192,16 @@ def main():
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler) num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
val_loader = torch.utils.data.DataLoader( val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([ datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(val_size), transforms.Resize(val_size),
transforms.CenterCrop(crop_size), transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
])), ])),
batch_size=args.batch_size, shuffle=False, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True) num_workers=args.workers, pin_memory=True,
collate_fn=fast_collate)
if args.evaluate: if args.evaluate:
validate(val_loader, model, criterion) validate(val_loader, model, criterion)
...@@ -221,6 +235,11 @@ class data_prefetcher(): ...@@ -221,6 +235,11 @@ class data_prefetcher():
def __init__(self, loader): def __init__(self, loader):
self.loader = iter(loader) self.loader = iter(loader)
self.stream = torch.cuda.Stream() self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
if args.fp16:
self.mean = self.mean.half()
self.std = self.std.half()
self.preload() self.preload()
def preload(self): def preload(self):
...@@ -233,7 +252,12 @@ class data_prefetcher(): ...@@ -233,7 +252,12 @@ class data_prefetcher():
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True) self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True) self.next_target = self.next_target.cuda(async=True)
if args.fp16:
self.next_input = self.next_input.half()
else:
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self): def next(self):
torch.cuda.current_stream().wait_stream(self.stream) torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input input = self.next_input
...@@ -304,11 +328,15 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -304,11 +328,15 @@ def train(train_loader, model, criterion, optimizer, epoch):
if args.rank == 0 and i % args.print_freq == 0 and i > 1: if args.rank == 0 and i % args.print_freq == 0 and i > 1:
print('Epoch: [{0}][{1}/{2}]\t' print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time, epoch, i, len(train_loader),
args.world_size * args.batch_size / batch_time.val,
args.world_size * args.batch_size / batch_time.avg,
batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5)) data_time=data_time, loss=losses, top1=top1, top5=top5))
...@@ -359,10 +387,14 @@ def validate(val_loader, model, criterion): ...@@ -359,10 +387,14 @@ def validate(val_loader, model, criterion):
if args.rank == 0 and i % args.print_freq == 0: if args.rank == 0 and i % args.print_freq == 0:
print('Test: [{0}/{1}]\t' print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {2:.3f} ({3:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses, i, len(val_loader),
args.world_size * args.batch_size / batch_time.val,
args.world_size * args.batch_size / batch_time.avg,
batch_time=batch_time, loss=losses,
top1=top1, top5=top5)) top1=top1, top5=top5))
input, target = prefetcher.next() input, target = prefetcher.next()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment