Commit 7d0aef9f authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'dist_valid'

parents 2af29c19 ae3be17a
...@@ -170,23 +170,26 @@ def main(): ...@@ -170,23 +170,26 @@ def main():
# transforms.ToTensor(), Too slow # transforms.ToTensor(), Too slow
# normalize, # normalize,
])) ]))
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(val_size),
transforms.CenterCrop(crop_size),
]))
train_sampler = None
val_sampler = None
if args.distributed: if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else: val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
train_sampler = None
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
val_loader = torch.utils.data.DataLoader( val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([ val_dataset,
transforms.Resize(val_size),
transforms.CenterCrop(crop_size),
])),
batch_size=args.batch_size, shuffle=False, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, num_workers=args.workers, pin_memory=True,
sampler=val_sampler,
collate_fn=fast_collate) collate_fn=fast_collate)
if args.evaluate: if args.evaluate:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment