Unverified Commit 47144979 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #173 from NVIDIA/api_refactor

Unified mixed precision API + backend performance improvements
parents 1603407b 6644c6e6
# ImageNet training in PyTorch # Mixed Precision ImageNet Training in PyTorch
This example is based on [https://github.com/pytorch/examples/tree/master/imagenet](https://github.com/pytorch/examples/tree/master/imagenet). `main_amp.py` is based on [https://github.com/pytorch/examples/tree/master/imagenet](https://github.com/pytorch/examples/tree/master/imagenet).
It implements training of popular model architectures, such as ResNet, AlexNet, and VGG on the ImageNet dataset. It implements Automatic Mixed Precision (Amp) training of popular model architectures, such as ResNet, AlexNet, and VGG, on the ImageNet dataset, and illustrates use of the new Amp API along with command-line flags (forwarded to `amp.initialize`) to easily manipulate and switch between various pure and mixed precision training modes.
`main.py` with the `--fp16` argument demonstrates mixed precision training with manual management of master parameters and loss scaling. Three lines enable Amp:
```
`main_fp16_optimizer.py` with `--fp16` demonstrates use of `apex.fp16_utils.FP16_Optimizer` to automatically manage master parameters and loss scaling. # Added after model and optimizer construction
model, optimizer = amp.initialize(model, optimizer, flags...)
`main_amp.py` with `--fp16` demonstrates use of Amp to automatically perform all FP16-friendly operations in half precision under the hood. Notice that with Amp: ...
..* you don't need to explicitly convert your model, or the input data, to half(). Conversions will occur on-the-fly internally within the Amp-patched torch functions. # loss.backward() changed to:
..* dynamic loss scaling is always used under the hood. with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
```
`main_reducer.py` is identical to `main.py`, except that it shows the use of [apex.parallel.Reduce](https://nvidia.github.io/apex/parallel.html#apex.parallel.Reducer) instead of `DistributedDataParallel`. With the new Amp API **you never need to explicitly convert your model, or the input data, to half().**
## Requirements ## Requirements
- `pip install -r requirements.txt`
- Download the ImageNet dataset and move validation images to labeled subfolders - Download the ImageNet dataset and move validation images to labeled subfolders
- To do this, you can use the following script: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh - The following script may be helpful: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh
## Training ## Training
To train a model, run `main.py` with the desired model architecture and the path to the ImageNet dataset. To train a model, create softlinks to the Imagenet dataset, then run `main.py` with the desired model architecture, as shown in `Example commands` below.
The default learning rate schedule is set for ResNet50. `main_amp.py` script rescales the learning rate according to the global batch size (number of distributed processes \* per-process minibatch size).
## Example commands
The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs. This is appropriate for ResNet and models with batch normalization, but too high for AlexNet and VGG. Use 0.01 as the initial learning rate for AlexNet or VGG: **Note:** batch size `--b 224` assumes your GPUs have >=16GB of onboard memory. You may be able to increase this to 256, but that's cutting it close, so it may out-of-memory for different Pytorch versions.
```bash **Note:** All of the following use 4 dataloader subprocesses (`--workers 4`) to reduce potential
python main.py -a alexnet --lr 0.01 /path/to/imagenet/folder CPU data loading bottlenecks.
**Note:** `--opt-level` `O1` and `O2` both use dynamic loss scaling by default unless manually overridden.
`--opt-level` `O0` and `O3` (the "pure" training modes) do not use loss scaling by default.
`O0` and `O3` can be told to use loss scaling via manual overrides, but using loss scaling with `O0`
(pure FP32 training) does not really make sense, and will trigger a warning.
Softlink training and validation datasets into the current directory:
```
$ ln -sf /data/imagenet/train-jpeg/ train
$ ln -sf /data/imagenet/val-jpeg/ val
``` ```
The directory at /path/to/imagenet/directory should contain two subdirectories called "train" ### Summary
and "val" that contain the training and validation data respectively.
Amp allows easy experimentation with various pure and mixed precision options.
```
$ python main_amp.py -a resnet50 --b 128 --workers 4 --opt-level O0 ./
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 ./
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 --keep-batchnorm-fp32 True ./
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 --loss-scale 128.0 ./
$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 --loss-scale 128.0 ./
$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./
```
Options are broken down in detail below.
#### `--opt-level O0` (FP32 training) and `O3` (FP16 training)
"Pure FP32" training:
```
$ python main_amp.py -a resnet50 --b 128 --workers 4 --opt-level O0 ./
```
"Pure FP16" training:
```
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 ./
```
FP16 training with FP32 batchnorm:
```
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 --keep-batchnorm-fp32 True ./
```
Keeping the batchnorms in FP32 improves stability and allows Pytorch
to use cudnn batchnorms, which significantly increases speed in Resnet50.
The `O3` options might not converge, because they are not true mixed precision.
However, they can be useful to establish "speed of light" performance for
your model, which provides a baseline for comparison with `O1` and `O2`.
For Resnet50 in particular, `--opt-level O3 --keep-batchnorm-fp32 True` establishes
the "speed of light." (Without `--keep-batchnorm-fp32`, it's slower, because it does
not use cudnn batchnorm.)
#### `--opt-level O1` ("conservative mixed precision")
`O1` patches Torch functions to cast inputs according to a whitelist-blacklist model.
FP16-friendly (Tensor Core) ops like gemms and convolutions run in FP16, while ops
that benefit from FP32, like batchnorm and softmax, run in FP32.
Also, dynamic loss scaling is used by default.
```
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./
```
`O1` overridden to use static loss scaling:
```
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 --loss-scale 128.0
```
Distributed training with 2 processes (1 GPU per process, see **Distributed training** below
for more detail)
```
$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./
```
For best performance, set `--nproc_per_node` equal to the total number of GPUs on the node
to use all available resources.
#### `--opt-level O2` ("fast mixed precision")
`O2` casts the model to FP16, keeps batchnorms in FP32,
maintains master weights in FP32, and implements
dynamic loss scaling by default. (Unlike --opt-level O1, --opt-level O2
does not patch Torch functions.)
```
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./
```
"Fast mixed precision" overridden to use static loss scaling:
```
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 --loss-scale 128.0 ./
```
Distributed training with 2 processes (1 GPU per process)
```
$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./
```
## Distributed training ## Distributed training
`main.py` and `main_fp16_optimizer.py` have been modified to use the `DistributedDataParallel` module in Apex instead of the one in upstream PyTorch. `apex.parallel.DistributedDataParallel` `main_amp.py` optionally uses `apex.parallel.DistributedDataParallel` (DDP) for multiprocess training with one GPU per process.
is a drop-in replacement for `torch.nn.parallel.DistribtuedDataParallel` (see our [distributed example](https://github.com/NVIDIA/apex/tree/master/examples/distributed)).
The scripts can interact with
[torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility)
to spawn multiprocess jobs using the following syntax:
``` ```
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main.py args... model = apex.parallel.DistributedDataParallel(model)
``` ```
`NUM_GPUS` should be less than or equal to the number of visible GPU devices on the node. is a drop-in replacement for
```
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[arg.local_rank],
output_device=arg.local_rank)
```
(because Torch DDP permits multiple GPUs per process, with Torch DDP you are required to
manually specify the device to run on and the output device.
With Apex DDP, it uses only the current device by default).
Optionally one can run imagenet with sync batch normalization by adding The choice of DDP wrapper (Torch or Apex) is orthogonal to the use of Amp and other Apex tools. It is safe to use `apex.amp` with either `torch.nn.parallel.DistributedDataParallel` or `apex.parallel.DistributedDataParallel`. In the future, I may add some features that permit optional tighter integration between `Amp` and `apex.parallel.DistributedDataParallel` for marginal performance benefits, but currently, there's no compelling reason to use Apex DDP versus Torch DDP for most models.
`--sync_bn` into the `args...`
## Example commands To use DDP with `apex.amp`, the only gotcha is that
```
model, optimizer = amp.initialize(model, optimizer, flags...)
```
must precede
```
model = DDP(model)
```
If DDP wrapping occurs before `amp.initialize`, `amp.initialize` will raise an error.
(note: batch size `--b 224` assumes your GPUs have >=16GB of onboard memory) With both Apex DDP and Torch DDP, you must also call `torch.cuda.set_device(args.local_rank)` within
each process prior to initializing your model or any other tensors.
More information can be found in the docs for the
Pytorch multiprocess launcher module [torch.distributed.launch](https://pytorch.org/docs/stable/distributed.html#launch-utility).
```bash `main_amp.py` is written to interact with
### Softlink training dataset into current directory [torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility),
$ ln -sf /data/imagenet/train-jpeg/ train which spawns multiprocess jobs using the following syntax:
### Softlink validation dataset into current directory
$ ln -sf /data/imagenet/val-jpeg/ val
### Single-process training
$ python main.py -a resnet50 --fp16 --b 224 --workers 4 --static-loss-scale 128.0 ./
### Single-process training with Amp. Amp's casting causes it to use a bit more memory,
### hence the batch size 128.
$ python main_amp.py -a resnet50 --fp16 --b 128 --workers 4 ./
### Multi-process training (uses all visible GPUs on the node)
$ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main.py -a resnet50 --fp16 --b 224 --workers 4 --static-loss-scale 128.0 ./
### Multi-process training on GPUs 0 and 1 only
$ export CUDA_VISIBLE_DEVICES=0,1
$ python -m torch.distributed.launch --nproc_per_node=2 main.py -a resnet50 --fp16 --b 224 --workers 4 ./
### Multi-process training with FP16_Optimizer, static loss scale 128.0 (still uses FP32 master params)
$ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main_fp16_optimizer.py -a resnet50 --fp16 --b 224 --static-loss-scale 128.0 --workers 4 ./
### Multi-process training with FP16_Optimizer, dynamic loss scaling
$ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main_fp16_optimizer.py -a resnet50 --fp16 --b 224 --dynamic-loss-scale --workers 4 ./
```
## Usage for `main.py` and `main_fp16_optimizer.py`
`main_fp16_optimizer.py` also accepts the optional flag
```bash
--dynamic-loss-scale Use dynamic loss scaling. If supplied, this argument
supersedes --static-loss-scale.
``` ```
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main_amp.py args...
```
`NUM_GPUS` should be less than or equal to the number of visible GPU devices on the node. The use of `torch.distributed.launch` is unrelated to the choice of DDP wrapper. It is safe to use either apex DDP or torch DDP with `torch.distributed.launch`.
Optionally, one can run imagenet with synchronized batch normalization across processes by adding
`--sync_bn` to the `args...`
## Deterministic training (for debugging purposes)
Running with the `--deterministic` flag should produce bitwise identical outputs run-to-run,
regardless of what other options are used (see [Pytorch docs on reproducibility](https://pytorch.org/docs/stable/notes/randomness.html)).
Since `--deterministic` disables `torch.backends.cudnn.benchmark`, `--deterministic` may
cause a modest performance decrease.
...@@ -20,7 +20,8 @@ import numpy as np ...@@ -20,7 +20,8 @@ import numpy as np
try: try:
from apex.parallel import DistributedDataParallel as DDP from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import * from apex.fp16_utils import *
from apex import amp from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
...@@ -59,8 +60,6 @@ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', ...@@ -59,8 +60,6 @@ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
parser.add_argument('--pretrained', dest='pretrained', action='store_true', parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model') help='use pre-trained model')
parser.add_argument('--fp16', action='store_true',
help='Run model fp16 mode.')
parser.add_argument('--prof', dest='prof', action='store_true', parser.add_argument('--prof', dest='prof', action='store_true',
help='Only run 10 iterations for profiling.') help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true') parser.add_argument('--deterministic', action='store_true')
...@@ -69,6 +68,11 @@ parser.add_argument("--local_rank", default=0, type=int) ...@@ -69,6 +68,11 @@ parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true', parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.') help='enabling apex sync BN.')
parser.add_argument('--has-ext', action='store_true')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
cudnn.benchmark = True cudnn.benchmark = True
def fast_collate(batch): def fast_collate(batch):
...@@ -91,13 +95,17 @@ def fast_collate(batch): ...@@ -91,13 +95,17 @@ def fast_collate(batch):
best_prec1 = 0 best_prec1 = 0
args = parser.parse_args() args = parser.parse_args()
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
if args.deterministic: if args.deterministic:
cudnn.benchmark = False cudnn.benchmark = False
cudnn.deterministic = True cudnn.deterministic = True
torch.manual_seed(args.local_rank) torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10)
# Initialize Amp
amp_handle = amp.init(enabled=args.fp16)
def main(): def main():
global best_prec1, args global best_prec1, args
...@@ -116,8 +124,7 @@ def main(): ...@@ -116,8 +124,7 @@ def main():
init_method='env://') init_method='env://')
args.world_size = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
if args.fp16: assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
# create model # create model
if args.pretrained: if args.pretrained:
...@@ -134,6 +141,24 @@ def main(): ...@@ -134,6 +141,24 @@ def main():
model = model.cuda() model = model.cuda()
# Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# Initialize Amp. Amp accepts either values or strings for the optional override arguments,
# for convenient interoperation with argparse.
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if args.distributed: if args.distributed:
# By default, apex.parallel.DistributedDataParallel overlaps communication with # By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass. # computation in the backward pass.
...@@ -144,12 +169,6 @@ def main(): ...@@ -144,12 +169,6 @@ def main():
# define loss function (criterion) and optimizer # define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().cuda()
# Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# Optionally resume from a checkpoint # Optionally resume from a checkpoint
if args.resume: if args.resume:
# Use a local scope to avoid dangling references # Use a local scope to avoid dangling references
...@@ -242,7 +261,6 @@ class data_prefetcher(): ...@@ -242,7 +261,6 @@ class data_prefetcher():
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
# With Amp, it isn't necessary to manually convert data to half. # With Amp, it isn't necessary to manually convert data to half.
# Type conversions are done internally on the fly within patched torch functions.
# if args.fp16: # if args.fp16:
# self.mean = self.mean.half() # self.mean = self.mean.half()
# self.std = self.std.half() # self.std = self.std.half()
...@@ -259,7 +277,6 @@ class data_prefetcher(): ...@@ -259,7 +277,6 @@ class data_prefetcher():
self.next_input = self.next_input.cuda(non_blocking=True) self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True) self.next_target = self.next_target.cuda(non_blocking=True)
# With Amp, it isn't necessary to manually convert data to half. # With Amp, it isn't necessary to manually convert data to half.
# Type conversions are done internally on the fly within patched torch functions.
# if args.fp16: # if args.fp16:
# self.next_input = self.next_input.half() # self.next_input = self.next_input.half()
# else: # else:
...@@ -276,7 +293,6 @@ class data_prefetcher(): ...@@ -276,7 +293,6 @@ class data_prefetcher():
def train(train_loader, model, criterion, optimizer, epoch): def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter() batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter() losses = AverageMeter()
top1 = AverageMeter() top1 = AverageMeter()
top5 = AverageMeter() top5 = AverageMeter()
...@@ -287,7 +303,7 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -287,7 +303,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
prefetcher = data_prefetcher(train_loader) prefetcher = data_prefetcher(train_loader)
input, target = prefetcher.next() input, target = prefetcher.next()
i = -1 i = 0
while input is not None: while input is not None:
i += 1 i += 1
...@@ -296,55 +312,67 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -296,55 +312,67 @@ def train(train_loader, model, criterion, optimizer, epoch):
if args.prof: if args.prof:
if i > 10: if i > 10:
break break
# measure data loading time
data_time.update(time.time() - end)
# compute output # compute output
if args.prof: torch.cuda.nvtx.range_push("forward")
output = model(input) output = model(input)
if args.prof: torch.cuda.nvtx.range_pop()
loss = criterion(output, target) loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0))
top5.update(to_python_float(prec5), input.size(0))
# compute gradient and do SGD step # compute gradient and do SGD step
optimizer.zero_grad() optimizer.zero_grad()
with amp_handle.scale_loss(loss, optimizer) as scaled_loss: if args.prof: torch.cuda.nvtx.range_push("backward")
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
if args.prof: torch.cuda.nvtx.range_pop()
optimizer.step() # for param in model.parameters():
# print(param.data.double().sum().item(), param.grad.data.double().sum().item())
torch.cuda.synchronize() if args.prof: torch.cuda.nvtx.range_push("step")
# measure elapsed time optimizer.step()
batch_time.update(time.time() - end) if args.prof: torch.cuda.nvtx.range_pop()
end = time.time()
input, target = prefetcher.next() input, target = prefetcher.next()
if args.local_rank == 0 and i % args.print_freq == 0 and i > 1: if i%args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t' # Every print_freq iterations, check the loss accuracy and speed.
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' # For best performance, it doesn't make sense to print these metrics every
'Speed {3:.3f} ({4:.3f})\t' # iteration, since they incur an allreduce and some host<->device syncs.
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' # Measure accuracy
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), # Average loss and accuracy across processes for logging
args.world_size * args.batch_size / batch_time.val, if args.distributed:
args.world_size * args.batch_size / batch_time.avg, reduced_loss = reduce_tensor(loss.data)
batch_time=batch_time, prec1 = reduce_tensor(prec1)
data_time=data_time, loss=losses, top1=top1, top5=top5)) prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
# to_python_float incurs a host<->device sync
losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0))
top5.update(to_python_float(prec5), input.size(0))
torch.cuda.synchronize()
batch_time.update((time.time() - end)/args.print_freq)
end = time.time()
if args.local_rank == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader),
args.world_size*args.batch_size/batch_time.val,
args.world_size*args.batch_size/batch_time.avg,
batch_time=batch_time,
loss=losses, top1=top1, top5=top5))
def validate(val_loader, model, criterion): def validate(val_loader, model, criterion):
...@@ -360,7 +388,7 @@ def validate(val_loader, model, criterion): ...@@ -360,7 +388,7 @@ def validate(val_loader, model, criterion):
prefetcher = data_prefetcher(val_loader) prefetcher = data_prefetcher(val_loader)
input, target = prefetcher.next() input, target = prefetcher.next()
i = -1 i = 0
while input is not None: while input is not None:
i += 1 i += 1
...@@ -387,6 +415,7 @@ def validate(val_loader, model, criterion): ...@@ -387,6 +415,7 @@ def validate(val_loader, model, criterion):
batch_time.update(time.time() - end) batch_time.update(time.time() - end)
end = time.time() end = time.time()
# TODO: Change timings to mirror train().
if args.local_rank == 0 and i % args.print_freq == 0: if args.local_rank == 0 and i % args.print_freq == 0:
print('Test: [{0}/{1}]\t' print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
......
...@@ -12,13 +12,16 @@ TORCH_MAJOR = int(torch.__version__.split('.')[0]) ...@@ -12,13 +12,16 @@ TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 0 and TORCH_MINOR < 4: if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
raise RuntimeError("APEx requires Pytorch 0.4 or newer.\n" + raise RuntimeError("Apex requires Pytorch 0.4 or newer.\n" +
"The latest stable release can be obtained from https://pytorch.org/") "The latest stable release can be obtained from https://pytorch.org/")
cmdclass = {} cmdclass = {}
ext_modules = [] ext_modules = []
if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv: if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
if TORCH_MAJOR == 0:
raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
"found torch.__version__ = {}".format(torch.__version))
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
...@@ -34,12 +37,23 @@ if "--cuda_ext" in sys.argv: ...@@ -34,12 +37,23 @@ if "--cuda_ext" in sys.argv:
sys.argv.remove("--cuda_ext") sys.argv.remove("--cuda_ext")
if torch.utils.cpp_extension.CUDA_HOME is None: if torch.utils.cpp_extension.CUDA_HOME is None:
print("Warning: nvcc is not available. Ignoring --cuda-ext") raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_1 = ['-DVERSION_GE_1_1']
ext_modules.append( ext_modules.append(
CUDAExtension(name='amp_C', CUDAExtension(name='amp_C',
sources=['csrc/scale_check_overflow.cpp', sources=['csrc/amp_C_frontend.cpp',
'csrc/scale_check_overflow_kernel.cu'])) 'csrc/scale_check_overflow_kernel.cu',
'csrc/multi_tensor_scale_kernel.cu'],
extra_compile_args={'cxx': ['-O3'],
'nvcc':['-lineinfo',
'-O3',
'--use_fast_math']}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_adam_cuda', CUDAExtension(name='fused_adam_cuda',
sources=['apex/optimizers/csrc/fused_adam_cuda.cpp', sources=['apex/optimizers/csrc/fused_adam_cuda.cpp',
...@@ -55,10 +69,10 @@ if "--cuda_ext" in sys.argv: ...@@ -55,10 +69,10 @@ if "--cuda_ext" in sys.argv:
CUDAExtension(name='fused_layer_norm_cuda', CUDAExtension(name='fused_layer_norm_cuda',
sources=['apex/normalization/csrc/layer_norm_cuda.cpp', sources=['apex/normalization/csrc/layer_norm_cuda.cpp',
'apex/normalization/csrc/layer_norm_cuda_kernel.cu'], 'apex/normalization/csrc/layer_norm_cuda_kernel.cu'],
extra_compile_args={'cxx': ['-O3',], extra_compile_args={'cxx': ['-O3'] + version_ge_1_1,
'nvcc':['-maxrregcount=50', 'nvcc':['-maxrregcount=50',
'-O3', '-O3',
'--use_fast_math']})) '--use_fast_math'] + version_ge_1_1}))
setup( setup(
name='apex', name='apex',
......
...@@ -76,7 +76,7 @@ class TestCache(unittest.TestCase): ...@@ -76,7 +76,7 @@ class TestCache(unittest.TestCase):
param.grad = None param.grad = None
loss = model(self.x).sum() loss = model(self.x).sum()
self.handle._default_scaler._loss_scale = 1.0 self.handle._default_scaler._loss_scale = 4.0
with self.handle.scale_loss(loss, dummy_optimizer) as scaled_loss: with self.handle.scale_loss(loss, dummy_optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
......
import unittest
import functools as ft
import itertools as it
from apex import amp
import torch
from torch import nn
import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
try:
import amp_C
from amp_C import multi_tensor_scale
from apex.multi_tensor_apply import MultiTensorApply
disabled = False
except ImportError as err:
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
disabled = True
class TestMultiTensorScale(unittest.TestCase):
def setUp(self):
self.scale = 4.0
self.overflow_buf = torch.cuda.IntTensor(1).zero_()
self.ref = torch.cuda.FloatTensor([1.0])
common_init(self)
def tearDown(self):
pass
# The tensor creation here is written for convenience, not speed.
def downscale(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, inplace=False):
self.overflow_buf.zero_()
a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
out_list = []
for i in range(repeat_tensors):
out_list += [a.clone().to(out_type), b.clone().to(out_type)]
if inplace:
in_list = out_list
else:
in_list = [out.clone().to(in_type) for out in out_list]
applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]))
self.assertTrue(self.overflow_buf.item() == 0)
def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False):
self.overflow_buf.zero_()
a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
out_list = []
for i in range(repeat_tensors):
out_list += [a.clone().to(out_type), b.clone().to(out_type)]
if inplace:
in_list = out_list
else:
in_list = [out.clone().to(in_type) for out in out_list]
applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
self.overflow_buf.zero_()
in_list[t][ind] = val
applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
self.assertTrue(self.overflow_buf.item())
# Currently, the fused kernel gives a hard error if you attempt to downscale
# into fp16 output, which imo is the desired behavior. Maybe someday we
# will learn otherwise.
# @unittest.skipIf(disabled, "amp_C is unavailable")
# def test_fp16_to_fp16(self):
# self.downscale(self.fp16, self.fp16, self.fp16_ref)
#
# @unittest.skipIf(disabled, "amp_C is unavailable")
# def test_fp32_to_fp16(self):
# self.downscale(self.fp32, self.fp16, self.fp16_ref)
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fuzz(self):
input_size_pairs = (
(7777*77, 555*555),
(777, 555),
(555, 2048*32+1),
(2048*32+1, 555),
(555, 2048*32),
(2048*32, 555),
(33333, 555),
(555, 33333))
appliers = (
MultiTensorApply(2048*32),
MultiTensorApply(333),
MultiTensorApply(33333))
repeat_tensors = (
1,
55)
for sizea, sizeb in input_size_pairs:
for applier in appliers:
for repeat in repeat_tensors:
for in_type in (torch.float32, torch.float16):
for out_type in (torch.float32, torch.float16):
for inplace in (True, False):
if inplace is True and (out_type is not in_type):
continue
else:
self.downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=inplace)
self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
0, 0, float('nan'), inplace=inplace)
self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
if __name__ == '__main__':
unittest.main()
...@@ -51,6 +51,7 @@ class TestFP16Optimizer(unittest.TestCase): ...@@ -51,6 +51,7 @@ class TestFP16Optimizer(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff) self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_loss_scaling(self): def test_loss_scaling(self):
ref_optim = torch.optim.Adam(self.ref_model.parameters()) ref_optim = torch.optim.Adam(self.ref_model.parameters())
......
import argparse
import torch
parser = argparse.ArgumentParser(description='Compare')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--fused-adam', action='store_true')
args = parser.parse_args()
base_file = str(args.opt_level) + "_" +\
str(args.loss_scale) + "_" +\
str(args.keep_batchnorm_fp32) + "_" +\
str(args.fused_adam)
file_e = "True_" + base_file
file_p = "False_" + base_file
dict_e = torch.load(file_e)
dict_p = torch.load(file_p)
torch.set_printoptions(precision=10)
print(file_e)
print(file_p)
for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):
assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p)
loss_e = dict_e["Loss"][n]
loss_p = dict_p["Loss"][n]
assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p)
print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format(
i_e,
loss_e,
loss_p,
dict_e["Speed"][n],
dict_p["Speed"][n]))
import argparse
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import numpy as np
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--prof', dest='prof', action='store_true',
help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--has-ext', action='store_true')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--fused-adam', action='store_true')
parser.add_argument('--prints-to-process', type=int, default=10)
cudnn.benchmark = True
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
tens = torch.from_numpy(nump_array)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
best_prec1 = 0
args = parser.parse_args()
# Let multi_tensor_applier be the canary in the coalmine
# that verifies if the backend is what we think it is
assert multi_tensor_applier.available == args.has_ext
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10)
def main():
global best_prec1, args
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.gpu = 0
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank % torch.cuda.device_count()
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
model = models.__dict__[args.arch](pretrained=True)
else:
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()
if args.sync_bn:
import apex
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda()
# Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
if args.fused_adam:
optimizer = optimizers.FusedAdam(model.parameters())
else:
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
model, optimizer = amp.initialize(
model, optimizer,
# enabled=False,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)
if args.distributed:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
# Optionally resume from a checkpoint
if args.resume:
# Use a local scope to avoid dangling references
def resume():
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
resume()
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
if(args.arch == "inception_v3"):
crop_size = 299
val_size = 320 # I chose this value arbitrarily, we can adjust.
else:
crop_size = 224
val_size = 256
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
# transforms.ToTensor(), Too slow
# normalize,
]))
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(val_size),
transforms.CenterCrop(crop_size),
]))
train_sampler = None
val_sampler = None
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True,
sampler=val_sampler,
collate_fn=fast_collate)
if args.evaluate:
validate(val_loader, model, criterion)
return
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch)
if args.prof:
break
# evaluate on validation set
prec1 = validate(val_loader, model, criterion)
# remember best prec@1 and save checkpoint
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}, is_best)
class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self.preload()
def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
self.preload()
return input, target
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to train mode
model.train()
end = time.time()
run_info_dict = {"Iteration" : [],
"Loss" : [],
"Speed" : []}
prefetcher = data_prefetcher(train_loader)
input, target = prefetcher.next()
i = -1
while input is not None:
i += 1
# No learning rate warmup for this test, to expose bitwise inaccuracies more quickly
# adjust_learning_rate(optimizer, epoch, i, len(train_loader))
if args.prof:
if i > 10:
break
# measure data loading time
data_time.update(time.time() - end)
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0))
top5.update(to_python_float(prec5), input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
# for param in model.parameters():
# print(param.data.double().sum().item(), param.grad.data.double().sum().item())
# torch.cuda.synchronize()
torch.cuda.nvtx.range_push("step")
optimizer.step()
torch.cuda.nvtx.range_pop()
torch.cuda.synchronize()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
input, target = prefetcher.next()
if i % args.print_freq == 0 and i > 1:
if args.local_rank == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader),
args.world_size * args.batch_size / batch_time.val,
args.world_size * args.batch_size / batch_time.avg,
batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
run_info_dict["Iteration"].append(i)
run_info_dict["Loss"].append(losses.val)
run_info_dict["Speed"].append(args.world_size * args.batch_size / batch_time.val)
if len(run_info_dict["Loss"]) == args.prints_to_process:
if args.local_rank == 0:
torch.save(run_info_dict,
str(args.has_ext) + "_" + str(args.opt_level) + "_" +
str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32) + "_" +
str(args.fused_adam))
quit()
def validate(val_loader, model, criterion):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
prefetcher = data_prefetcher(val_loader)
input, target = prefetcher.next()
i = -1
while input is not None:
i += 1
# compute output
with torch.no_grad():
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0))
top5.update(to_python_float(prec5), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {2:.3f} ({3:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader),
args.world_size * args.batch_size / batch_time.val,
args.world_size * args.batch_size / batch_time.avg,
batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
input, target = prefetcher.next()
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch, step, len_epoch):
"""LR schedule that should yield 76% converged accuracy with batch size 256"""
factor = epoch // 30
if epoch >= 80:
factor = factor + 1
lr = args.lr*(0.1**factor)
"""Warmup"""
if epoch < 5:
lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
# if(args.local_rank == 0):
# print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= args.world_size
return rt
if __name__ == '__main__':
main()
#!/bin/bash
print_banner() {
printf "\n\n\n\e[30m\e[42m$1\e[0m\n\n\n\n"
}
print_banner "Distributed status: $1"
# DATADIR="/home/mcarilli/Desktop/pt18data/apex/examples/imagenet/bare_metal_train_val/"
DATADIR="/opt/home/apex/examples/imagenet/"
if [ "$1" == "single_gpu" ]
then
BASE_CMD="python main_amp.py -a resnet50 --b 128 --workers 4 --deterministic --prints-to-process 5"
fi
if [ "$1" == "distributed" ]
then
BASE_CMD="python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 128 --workers 4 --deterministic --prints-to-process 5"
fi
ADAM_ARGS="--opt-level O2 --keep-batchnorm-fp32 False --fused-adam"
keep_batchnorms=(
""
"--keep-batchnorm-fp32 True"
"--keep-batchnorm-fp32 False"
)
loss_scales=(
""
"--loss-scale 1.0"
"--loss-scale 128.0"
"--loss-scale dynamic"
)
opt_levels=(
"O0"
"O1"
"O2"
"O3"
)
rm True*
rm False*
set -e
print_banner "Installing Apex with --cuda_ext and --cpp_ext"
pushd ../../..
python setup.py install --cuda_ext --cpp_ext
popd
for opt_level in "${opt_levels[@]}"
do
for loss_scale in "${loss_scales[@]}"
do
for keep_batchnorm in "${keep_batchnorms[@]}"
do
if [ "$opt_level" == "O1" ] && [ -n "${keep_batchnorm}" ]
then
print_banner "Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}"
continue
fi
print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR"
set -x
${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR
set +x
done
done
done
# Handle FusedAdam separately due to limited support.
# FusedAdam will not be tested for bitwise accuracy against the Python implementation.
# The L0 tests already do so. These tests are here to ensure that it actually runs,
# and get an idea of performance.
for loss_scale in "${loss_scales[@]}"
do
print_banner "${BASE_CMD} ${ADAM_ARGS} ${loss_scale} --has-ext $DATADIR"
set -x
${BASE_CMD} ${ADAM_ARGS} ${loss_scale} --has-ext $DATADIR
set +x
done
print_banner "Reinstalling apex without extensions"
pushd ../../..
python setup.py install
popd
for opt_level in "${opt_levels[@]}"
do
for loss_scale in "${loss_scales[@]}"
do
for keep_batchnorm in "${keep_batchnorms[@]}"
do
if [ "$opt_level" == "O1" ] && [ -n "${keep_batchnorm}" ]
then
print_banner "Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}"
continue
fi
print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR"
set -x
${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR
set +x
done
done
done
print_banner "Checking for bitwise accuracy between Python-only and cpp/cuda extension installs"
for opt_level in "${opt_levels[@]}"
do
for loss_scale in "${loss_scales[@]}"
do
for keep_batchnorm in "${keep_batchnorms[@]}"
do
echo ""
if [ "$opt_level" == "O1" ] && [ -n "${keep_batchnorm}" ]
then
echo "Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}"
continue
fi
echo "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} [--has-ext] $DATADIR"
set -x
python compare.py --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm}
set +x
done
done
done
print_banner "Reinstalling Apex with --cuda_ext and --cpp_ext"
pushd ../../..
python setup.py install --cuda_ext --cpp_ext
popd
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment