Commit a2801d91 authored by Michael Carilli's avatar Michael Carilli
Browse files

Revising LR scaling to account for any choice of num processes, batch size per process

parent 01e29c97
...@@ -43,9 +43,9 @@ parser.add_argument('--epochs', default=90, type=int, metavar='N', ...@@ -43,9 +43,9 @@ parser.add_argument('--epochs', default=90, type=int, metavar='N',
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int, parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)') metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate') metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum') help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
...@@ -133,8 +133,8 @@ def main(): ...@@ -133,8 +133,8 @@ def main():
# define loss function (criterion) and optimizer # define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().cuda()
# Scale learning rate based on per-process batch size # Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size)/256. args.lr = args.lr*float(args.batch_size*args.world_size)/256.
optimizer = torch.optim.SGD(master_params, args.lr, optimizer = torch.optim.SGD(master_params, args.lr,
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
......
...@@ -43,9 +43,9 @@ parser.add_argument('--epochs', default=90, type=int, metavar='N', ...@@ -43,9 +43,9 @@ parser.add_argument('--epochs', default=90, type=int, metavar='N',
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int, parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)') metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate') metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum') help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
...@@ -134,8 +134,8 @@ def main(): ...@@ -134,8 +134,8 @@ def main():
# define loss function (criterion) and optimizer # define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().cuda()
# Scale learning rate based on per-process batch size # Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size)/256. args.lr = args.lr*float(args.batch_size*args.world_size)/256.
optimizer = torch.optim.SGD(model.parameters(), args.lr, optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
......
...@@ -43,9 +43,9 @@ parser.add_argument('--epochs', default=90, type=int, metavar='N', ...@@ -43,9 +43,9 @@ parser.add_argument('--epochs', default=90, type=int, metavar='N',
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int, parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)') metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate') metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum') help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
...@@ -133,8 +133,8 @@ def main(): ...@@ -133,8 +133,8 @@ def main():
# define loss function (criterion) and optimizer # define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().cuda()
# Scale learning rate based on per-process batch size # Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size)/256. args.lr = args.lr*float(args.batch_size*args.world_size)/256.
optimizer = torch.optim.SGD(master_params, args.lr, optimizer = torch.optim.SGD(master_params, args.lr,
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment