Commit 2ca894da authored by Vitaly Fedyunin's avatar Vitaly Fedyunin Committed by mcarilli
Browse files

Channels last support (#668)

parent b66ffc1d
...@@ -25,21 +25,19 @@ try: ...@@ -25,21 +25,19 @@ try:
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
def fast_collate(batch, memory_format):
def fast_collate(batch):
imgs = [img[0] for img in batch] imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0] w = imgs[0].size[0]
h = imgs[0].size[1] h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format)
for i, img in enumerate(imgs): for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8) nump_array = np.asarray(img, dtype=np.uint8)
if(nump_array.ndim < 3): if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1) nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2) nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array) tensor[i] += torch.from_numpy(nump_array)
return tensor, targets return tensor, targets
...@@ -90,6 +88,7 @@ def parse(): ...@@ -90,6 +88,7 @@ def parse():
parser.add_argument('--opt-level', type=str) parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--channels-last', type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -127,6 +126,11 @@ def main(): ...@@ -127,6 +126,11 @@ def main():
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# create model # create model
if args.pretrained: if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch)) print("=> using pre-trained model '{}'".format(args.arch))
...@@ -140,7 +144,7 @@ def main(): ...@@ -140,7 +144,7 @@ def main():
print("using apex synced BN") print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model) model = apex.parallel.convert_syncbn_model(model)
model = model.cuda() model = model.cuda().to(memory_format=memory_format)
# Scale learning rate based on global batch size # Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256. args.lr = args.lr*float(args.batch_size*args.world_size)/256.
...@@ -218,16 +222,18 @@ def main(): ...@@ -218,16 +222,18 @@ def main():
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
collate_fn = lambda b: fast_collate(b, memory_format)
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader( val_loader = torch.utils.data.DataLoader(
val_dataset, val_dataset,
batch_size=args.batch_size, shuffle=False, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, num_workers=args.workers, pin_memory=True,
sampler=val_sampler, sampler=val_sampler,
collate_fn=fast_collate) collate_fn=collate_fn)
if args.evaluate: if args.evaluate:
validate(val_loader, model, criterion) validate(val_loader, model, criterion)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment