Commit 7d0aef9f authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'dist_valid'

parents 2af29c19 ae3be17a
......@@ -170,23 +170,26 @@ def main():
# transforms.ToTensor(), Too slow
# normalize,
]))
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(val_size),
transforms.CenterCrop(crop_size),
]))
train_sampler = None
val_sampler = None
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(val_size),
transforms.CenterCrop(crop_size),
])),
val_dataset,
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True,
sampler=val_sampler,
collate_fn=fast_collate)
if args.evaluate:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment