Commit e86f986d authored by Syed Tousif Ahmed's avatar Syed Tousif Ahmed
Browse files

Put parser in a function to make script importable

parent 8d0deb09
...@@ -25,89 +25,92 @@ try: ...@@ -25,89 +25,92 @@ try:
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
model_names = sorted(name for name in models.__dict__
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
def parse():
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__") if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])) and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR', parser.add_argument('data', metavar='DIR',
help='path to dataset') help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=model_names, choices=model_names,
help='model architecture: ' + help='model architecture: ' +
' | '.join(model_names) + ' | '.join(model_names) +
' (default: resnet18)') ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)') help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N', parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run') help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int, parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size per process (default: 256)') metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.') metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum') help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)') metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int, parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)') metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set') help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model') help='use pre-trained model')
parser.add_argument('--prof', default=-1, type=int, parser.add_argument('--prof', default=-1, type=int,
help='Only run 10 iterations for profiling.') help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true') parser.add_argument('--deterministic', action='store_true')
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true', parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.') help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str) parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None)
args = parser.parse_args()
cudnn.benchmark = True return args
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
best_prec1 = 0 def main():
args = parser.parse_args() global best_prec1, args
print("opt_level = {}".format(args.opt_level)) args = parse()
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) print("opt_level = {}".format(args.opt_level))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
if args.deterministic: cudnn.benchmark = True
best_prec1 = 0
if args.deterministic:
cudnn.benchmark = False cudnn.benchmark = False
cudnn.deterministic = True cudnn.deterministic = True
torch.manual_seed(args.local_rank) torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10) torch.set_printoptions(precision=10)
def main():
global best_prec1, args
args.distributed = False args.distributed = False
if 'WORLD_SIZE' in os.environ: if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.distributed = int(os.environ['WORLD_SIZE']) > 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment