Commit 414dc119 authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Re-add normalization, correct typing.

parent 0081afb8
......@@ -163,8 +163,6 @@ def main():
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if(args.arch == "inception_v3"):
crop_size = 299
......@@ -191,6 +189,10 @@ def main():
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(val_size),
......@@ -240,6 +242,11 @@ class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1)
if args.fp16:
self.mean = self.mean.half()
self.std = self.std.half()
self.preload()
def preload(self):
......@@ -252,7 +259,12 @@ class data_prefetcher():
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True)
if args.fp16:
self.next_input = self.next_input.half()
else:
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment