Commit cae6005c authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Update imagenet example to fast version.

parent 343590a1
...@@ -16,14 +16,14 @@ import torchvision.transforms as transforms ...@@ -16,14 +16,14 @@ import torchvision.transforms as transforms
import torchvision.datasets as datasets import torchvision.datasets as datasets
import torchvision.models as models import torchvision.models as models
import numpy as np
try: try:
from apex.parallel import DistributedDataParallel as DDP from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import * from apex.fp16_utils import *
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
import numpy as np
model_names = sorted(name for name in models.__dict__ model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__") if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])) and callable(models.__dict__[name]))
...@@ -61,8 +61,8 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true', ...@@ -61,8 +61,8 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
parser.add_argument('--fp16', action='store_true', parser.add_argument('--fp16', action='store_true',
help='Run model fp16 mode.') help='Run model fp16 mode.')
parser.add_argument('--static-loss-scale', type=float, default=1, parser.add_argument('--loss-scale', type=float, default=1,
help='Static loss scale, positive power of 2 values can improve fp16 convergence.') help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
parser.add_argument('--prof', dest='prof', action='store_true', parser.add_argument('--prof', dest='prof', action='store_true',
help='Only run 10 iterations for profiling.') help='Only run 10 iterations for profiling.')
...@@ -80,6 +80,26 @@ parser.add_argument('--rank', default=0, type=int, ...@@ -80,6 +80,26 @@ parser.add_argument('--rank', default=0, type=int,
cudnn.benchmark = True cudnn.benchmark = True
import numpy as np
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
tens = torch.from_numpy(nump_array)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
best_prec1 = 0 best_prec1 = 0
args = parser.parse_args() args = parser.parse_args()
def main(): def main():
...@@ -93,18 +113,12 @@ def main(): ...@@ -93,18 +113,12 @@ def main():
if args.distributed: if args.distributed:
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
dist.init_process_group(backend=args.dist_backend, dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
world_size=args.world_size,
rank=args.rank)
if args.fp16: if args.fp16:
assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
if args.static_loss_scale != 1.0:
if not args.fp16:
print("Warning: if --fp16 is not used, static_loss_scale will be ignored.")
# create model # create model
if args.pretrained: if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch)) print("=> using pre-trained model '{}'".format(args.arch))
...@@ -154,7 +168,7 @@ def main(): ...@@ -154,7 +168,7 @@ def main():
if(args.arch == "inception_v3"): if(args.arch == "inception_v3"):
crop_size = 299 crop_size = 299
val_size = 320 # Arbitrarily chosen, adjustable. val_size = 320 # I chose this value arbitrarily, we can adjust.
else: else:
crop_size = 224 crop_size = 224
val_size = 256 val_size = 256
...@@ -164,8 +178,8 @@ def main(): ...@@ -164,8 +178,8 @@ def main():
transforms.Compose([ transforms.Compose([
transforms.RandomResizedCrop(crop_size), transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToTensor(), #transforms.ToTensor(), Too slow
normalize, #normalize,
])) ]))
if args.distributed: if args.distributed:
...@@ -175,7 +189,7 @@ def main(): ...@@ -175,7 +189,7 @@ def main():
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler) num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
val_loader = torch.utils.data.DataLoader( val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([ datasets.ImageFolder(valdir, transforms.Compose([
...@@ -215,6 +229,13 @@ def main(): ...@@ -215,6 +229,13 @@ def main():
'optimizer' : optimizer.state_dict(), 'optimizer' : optimizer.state_dict(),
}, is_best) }, is_best)
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
class data_prefetcher(): class data_prefetcher():
def __init__(self, loader): def __init__(self, loader):
self.loader = iter(loader) self.loader = iter(loader)
...@@ -284,15 +305,15 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -284,15 +305,15 @@ def train(train_loader, model, criterion, optimizer, epoch):
top1.update(to_python_float(prec1), input.size(0)) top1.update(to_python_float(prec1), input.size(0))
top5.update(to_python_float(prec5), input.size(0)) top5.update(to_python_float(prec5), input.size(0))
loss = loss*args.loss_scale
# compute gradient and do SGD step # compute gradient and do SGD step
if args.fp16: if args.fp16:
loss = loss*args.static_loss_scale
model.zero_grad() model.zero_grad()
loss.backward() loss.backward()
model_grads_to_master_grads(model_params, master_params) model_grads_to_master_grads(model_params, master_params)
if args.static_loss_scale != 1: if args.loss_scale != 1:
for param in master_params: for param in master_params:
param.grad.data = param.grad.data/args.static_loss_scale param.grad.data = param.grad.data/args.loss_scale
optimizer.step() optimizer.step()
master_params_to_model_params(model_params, master_params) master_params_to_model_params(model_params, master_params)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment