Commit eae8b989 authored by Michael Carilli's avatar Michael Carilli
Browse files

Adjusting learning rate schedule for 76% accuracy

parent 39c9be85
......@@ -87,16 +87,7 @@ import apex_C as apex_backend
```
### Windows support
Windows support is experimental, and Linux is recommended. If you wish to install Apex in Windows, there are two requirements:
1. Apex must be installed in the same Conda environment as Pytorch.
2. Building Apex requires the same Visual Studio environment settings as [building Pytorch from source](https://github.com/pytorch/pytorch#install-pytorch):
```
cd apex_dir
set "VS150COMNTOOLS=C:\Program Files (x86)\Microsoft Visual Studio\2017\Enterprise\VC\Auxiliary\Build"
call "%VS150COMNTOOLS%\vcvarsall.bat" x64 -vcvars_ver=14.11
python setup.py install
```
You may need to replace `2017`, `Enterprise`, or `vcvars_ver` according to your version of Visual Studio.
Windows support is experimental, and Linux is recommended. However, since Apex is Python-only, there's a good chance it "just works." If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
<!--
reparametrization and RNN API under construction
......
......@@ -199,7 +199,6 @@ def main():
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
adjust_learning_rate(optimizer, epoch)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch)
......@@ -272,6 +271,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
while input is not None:
i += 1
adjust_learning_rate(optimizer, epoch, i, len(train_loader))
if args.prof:
if i > 10:
break
......@@ -426,9 +427,22 @@ class AverageMeter(object):
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
def adjust_learning_rate(optimizer, epoch, step, len_epoch):
"""LR schedule that should yield 76% converged accuracy with batch size 256"""
factor = epoch // 30
if epoch >= 80:
factor = factor + 1
lr = args.lr*(0.1**factor)
"""Warmup"""
if epoch < 5:
lr = lr*float(step + epoch*len_epoch)/(5.*len_epoch)
if(args.local_rank == 0):
print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment