"cacheflow/vscode:/vscode.git/clone" did not exist on "9f88db35da012544a9cd9450c68b8df2e6509b92"
Commit 2ca894da authored by Vitaly Fedyunin's avatar Vitaly Fedyunin Committed by mcarilli
Browse files

Channels last support (#668)

parent b66ffc1d
......@@ -25,21 +25,19 @@ try:
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
def fast_collate(batch, memory_format):
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format)
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
......@@ -90,6 +88,7 @@ def parse():
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--channels-last', type=bool, default=False)
args = parser.parse_args()
return args
......@@ -127,6 +126,11 @@ def main():
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
......@@ -140,7 +144,7 @@ def main():
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda()
model = model.cuda().to(memory_format=memory_format)
# Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
......@@ -218,16 +222,18 @@ def main():
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
collate_fn = lambda b: fast_collate(b, memory_format)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True,
sampler=val_sampler,
collate_fn=fast_collate)
collate_fn=collate_fn)
if args.evaluate:
validate(val_loader, model, criterion)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment