Unverified Commit ae81313f authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Miscellaneous improvements to the classification reference scripts (#894)

* Miscellaneous improvements to the classification reference scritps

* Fix lint
parent 43ab2fef
...@@ -59,7 +59,18 @@ def evaluate(model, criterion, data_loader, device): ...@@ -59,7 +59,18 @@ def evaluate(model, criterion, data_loader, device):
return metric_logger.acc1.global_avg return metric_logger.acc1.global_avg
def _get_cache_path(filepath):
import hashlib
h = hashlib.sha1(filepath.encode()).hexdigest()
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
cache_path = os.path.expanduser(cache_path)
return cache_path
def main(args): def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args) utils.init_distributed_mode(args)
print(args) print(args)
...@@ -76,28 +87,45 @@ def main(args): ...@@ -76,28 +87,45 @@ def main(args):
print("Loading training data") print("Loading training data")
st = time.time() st = time.time()
scale = (0.08, 1.0) cache_path = _get_cache_path(traindir)
if args.model == 'mobilenet_v2': if args.cache_dataset and os.path.exists(cache_path):
scale = (0.2, 1.0) # Attention, as the transforms are also cached!
dataset = torchvision.datasets.ImageFolder( print("Loading dataset_train from {}".format(cache_path))
traindir, dataset, _ = torch.load(cache_path)
transforms.Compose([ else:
transforms.RandomResizedCrop(224, scale=scale), dataset = torchvision.datasets.ImageFolder(
transforms.RandomHorizontalFlip(), traindir,
transforms.ToTensor(), transforms.Compose([
normalize, transforms.RandomResizedCrop(224),
])) transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st) print("Took", time.time() - st)
print("Loading validation data") print("Loading validation data")
dataset_test = torchvision.datasets.ImageFolder( cache_path = _get_cache_path(valdir)
valdir, if args.cache_dataset and os.path.exists(cache_path):
transforms.Compose([ # Attention, as the transforms are also cached!
transforms.Resize(256), print("Loading dataset_test from {}".format(cache_path))
transforms.CenterCrop(224), dataset_test, _ = torch.load(cache_path)
transforms.ToTensor(), else:
normalize, dataset_test = torchvision.datasets.ImageFolder(
])) valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)
print("Creating data loaders") print("Creating data loaders")
if args.distributed: if args.distributed:
...@@ -118,7 +146,7 @@ def main(args): ...@@ -118,7 +146,7 @@ def main(args):
print("Creating model") print("Creating model")
model = torchvision.models.__dict__[args.model]() model = torchvision.models.__dict__[args.model]()
model.to(device) model.to(device)
if args.distributed: if args.distributed and args.sync_bn:
model = torch.nn.utils.convert_sync_batchnorm(model) model = torch.nn.utils.convert_sync_batchnorm(model)
model_without_ddp = model model_without_ddp = model
...@@ -131,7 +159,6 @@ def main(args): ...@@ -131,7 +159,6 @@ def main(args):
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
# if using mobilenet, step_size=2 and gamma=0.94
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
if args.resume: if args.resume:
...@@ -139,6 +166,7 @@ def main(args): ...@@ -139,6 +166,7 @@ def main(args):
model_without_ddp.load_state_dict(checkpoint['model']) model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.test_only: if args.test_only:
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
...@@ -146,26 +174,32 @@ def main(args): ...@@ -146,26 +174,32 @@ def main(args):
print("Start training") print("Start training")
start_time = time.time() start_time = time.time()
for epoch in range(args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
lr_scheduler.step() lr_scheduler.step()
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq) train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir: if args.output_dir:
utils.save_on_master({ checkpoint = {
'model': model_without_ddp.state_dict(), 'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(),
'args': args}, 'epoch': epoch,
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str)) print('Training time {}'.format(total_time_str))
if __name__ == "__main__": def parse_args():
import argparse import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training') parser = argparse.ArgumentParser(description='PyTorch Classification Training')
...@@ -188,6 +222,20 @@ if __name__ == "__main__": ...@@ -188,6 +222,20 @@ if __name__ == "__main__":
parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument(
"--cache-dataset",
dest="cache_dataset",
help="Cache the datasets for quicker initialization. It also serializes the transforms",
action="store_true",
)
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true",
)
parser.add_argument( parser.add_argument(
"--test-only", "--test-only",
dest="test_only", dest="test_only",
...@@ -202,7 +250,9 @@ if __name__ == "__main__": ...@@ -202,7 +250,9 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if args.output_dir: return args
utils.mkdir(args.output_dir)
if __name__ == "__main__":
args = parse_args()
main(args) main(args)
...@@ -214,13 +214,15 @@ def save_on_master(*args, **kwargs): ...@@ -214,13 +214,15 @@ def save_on_master(*args, **kwargs):
def init_distributed_mode(args): def init_distributed_mode(args):
if 'SLURM_PROCID' in os.environ: if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"]) args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE']) args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK']) args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else: else:
print('Not using distributed mode') print('Not using distributed mode')
args.distributed = False args.distributed = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment