Commit dc41c5ce authored by Michael Carilli's avatar Michael Carilli
Browse files

Adjusting learning rate for batch size

parent 437bcf22
...@@ -64,9 +64,6 @@ $ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main_fp16_optimiz ...@@ -64,9 +64,6 @@ $ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main_fp16_optimiz
## Usage for `main.py` and `main_fp16_optimizer.py` ## Usage for `main.py` and `main_fp16_optimizer.py`
```bash
```
`main_fp16_optimizer.py` also accepts the optional flag `main_fp16_optimizer.py` also accepts the optional flag
```bash ```bash
--dynamic-loss-scale Use dynamic loss scaling. If supplied, this argument --dynamic-loss-scale Use dynamic loss scaling. If supplied, this argument
......
...@@ -133,6 +133,8 @@ def main(): ...@@ -133,6 +133,8 @@ def main():
# define loss function (criterion) and optimizer # define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().cuda()
# Scale learning rate based on per-process batch size
args.lr = args.lr*float(args.batch_size)/256.
optimizer = torch.optim.SGD(master_params, args.lr, optimizer = torch.optim.SGD(master_params, args.lr,
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
......
...@@ -134,6 +134,8 @@ def main(): ...@@ -134,6 +134,8 @@ def main():
# define loss function (criterion) and optimizer # define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().cuda()
# Scale learning rate based on per-process batch size
args.lr = args.lr*float(args.batch_size)/256.
optimizer = torch.optim.SGD(model.parameters(), args.lr, optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment