You need to sign in or sign up before continuing.
Commit 39c9be85 authored by Michael Carilli's avatar Michael Carilli
Browse files

Syncing imagenet examples to use distributed validation

parent 7d0aef9f
...@@ -175,23 +175,26 @@ def main(): ...@@ -175,23 +175,26 @@ def main():
# transforms.ToTensor(), Too slow # transforms.ToTensor(), Too slow
# normalize, # normalize,
])) ]))
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(val_size),
transforms.CenterCrop(crop_size),
]))
train_sampler = None
val_sampler = None
if args.distributed: if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else: val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
train_sampler = None
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
val_loader = torch.utils.data.DataLoader( val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([ val_dataset,
transforms.Resize(val_size),
transforms.CenterCrop(crop_size),
])),
batch_size=args.batch_size, shuffle=False, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, num_workers=args.workers, pin_memory=True,
sampler=val_sampler,
collate_fn=fast_collate) collate_fn=fast_collate)
if args.evaluate: if args.evaluate:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment