Commit c2b62b7f authored by JR_ZZU's avatar JR_ZZU 🌴
Browse files

delete origin files

parent 2a4864d5
ARG FROM_IMAGE=lcskrishna/rocm-pytorch:rocm3.3_ubuntu16.04_py3.6_pytorch_bfloat16_mgpu
FROM ${FROM_IMAGE}
RUN \
git clone --recursive https://github.com/ROCmSoftwarePlatform/apex.git && \
cd apex && \
python3.6 setup.py install --cpp_ext --cuda_ext
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
# APEX
## 介绍
[Introduction](README_ORIGIN.md)
## 安装
### System Requirements
- Linux.
- Python 3.7, 3.8, 3.9
- (**推荐**) Upgrade pip
```
python3 -m pip install --upgrade pip #--user
```
### 使用pip安装(以dtk-23.04版本为例)
可以在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取最新的 apex Release 版本(需对应 DCU Toolkit 版本与 python 版本)
```bash
python3 -m pip install apex-0.1+git2d8b360.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
```
### 使用源码安装
#### 编译环境准备(以dtk-23.04版本为例)
- 拉取 apex 代码
```
git clone -b dtk-23.04 http://developer.hpccube.com/codes/aicomponent/apex.git
```
-[开发者社区](https://developer.hpccube.com/tool/#sdk) DCU Toolkit 中下载 DTK-23.04 解压至 /opt/ 路径下,并建立软链接
```
cd /opt && ln -s dtk-23.04 dtk
```
- 在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取对应的 pytorch Release 版本(需对应 DCU Toolkit 版本与 python 版本)
```bash
python3 -m pip install torch-1.13.1a0+git4c8a1fe.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
```
- 导入环境变量以及安装必要依赖库
```bash
source /opt/dtk/env.sh
export PYTORCH_ROCM_ARCH="gfx906;gfx926"
MAX_JOBS=16
pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
pip3 install wheel -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
```
#### 编译安装
- 执行编译命令
```shell
cd apex
CXX=hipcc CC=hipcc python3 setup.py --cpp_ext --cuda_ext bdist_wheel
pip install dist/apex*
```
# Introduction
This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually.
The intent of Apex is to make up-to-date utilities available to users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
## [GTC 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/GTC_2019) and [Pytorch DevCon 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/Pytorch_Devcon_2019) Slides
# Contents
## 1. Amp: Automatic Mixed Precision
`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
Users can easily experiment with different pure and mixed precision training modes by supplying
different flags to `amp.initialize`.
[Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
[API Documentation](https://nvidia.github.io/apex/amp.html)
[Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
## 2. Distributed Training
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.
[API Documentation](https://nvidia.github.io/apex/parallel.html)
[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel)
[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed)
The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
### Synchronized Batch Normalization
`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
support synchronized BN.
It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
Synchronous BN has been used in cases where only a small
local minibatch can fit on each GPU.
Allreduced stats increase the effective batch size for the BN layer to the
global batch size across all processes (which, technically, is the correct
formulation).
Synchronous BN has been observed to improve converged accuracy in some of our research models.
### Checkpointing
To properly save and load your `amp` training, we introduce the `amp.state_dict()`, which contains all `loss_scalers` and their corresponding unskipped steps,
as well as `amp.load_state_dict()` to restore these attributes.
In order to get bitwise accuracy, we recommend the following workflow:
```python
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
# Continue training
...
```
Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`.
# Installation
## Containers
NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.
The containers come with all the custom extensions available at the moment.
See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as:
- how to pull a container
- how to run a pulled container
- release notes
## From Source
To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch.
The latest stable release obtainable from https://pytorch.org should also work.
### Rocm
Apex on ROCm supports both python only build and extension build.
Note: Pytorch version recommended is >=1.5 for extension build.
### To install using python only build use the following command in apex folder:
```
python setup.py install
```
### To install using extensions enabled use the following command in apex folder:
```
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
# otherwise
python setup.py install --cpp_ext --cuda_ext
```
Note that using --cuda_ext flag to install Apex will also enable all the extensions supported on ROCm including "--distributed_adam", "--distributed_lamb", "--bnp", "--xentropy", "--deprecated_fused_adam", "--deprecated_fused_lamb", and "--fast_multihead_attn".
### Linux
For performance and full functionality, we recommend installing Apex with
CUDA and C++ extensions via
```bash
git clone https://github.com/NVIDIA/apex
cd apex
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
# otherwise
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
Apex also supports a Python-only build via
```bash
pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./
```
A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`.
- Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`.
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
### [Experimental] Windows
`pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
\ No newline at end of file
Under construction...
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import math
def is_iterable(maybe_iterable):
return isinstance(maybe_iterable, list) or isinstance(maybe_iterable, tuple)
def flatten_list(tens_list):
"""
flatten_list
"""
if not is_iterable(tens_list):
return tens_list
return torch.cat(tens_list, dim=0).view(len(tens_list), *tens_list[0].size() )
#These modules always assumes batch_first
class bidirectionalRNN(nn.Module):
"""
bidirectionalRNN
"""
def __init__(self, inputRNN, num_layers=1, dropout = 0):
super(bidirectionalRNN, self).__init__()
self.dropout = dropout
self.fwd = stackedRNN(inputRNN, num_layers=num_layers, dropout = dropout)
self.bckwrd = stackedRNN(inputRNN.new_like(), num_layers=num_layers, dropout = dropout)
self.rnns = nn.ModuleList([self.fwd, self.bckwrd])
#collect hidden option will return all hidden/cell states from entire RNN
def forward(self, input, collect_hidden=False):
"""
forward()
"""
seq_len = input.size(0)
bsz = input.size(1)
fwd_out, fwd_hiddens = list(self.fwd(input, collect_hidden = collect_hidden))
bckwrd_out, bckwrd_hiddens = list(self.bckwrd(input, reverse=True, collect_hidden = collect_hidden))
output = torch.cat( [fwd_out, bckwrd_out], -1 )
hiddens = tuple( torch.cat(hidden, -1) for hidden in zip( fwd_hiddens, bckwrd_hiddens) )
return output, hiddens
def reset_parameters(self):
"""
reset_parameters()
"""
for rnn in self.rnns:
rnn.reset_parameters()
def init_hidden(self, bsz):
"""
init_hidden()
"""
for rnn in self.rnns:
rnn.init_hidden(bsz)
def detach_hidden(self):
"""
detach_hidden()
"""
for rnn in self.rnns:
rnn.detachHidden()
def reset_hidden(self, bsz):
"""
reset_hidden()
"""
for rnn in self.rnns:
rnn.reset_hidden(bsz)
def init_inference(self, bsz):
"""
init_inference()
"""
for rnn in self.rnns:
rnn.init_inference(bsz)
#assumes hidden_state[0] of inputRNN is output hidden state
#constructor either takes an RNNCell or list of RNN layers
class stackedRNN(nn.Module):
"""
stackedRNN
"""
def __init__(self, inputRNN, num_layers=1, dropout=0):
super(stackedRNN, self).__init__()
self.dropout = dropout
if isinstance(inputRNN, RNNCell):
self.rnns = [inputRNN]
for i in range(num_layers-1):
self.rnns.append(inputRNN.new_like(inputRNN.output_size))
elif isinstance(inputRNN, list):
assert len(inputRNN) == num_layers, "RNN list length must be equal to num_layers"
self.rnns=inputRNN
else:
raise RuntimeError()
self.nLayers = len(self.rnns)
self.rnns = nn.ModuleList(self.rnns)
'''
Returns output as hidden_state[0] Tensor([sequence steps][batch size][features])
If collect hidden will also return Tuple(
[n_hidden_states][sequence steps] Tensor([layer][batch size][features])
)
If not collect hidden will also return Tuple(
[n_hidden_states] Tensor([layer][batch size][features])
'''
def forward(self, input, collect_hidden=False, reverse=False):
"""
forward()
"""
seq_len = input.size(0)
bsz = input.size(1)
inp_iter = reversed(range(seq_len)) if reverse else range(seq_len)
hidden_states = [[] for i in range(self.nLayers)]
outputs = []
for seq in inp_iter:
for layer in range(self.nLayers):
if layer == 0:
prev_out = input[seq]
outs = self.rnns[layer](prev_out)
if collect_hidden:
hidden_states[layer].append(outs)
elif seq == seq_len-1:
hidden_states[layer].append(outs)
prev_out = outs[0]
outputs.append(prev_out)
if reverse:
outputs = list(reversed(outputs))
'''
At this point outputs is in format:
list( [seq_length] x Tensor([bsz][features]) )
need to convert it to:
list( Tensor([seq_length][bsz][features]) )
'''
output = flatten_list(outputs)
'''
hidden_states at this point is in format:
list( [layer][seq_length][hidden_states] x Tensor([bsz][features]) )
need to convert it to:
For not collect hidden:
list( [hidden_states] x Tensor([layer][bsz][features]) )
For collect hidden:
list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
'''
if not collect_hidden:
seq_len = 1
n_hid = self.rnns[0].n_hidden_states
new_hidden = [ [ [ None for k in range(self.nLayers)] for j in range(seq_len) ] for i in range(n_hid) ]
for i in range(n_hid):
for j in range(seq_len):
for k in range(self.nLayers):
new_hidden[i][j][k] = hidden_states[k][j][i]
hidden_states = new_hidden
#Now in format list( [hidden_states][seq_length][layer] x Tensor([bsz][features]) )
#Reverse seq_length if reverse
if reverse:
hidden_states = list( list(reversed(list(entry))) for entry in hidden_states)
#flatten layer dimension into tensor
hiddens = list( list(
flatten_list(seq) for seq in hidden )
for hidden in hidden_states )
#Now in format list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
#Remove seq_length dimension if not collect_hidden
if not collect_hidden:
hidden_states = list( entry[0] for entry in hidden_states)
return output, hidden_states
def reset_parameters(self):
"""
reset_parameters()
"""
for rnn in self.rnns:
rnn.reset_parameters()
def init_hidden(self, bsz):
"""
init_hidden()
"""
for rnn in self.rnns:
rnn.init_hidden(bsz)
def detach_hidden(self):
"""
detach_hidden()
"""
for rnn in self.rnns:
rnn.detach_hidden()
def reset_hidden(self, bsz):
"""
reset_hidden()
"""
for rnn in self.rnns:
rnn.reset_hidden(bsz)
def init_inference(self, bsz):
"""
init_inference()
"""
for rnn in self.rnns:
rnn.init_inference(bsz)
class RNNCell(nn.Module):
"""
RNNCell
gate_multiplier is related to the architecture you're working with
For LSTM-like it will be 4 and GRU-like will be 3.
Always assumes input is NOT batch_first.
Output size that's not hidden size will use output projection
Hidden_states is number of hidden states that are needed for cell
if one will go directly to cell as tensor, if more will go as list
"""
def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_states = 2, bias = False, output_size = None):
super(RNNCell, self).__init__()
self.gate_multiplier = gate_multiplier
self.input_size = input_size
self.hidden_size = hidden_size
self.cell = cell
self.bias = bias
self.output_size = output_size
if output_size is None:
self.output_size = hidden_size
self.gate_size = gate_multiplier * self.hidden_size
self.n_hidden_states = n_hidden_states
self.w_ih = nn.Parameter(torch.empty(self.gate_size, self.input_size))
self.w_hh = nn.Parameter(torch.empty(self.gate_size, self.output_size))
#Check if there's recurrent projection
if(self.output_size != self.hidden_size):
self.w_ho = nn.Parameter(torch.empty(self.output_size, self.hidden_size))
self.b_ih = self.b_hh = None
if self.bias:
self.b_ih = nn.Parameter(torch.empty(self.gate_size))
self.b_hh = nn.Parameter(torch.empty(self.gate_size))
#hidden states for forward
self.hidden = [ None for states in range(self.n_hidden_states)]
self.reset_parameters()
def new_like(self, new_input_size=None):
"""
new_like()
"""
if new_input_size is None:
new_input_size = self.input_size
return type(self)(self.gate_multiplier,
new_input_size,
self.hidden_size,
self.cell,
self.n_hidden_states,
self.bias,
self.output_size)
#Use xavier where we can (weights), otherwise use uniform (bias)
def reset_parameters(self, gain=1):
"""
reset_parameters()
"""
stdev = 1.0 / math.sqrt(self.hidden_size)
for param in self.parameters():
param.data.uniform_(-stdev, stdev)
'''
Xavier reset:
def reset_parameters(self, gain=1):
stdv = 1.0 / math.sqrt(self.gate_size)
for param in self.parameters():
if (param.dim() > 1):
torch.nn.init.xavier_normal(param, gain)
else:
param.data.uniform_(-stdv, stdv)
'''
def init_hidden(self, bsz):
"""
init_hidden()
"""
for param in self.parameters():
if param is not None:
a_param = param
break
for i, _ in enumerate(self.hidden):
if(self.hidden[i] is None or self.hidden[i].data.size()[0] != bsz):
if i==0:
hidden_size = self.output_size
else:
hidden_size = self.hidden_size
tens = a_param.data.new(bsz, hidden_size).zero_()
self.hidden[i] = Variable(tens, requires_grad=False)
def reset_hidden(self, bsz):
"""
reset_hidden()
"""
for i, _ in enumerate(self.hidden):
self.hidden[i] = None
self.init_hidden(bsz)
def detach_hidden(self):
"""
detach_hidden()
"""
for i, _ in enumerate(self.hidden):
if self.hidden[i] is None:
raise RuntimeError("Must initialize hidden state before you can detach it")
for i, _ in enumerate(self.hidden):
self.hidden[i] = self.hidden[i].detach()
def forward(self, input):
"""
forward()
if not inited or bsz has changed this will create hidden states
"""
self.init_hidden(input.size()[0])
hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
self.hidden = self.cell(input, hidden_state, self.w_ih, self.w_hh, b_ih=self.b_ih, b_hh=self.b_hh)
if(self.n_hidden_states > 1):
self.hidden = list(self.hidden)
else:
self.hidden=[self.hidden]
if self.output_size != self.hidden_size:
self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
return tuple(self.hidden)
from .models import LSTM, GRU, ReLU, Tanh, mLSTM
__all__ = ['models']
import torch
import torch.nn as nn
import torch.nn.functional as F
from .RNNBackend import RNNCell
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
import math
class mLSTMRNNCell(RNNCell):
"""
mLSTMRNNCell
"""
def __init__(self, input_size, hidden_size, bias = False, output_size = None):
gate_multiplier = 4
super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size)
self.w_mih = nn.Parameter(torch.empty(self.output_size, self.input_size))
self.w_mhh = nn.Parameter(torch.empty(self.output_size, self.output_size))
self.reset_parameters()
def forward(self, input):
"""
mLSTMRNNCell.forward()
"""
#if not inited or bsz has changed this will create hidden states
self.init_hidden(input.size()[0])
hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
self.hidden = list(
self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_mih, self.w_mhh,
b_ih=self.b_ih, b_hh=self.b_hh)
)
if self.output_size != self.hidden_size:
self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
return tuple(self.hidden)
def new_like(self, new_input_size=None):
if new_input_size is None:
new_input_size = self.input_size
return type(self)(
new_input_size,
self.hidden_size,
self.bias,
self.output_size)
def mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None):
"""
mLSTMCell
"""
if input.is_cuda:
igates = F.linear(input, w_ih)
m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)
hgates = F.linear(m, w_hh)
state = fusedBackend.LSTMFused.apply
return state(igates, hgates, hidden[1], b_ih, b_hh)
hx, cx = hidden
m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)
gates = F.linear(input, w_ih, b_ih) + F.linear(m, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy
import torch
from torch.nn._functions.rnn import LSTMCell, RNNReLUCell, RNNTanhCell, GRUCell
from .RNNBackend import bidirectionalRNN, stackedRNN, RNNCell
from .cells import mLSTMRNNCell, mLSTMCell
def toRNNBackend(inputRNN, num_layers, bidirectional=False, dropout = 0):
"""
:class:`toRNNBackend`
"""
if bidirectional:
return bidirectionalRNN(inputRNN, num_layers, dropout = dropout)
else:
return stackedRNN(inputRNN, num_layers, dropout = dropout)
def LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
"""
:class:`LSTM`
"""
inputRNN = RNNCell(4, input_size, hidden_size, LSTMCell, 2, bias, output_size)
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
def GRU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
"""
:class:`GRU`
"""
inputRNN = RNNCell(3, input_size, hidden_size, GRUCell, 1, bias, output_size)
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
def ReLU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
"""
:class:`ReLU`
"""
inputRNN = RNNCell(1, input_size, hidden_size, RNNReLUCell, 1, bias, output_size)
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
def Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
"""
:class:`Tanh`
"""
inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size)
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
def mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
"""
:class:`mLSTM`
"""
inputRNN = mLSTMRNNCell(input_size, hidden_size, bias=bias, output_size=output_size)
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
import logging
import warnings
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch
if torch.distributed.is_available():
from . import parallel
from . import amp
from . import fp16_utils
# For optimizers and normalization there is no Python fallback.
# Absence of cuda backend is a hard error.
# I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda
# to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext
# so they expect those backends to be available, but for some reason they actually aren't
# available (for example because they built improperly in a way that isn't revealed until
# load time) the error message is timely and visible.
from . import optimizers
from . import normalization
from . import transformer
# Logging utilities for apex.transformer module
class RankInfoFormatter(logging.Formatter):
def format(self, record):
from apex.transformer.parallel_state import get_rank_info
record.rank_info = get_rank_info()
return super().format(record)
_library_root_logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S"))
_library_root_logger.addHandler(handler)
_library_root_logger.propagate = False
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
cudnn_available = torch.backends.cudnn.is_available()
cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
warnings.warn(
f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, "
f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
)
return False
return True
try:
from .version import version, git_hash, git_branch, dtk, abi, torch_version, dcu_version # noqa: F401
__version__, __dcu_version__ = version, dcu_version
except ImportError:
pass
from typing import Optional, Sequence
import torch
def _get_autocast_dtypes() -> Sequence[torch.dtype]:
if torch.cuda.is_bf16_supported():
return [torch.half, torch.bfloat16]
return [torch.half]
def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
if not torch.is_autocast_enabled():
return torch.float or dtype
else:
return torch.get_autocast_gpu_dtype()
def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
else:
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
# amp: Automatic Mixed Precision
## Annotating User Functions
Nearly all PyTorch user code needs nothing more than the two steps
above to use amp. After all, custom layers are built out of simpler
PyTorch components, and amp already can see those.
However, any custom C++ or CUDA code is outside of amp's (default)
view of things. For example, suppose I implemented a new recurrent
cell called a "forgetful recurrent unit" that calls directly into a
CUDA backend:
```python
from backend import FRUBackend
def fru(input, hidden, weight, bias):
# call to CUDA code
FRUBackend(input, hidden, weight, bias)
```
In this case, it is possible to get a runtime type mismatch. For
example, you might have `input` in fp16, and `weight` in fp32, and amp
doesn't have the visibility to insert an appropriate cast.
amp exposes two ways to handle "invisible" backend code: function
annotations and explicit registration.
#### Function annotation
The first way to handle backend code is a set of function annotations:
- `@amp.half_function`
- `@amp.float_function`
- `@amp.promote_function`
These correspond to:
- Cast all arguments to fp16
- Cast all argumnets fo fp32
- If there are any type mismatches, cast everything to the widest type
In our example, we believe that the FRU unit is fp16-safe and will get
performance gains from casting its arguments to fp16, so we write:
```python
@amp.half_function
def fru(input, hidden, weight, bias):
#...
```
#### Explicit registration
The other way to handle backend code is with explicit function
registration:
- `amp.register_half_function(module, function_name)`
- `amp.register_float_function(module, function_name)`
- `amp.register_promote_function(module, function_name)`
When using this API, `module` is the containing class or module for
the function, and `function_name` is the _string_ name of the
function. Note that the function must be registered before the call to
`amp.initalize()`.
For our FRU unit, we can register the backend function directly:
```python
import backend
amp.register_half_function(backend, 'FRUBackend')
```
from .amp import init, half_function, bfloat16_function, float_function, promote_function,\
register_half_function, register_bfloat16_function, register_float_function, register_promote_function
from .handle import scale_loss, disable_casts
from .frontend import initialize, state_dict, load_state_dict
from ._amp_state import master_params, _amp_state
VERSION = (0, 1, 0)
__version__ = '.'.join(map(str, VERSION))
# This is a "header object" that allows different amp modules to communicate.
# I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like.
# But apparently it's ok:
# http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm
import torch
class AmpState(object):
def __init__(self):
self.hard_override=False
self.allow_incoming_model_not_fp32 = False
self.verbosity=1
# Attribute stash. Could also just stash things as global module attributes.
_amp_state = AmpState()
def warn_or_err(msg):
if _amp_state.hard_override:
print("Warning: " + msg)
else:
raise RuntimeError(msg)
# I'm not sure if allowing hard_override is a good idea.
# + " If you're sure you know what you're doing, supply " +
# "hard_override=True to amp.initialize.")
def maybe_print(msg, rank0=False):
distributed = torch.distributed.is_available() and \
torch.distributed.is_initialized() and \
torch.distributed.get_world_size() > 1
if _amp_state.verbosity > 0:
if rank0:
if distributed:
if torch.distributed.get_rank() == 0:
print(msg)
else:
print(msg)
else:
print(msg)
# def iter_params(param_groups):
# for group in param_groups:
# for p in group['params']:
# yield p
def master_params(optimizer):
"""
Generator expression that iterates over the params owned by ``optimizer``.
Args:
optimizer: An optimizer previously returned from ``amp.initialize``.
"""
for group in optimizer.param_groups:
for p in group['params']:
yield p
import collections.abc as container_abcs
from types import MethodType
import functools
import sys
import warnings
import numpy as np
import torch
from ._amp_state import _amp_state, warn_or_err
from .handle import disable_casts
from .scaler import LossScaler
from ._process_optimizer import _process_optimizer
from apex.fp16_utils import convert_network
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..contrib.optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
if torch.distributed.is_available():
from ..parallel import DistributedDataParallel as apex_DDP
from ..parallel.LARC import LARC
def to_type(dtype, t):
if isinstance(t, torch.Tensor):
if not t.is_cuda:
# This should not be a hard error, since it may be legitimate.
warnings.warn("An input tensor was not cuda.")
# GANs require this.
# if t.requires_grad:
# warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
# "its gradients will not be properly allreduced by DDP.")
if t.is_floating_point():
return t.to(dtype)
return t
else:
# Trust the user's custom batch type, that's all I can do here.
return t.to(dtype)
# Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py.
def applier(value, fn):
if isinstance(value, torch.Tensor):
return fn(value)
elif isinstance(value, str):
return value
elif isinstance(value, np.ndarray):
return value
elif hasattr(value, "to"): # Allow handling of custom batch classes
return fn(value)
elif isinstance(value, container_abcs.Mapping):
return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(applier(v, fn) for v in value)
else:
# Do I want this to fire off even if someone chooses to pass something ordinary like
# an int or float? May be more annoying than it's worth.
# print("Warning: unrecognized type in applier. If your input data is a custom class, "
# "provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. "
# "Amp will check for your custom to() and invoke it to cast the batch's "
# "floating-point Tensors to the appropriate type. "
# "Also, if your data is a custom class, it is your responsibility to ensure that "
# "any Tensors you want to be cuda are already cuda."
return value
def check_models(models):
for model in models:
parallel_type = None
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
parallel_type = "torch.nn.parallel.DistributedDataParallel"
if ('apex_DDP' in sys.modules) and isinstance(model, apex_DDP):
parallel_type = "apex.parallel.DistributedDataParallel"
if isinstance(model, torch.nn.parallel.DataParallel):
parallel_type = "torch.nn.parallel.DataParallel"
if parallel_type is not None:
raise RuntimeError("Incoming model is an instance of {}. ".format(parallel_type) +
"Parallel wrappers should only be applied to the model(s) AFTER \n"
"the model(s) have been returned from amp.initialize.")
def check_params_fp32(models):
for model in models:
for name, param in model.named_parameters():
if param.is_floating_point():
if 'Half' in param.type() or 'BFloat16' in param.type():
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() or .bfloat16()\n"
"on your model before passing it, no matter what optimization level you choose.".format(
name, param.type()))
elif not param.is_cuda:
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you need to provide a model with parameters\n"
"located on a CUDA device before passing it no matter what optimization level\n"
"you chose. Use model.to('cuda') to use the default device.".format(
name, param.type()))
# Backward compatibility for PyTorch 0.4
if hasattr(model, 'named_buffers'):
buf_iter = model.named_buffers()
else:
buf_iter = model._buffers
for obj in buf_iter:
if type(obj)==tuple:
name, buf = obj
else:
name, buf = obj, buf_iter[obj]
if buf.is_floating_point():
if 'Half' in buf.type():
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, buf.type()))
elif not buf.is_cuda:
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you need to provide a model with buffers\n"
"located on a CUDA device before passing it no matter what optimization level\n"
"you chose. Use model.to('cuda') to use the default device.".format(
name, buf.type()))
def check_optimizers(optimizers):
for optim in optimizers:
bad_optim_type = None
if isinstance(optim, FP16_Optimizer_general):
bad_optim_type = "apex.fp16_utils.FP16_Optimizer"
if isinstance(optim, FP16_Optimizer_for_fused):
bad_optim_type = "apex.optimizers.FP16_Optimizer"
if bad_optim_type is not None:
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) +
"The optimizer(s) passed to amp.initialize() must be bare \n"
"instances of either ordinary Pytorch optimizers, or Apex fused \n"
"optimizers.\n")
class O2StateDictHook(object):
def __init__(self, fn):
self.fn = fn
def __call__(self, module, state_dict, prefix, local_metadata):
for key in state_dict:
param = state_dict[key]
if 'Half' in param.type() or 'BFloat16' in param.type():
param = param.to(torch.float32)
state_dict[key] = param
def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
from .amp import init as amp_init
optimizers_was_list = False
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):
optimizers = [optimizers]
elif optimizers is None:
optimizers = []
elif isinstance(optimizers, list):
optimizers_was_list = True
check_optimizers(optimizers)
else:
check_optimizers([optimizers])
raise TypeError("optimizers must be either a single optimizer or a list of optimizers.")
if isinstance(models, torch.nn.Module):
models_was_list = False
models = [models]
elif isinstance(models, list):
models_was_list = True
else:
raise TypeError("models must be either a single model or a list of models.")
check_models(models)
if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models)
# In the future, when FP16_Optimizer can be deprecated and master weights can
# become an attribute, remember to stash master weights before casting the model.
if properties.cast_model_type:
if properties.keep_batchnorm_fp32:
for model in models:
convert_network(model, properties.cast_model_type)
else:
for model in models:
model.to(properties.cast_model_type)
input_caster = functools.partial(to_type, properties.cast_model_type)
if cast_model_outputs is not None:
output_caster = functools.partial(to_type, cast_model_outputs)
else:
output_caster = functools.partial(to_type, torch.float32)
for model in models:
# Patch the forward method to cast incoming data to the correct type, and
# outgoing data to float32, so "the user never needs to call .half()/.bfloat16()."
# I like writing things explicitly more than decorators.
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
output = old_fwd(*applier(args, input_caster),
**applier(kwargs, input_caster))
return applier(output, output_caster)
return new_fwd
model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict())
# patch model.state_dict() to return float32 params
for model in models:
for module in model.modules():
module._register_state_dict_hook(O2StateDictHook(functools.partial(to_type, torch.float32)))
elif cast_model_outputs is not None:
output_caster = functools.partial(to_type, cast_model_outputs)
for model in models:
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
output = old_fwd(*args, **kwargs)
return applier(output, output_caster)
return new_fwd
model.forward = patch_forward(model.forward)
for i, optimizer in enumerate(optimizers):
optimizers[i] = _process_optimizer(optimizer, properties)
_amp_state.loss_scalers = []
for _ in range(num_losses):
_amp_state.loss_scalers.append(LossScaler(properties.loss_scale,
min_loss_scale=_amp_state.min_loss_scale,
max_loss_scale=_amp_state.max_loss_scale))
if properties.patch_torch_functions:
# handle is unused here. It's accessible later through a global value anyway.
handle = amp_init(loss_scale=properties.loss_scale,
patch_type=properties.patch_torch_functions_type,
verbose=(_amp_state.verbosity == 2))
for optimizer in optimizers:
# Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway.
def patch_step(old_step):
def new_step(self, *args, **kwargs):
with disable_casts():
output = old_step(*args, **kwargs)
return output
return new_step
optimizer.step = MethodType(patch_step(optimizer.step), optimizer)
if optimizers_was_list:
if models_was_list:
return models, optimizers
else:
return models[0], optimizers
else:
if models_was_list:
if len(optimizers) == 0:
return models
else:
return models, optimizers[0]
else:
if len(optimizers) == 0:
return models[0]
else:
return models[0], optimizers[0]
import types
from ..fp16_utils import master_params_to_model_params
from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import maybe_print, _amp_state
import torch
from ..optimizers import FusedSGD
class AmpOptimizerState(object):
def __init__(self):
pass
def _master_params_to_model_params(self):
stash = self._amp_stash
if multi_tensor_applier.available:
if len(stash.all_fp16_params) > 0:
multi_tensor_applier(
stash.multi_tensor_scale,
stash.dummy_overflow_buf,
[stash.all_fp32_from_fp16_params, stash.all_fp16_params],
1.0)
else:
for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
def lazy_init_with_master_weights(self):
stash = self._amp_stash
stash.fp16_groups = []
stash.fp32_from_fp16_groups = []
stash.fp32_from_fp32_groups = []
for i, param_group in enumerate(self.param_groups):
# maybe_print("FP16_Optimizer processing param group {}:".format(i))
fp16_params_this_group = []
fp32_params_this_group = []
fp32_from_fp16_params_this_group = []
for i, param in enumerate(param_group['params']):
if param.requires_grad:
if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
# .format(param.size()))
fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param.
# We still need to recast per-param state tensors, if any, to FP32.
if param in self.state:
self.state[master_param] = self.state.pop(param)
elif param.type() == 'torch.cuda.FloatTensor':
# maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
# .format(param.size()))
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError("Optimizer's parameters must one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type()))
stash.fp16_groups.append(fp16_params_this_group)
stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
stash.fp32_from_fp32_groups.append(fp32_params_this_group)
stash.all_fp16_params = []
for group in stash.fp16_groups:
stash.all_fp16_params += group
stash.all_fp32_from_fp16_params = []
for group in stash.fp32_from_fp16_groups:
stash.all_fp32_from_fp16_params += group
stash.all_fp32_from_fp32_params = []
for group in stash.fp32_from_fp32_groups:
stash.all_fp32_from_fp32_params += group
# all_fp16_grad_stash is only needed for fused optimizers.
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]
for param in stash.all_fp32_from_fp16_params:
param.grad = None
for param in stash.all_fp32_from_fp32_params:
param.grad = None
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
self.load_state_dict(self.state_dict())
def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None):
grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0
# not much to do if scale == 1.0 and static scaling
if scaler.loss_scale() == 1.0 and not scaler.dynamic:
# Clear the stash.
for i in range(len(stashed_grads)):
stashed_grads[i] = None
return
if scale_override is not None:
grads_have_scale, stashed_have_scale, out_scale = scale_override
# This is a lot of python overhead...
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(params, stashed_grads):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None
continue
# unscale() implements grads*(1/scale), so "scale" should be grads_have_scale/out_scale.
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
None, # unused_scale, currently present to avoid API breakage elsewhere
models_are_masters=True,
scale_override=grads_have_scale/out_scale)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash,
scale_override=(grads_have_scale, stashed_have_scale, out_scale))
# Clear the stash.
for i in range(len(stashed_grads)):
stashed_grads[i] = None
def prepare_backward_with_master_weights(self):
stash = self._amp_stash
self._amp_lazy_init()
for i, param in enumerate(stash.all_fp16_params):
# Set up to leverage grad copy elision.
# This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.
param.grad = None
# for i, param in enumerate(stash.all_fp32_from_fp16_params):
# stash.all_fp32_from_fp16_grad_stash[i] = param.grad
for i, param in enumerate(stash.all_fp32_from_fp32_params):
stash.all_fp32_from_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
def post_backward_with_master_weights(self, scaler):
stash = self._amp_stash
self._amp_lazy_init()
# This is a lot of python overhead...
fp16_grads_needing_unscale = []
new_fp32_grads = []
fp16_grads_needing_unscale_with_stash = []
preexisting_fp32_grads = []
for fp16_param, fp32_param in zip(stash.all_fp16_params,
stash.all_fp32_from_fp16_params):
if fp16_param.grad is None and fp32_param.grad is not None:
continue
elif fp16_param.grad is not None and fp32_param.grad is None:
fp32_param.grad = torch.empty_like(fp32_param)
fp16_grads_needing_unscale.append(fp16_param.grad)
new_fp32_grads.append(fp32_param.grad)
elif fp16_param.grad is not None and fp32_param.grad is not None:
fp16_grads_needing_unscale_with_stash.append(fp16_param.grad)
preexisting_fp32_grads.append(fp32_param.grad)
else: # fp16_param.grad is None and fp32_param.grad is None:
continue
if len(fp16_grads_needing_unscale) > 0:
scaler.unscale(
fp16_grads_needing_unscale,
new_fp32_grads,
scaler.loss_scale(),
models_are_masters=False)
if len(fp16_grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
fp16_grads_needing_unscale_with_stash,
preexisting_fp32_grads,
preexisting_fp32_grads)
# fp32 params can be treated as they would be in the "no_master_weights" case.
post_backward_models_are_masters(
scaler,
stash.all_fp32_from_fp32_params,
stash.all_fp32_from_fp32_grad_stash)
def lazy_init_no_master_weights(self):
stash = self._amp_stash
stash.all_fp16_params = []
stash.all_fp32_params = []
for i, param_group in enumerate(self.param_groups):
for i, param in enumerate(param_group['params']):
if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
stash.all_fp16_params.append(param)
elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param)
else:
raise TypeError("Optimizer's parameters must be one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.BFloat16Tensor. "
"Received {}".format(param.type()))
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]
def prepare_backward_no_master_weights(self):
stash = self._amp_stash
self._amp_lazy_init()
for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
for i, param in enumerate(stash.all_fp32_params):
stash.all_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
def post_backward_no_master_weights(self, scaler):
stash = self._amp_stash
self._amp_lazy_init()
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_params, stash.all_fp32_grad_stash))
for params, stashed_grads in split_types:
post_backward_models_are_masters(scaler, params, stashed_grads)
#####################################################################################
# FusedSGD versions
#####################################################################################
# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
# outside the kernel, so we must accumulate directly into the model grads.
def prepare_backward_with_master_weights_FusedSGD(self):
if self.materialize_master_grads:
prepare_backward_with_master_weights(self)
else:
stash = self._amp_stash
self._amp_lazy_init()
for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
for i, param in enumerate(stash.all_fp32_from_fp32_params):
stash.all_fp32_from_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
def post_backward_with_master_weights_FusedSGD(self, scaler):
if self.materialize_master_grads:
post_backward_with_master_weights(self, scaler)
else:
stash = self._amp_stash
self._amp_lazy_init()
grads_have_scale = scaler.loss_scale()
stashed_have_scale = self.most_recent_scale
out_scale = grads_have_scale
if self.scale_set_by_backward:
out_scale = min(grads_have_scale, self.most_recent_scale)
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
# unscale_with_stashed() implements grads*1/scale + stashed_grads*1.
# stashed_grads are scaled by self.most_recent_scale.
for params, stashed_grads in split_types:
post_backward_models_are_masters(scaler, params, stashed_grads,
(grads_have_scale, stashed_have_scale, out_scale))
self.most_recent_scale = out_scale
self.scale_set_by_backward = True
def prepare_backward_no_master_weights_FusedSGD(self):
prepare_backward_no_master_weights(self)
def post_backward_no_master_weights_FusedSGD(self, scaler):
post_backward_no_master_weights(self, scaler)
def _amp_lazy_init(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def _process_optimizer(optimizer, properties):
if hasattr(optimizer, "_amp_stash"):
raise RuntimeError("A given optimizer should only be passed through amp.initialize once.")
else:
optimizer._amp_stash = AmpOptimizerState()
optimizer._amp_stash.lazy_init_called = False
optimizer._amp_stash.already_patched = False
optimizer._amp_stash.params_have_scaled_gradients = False
for name in ("_lazy_init_maybe_master_weights",
"_master_params_to_model_params",
"_prepare_amp_backward",
"_post_amp_backward",
"_amp_lazy_init"):
if hasattr(optimizer, name):
raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
# TODO: Centralize exposure and import error checking for the C backend.
if multi_tensor_applier.available:
import amp_C
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);
if properties.master_weights:
optimizer._lazy_init_maybe_master_weights = types.MethodType(
lazy_init_with_master_weights, optimizer)
optimizer._master_params_to_model_params = types.MethodType(
_master_params_to_model_params, optimizer)
old_step = optimizer.step
def new_step(self, closure=None):
if closure is not None:
raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
retval = old_step()
if not isinstance(self, FusedSGD):
self._master_params_to_model_params()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in self._amp_stash.all_fp32_from_fp16_params:
param.grad = None
return retval
optimizer.step = types.MethodType(new_step, optimizer)
old_zero_grad = optimizer.zero_grad
def new_zero_grad(self):
stash = self._amp_stash
self._amp_lazy_init()
# Zero the model grads.
for param in stash.all_fp16_params:
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()
for param in stash.all_fp32_from_fp32_params:
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()
# Clear the master grads that are independent of model grads
for param in self._amp_stash.all_fp32_from_fp16_params:
param.grad = None
optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)
if isinstance(optimizer, FusedSGD):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights_FusedSGD, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights_FusedSGD, optimizer)
else:
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights, optimizer)
else:
optimizer._lazy_init_maybe_master_weights = types.MethodType(
lazy_init_no_master_weights, optimizer)
if isinstance(optimizer, FusedSGD):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights_FusedSGD, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights_FusedSGD, optimizer)
else:
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights, optimizer)
optimizer._amp_lazy_init = types.MethodType(_amp_lazy_init, optimizer)
old_add_param_group = optimizer.add_param_group
def new_add_param_group(self, new_group):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
assert isinstance(new_group, dict), "param group must be a dict"
new_params = new_group['params']
if isinstance(new_params, torch.Tensor):
new_group['params'] = [new_params]
elif isinstance(new_params, set):
raise TypeError('optimizer parameters need to be organized in ordered collections, but '
'the ordering of tensors in sets will change between runs. Please use a list instead.')
else:
new_group['params'] = list(new_params)
if properties.master_weights:
# Mutate new_group in-place to use FP32 master params
fp16_params_this_group = []
fp32_params_this_group = []
fp32_from_fp16_params_this_group = []
for i, param in enumerate(new_group['params']):
if param.requires_grad:
if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
new_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param)
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param)
new_group['params'][i] = param
else:
raise TypeError("Optimizer's parameters must be one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type()))
stash.fp16_groups.append(fp16_params_this_group)
stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
stash.fp32_from_fp32_groups.append(fp32_params_this_group)
stash.all_fp16_params += fp16_params_this_group
stash.all_fp32_from_fp16_params += fp32_from_fp16_params_this_group
stash.all_fp32_from_fp32_params += fp32_params_this_group
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
stash.all_fp32_from_fp32_grad_stash += [None for _ in fp32_params_this_group]
# It should be ok to let params be added with existing .grad attributes.
# for param in fp16_params_this_group:
# param.grad = None
# for param in fp32_from_fp16_params_this_group:
# param.grad = None
# for param in stash.fp32_params_this_group:
# param.grad = None
else:
for param in new_group['params']:
if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}:
stash.all_fp16_params.append(param)
stash.all_fp16_grad_stash.append(None)
elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param)
stash.all_fp32_grad_stash.append(None)
else:
raise TypeError("Optimizer's parameters must one of "
"torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. "
"Received {}".format(param.type()))
old_add_param_group(new_group)
optimizer.add_param_group = types.MethodType(new_add_param_group, optimizer)
return optimizer
from . import compat, rnn_compat, utils, wrap
from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides
from ._amp_state import _amp_state
from .frontend import *
import functools
import itertools
import torch
_DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set()
_USER_PROMOTE_REGISTRY = set()
def _decorator_helper(orig_fn, cast_fn, wrap_fn):
def wrapper(*args, **kwargs):
handle = _DECORATOR_HANDLE
if handle is None or not handle.is_active():
return orig_fn(*args, **kwargs)
inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__,
handle.verbose)
return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
return wrapper
# Decorator form
def half_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn)
def bfloat16_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_bfloat16, wrap_fn)
def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
def promote_function(fn):
wrap_fn = functools.partial(wrap.make_promote_wrapper)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
# Registry form
def register_half_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
def register_bfloat16_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_bfloat16))
def register_float_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_float))
def register_promote_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_PROMOTE_REGISTRY.add((module, name))
# Top-level function to insert _all_ the hooks.
def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_caching=True, verbose=False, allow_banned=False):
global _DECORATOR_HANDLE
if not enabled:
handle = NoOpHandle()
_DECORATOR_HANDLE = handle
return handle
handle = AmpHandle(loss_scale, enable_caching, verbose)
# 0) Force-{fp16, fp32} for user-annotated functions
for mod, fn, cast_fn in _USER_CAST_REGISTRY:
try_caching = (cast_fn == utils.maybe_half)
wrap.cached_cast(mod, fn, cast_fn, handle,
try_caching, verbose)
_USER_CAST_REGISTRY.clear()
# 0.5) Force-promote for user-annotated functions
for mod, fn in _USER_PROMOTE_REGISTRY:
wrap.promote(mod, fn, handle, verbose)
_USER_PROMOTE_REGISTRY.clear()
# conditionally choose between fp16 and bfloat16 functions list to cache
if patch_type == torch.float16:
low_prec_funcs = 'FP16_FUNCS'
maybe_low_prec = utils.maybe_half
low_prec_tensor = torch.cuda.HalfTensor
elif patch_type == torch.bfloat16:
low_prec_funcs = 'BFLOAT16_FUNCS'
maybe_low_prec = utils.maybe_bfloat16
low_prec_tensor = torch.cuda.BFloat16Tensor
else:
raise RuntimeError("Unsupported patch_torch_functions_type passed to initialize." +
"Supported types are: torch.float16 and torch.bfloat16.")
# 1) Force-{fp16, fp32} on white- / black-list functions
override_modules = [functional_overrides,
torch_overrides,
tensor_overrides]
cast_table = [(low_prec_funcs, maybe_low_prec),
('FP32_FUNCS', utils.maybe_float)]
for module, (list_name, cast_fn) in itertools.product(override_modules,
cast_table):
for fn in getattr(module, list_name):
try_caching = (cast_fn == maybe_low_prec)
wrap.cached_cast(module.MODULE, fn, cast_fn, handle,
try_caching, verbose)
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
# methods on FloatTensor, since they're distinct types.
if compat.tensor_is_float_tensor():
for fn in tensor_overrides.FP16_FUNCS:
wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half,
handle, try_caching=True, verbose=verbose)
for fn in tensor_overrides.FP32_FUNCS:
wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float,
handle, try_caching=False, verbose=verbose)
# 2) Enable type-promotion on multi-arg functions and methods.
# NB: special handling for sequence fns (e.g. `torch.cat`).
promote_modules = [torch_overrides, tensor_overrides]
promote_table = [('CASTS', wrap.promote),
('SEQUENCE_CASTS', wrap.sequence_promote)]
for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules,
promote_table):
for fn in getattr(promote_mod, list_name):
promote_fn(promote_mod.MODULE, fn, handle, verbose)
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if compat.tensor_is_float_tensor():
for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor,
torch.cuda.HalfTensor],
promote_table):
for fn in getattr(tensor_overrides, list_name):
promote_fn(cls, fn, handle, verbose)
# 3) For any in-place version of a blacklist function, error if any input is fp16/bfloat16.
# NB: this is overly conservative.
for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)
# 3.5) For any in-place blacklist method, error if called on fp16/bfloat16 tensor
for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor():
wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, handle, verbose)
# 4) For other in-place methods, match the type of self tensor
for fn in utils.as_inplace(itertools.chain(
getattr(tensor_overrides, low_prec_funcs),
tensor_overrides.CASTS)):
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor():
wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)
# 5) RNNs + RNN cells are whitelisted specially
if rnn_compat.has_old_rnns():
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose)
if not rnn_compat.has_old_rnns():
# Patch in our own indirection of `_VF` in modules/rnn s.t. it is mutable.
torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()
# Wrap all the rnns
for x in rnn_compat.RNN_NAMES:
wrap.new_rnn_cast(x.upper(), maybe_low_prec, handle, verbose)
# Wrap all the RNN cells
rnn_compat.whitelist_rnn_cells(maybe_low_prec, handle, verbose)
# 6) Place error+print message on banned functions.
# Or, if allow_banned, then cast to FP32.
for fn, err_msg in functional_overrides.BANNED_FUNCS:
if allow_banned:
wrap.cached_cast(functional_overrides.MODULE, fn, utils.maybe_float,
handle, try_caching=True, verbose=verbose)
else:
wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)
_DECORATOR_HANDLE = handle
_amp_state.handle = handle
return handle
import torch
# True for post-0.4, when Variables/Tensors merged.
def variable_is_tensor():
v = torch.autograd.Variable()
return isinstance(v, torch.Tensor)
def tensor_is_variable():
x = torch.Tensor()
return type(x) == torch.autograd.Variable
# False for post-0.4
def tensor_is_float_tensor():
x = torch.Tensor()
return type(x) == torch.FloatTensor
# Akin to `torch.is_tensor`, but returns True for Variable
# objects in pre-0.4.
def is_tensor_like(x):
return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable)
# Wraps `torch.is_floating_point` if present, otherwise checks
# the suffix of `x.type()`.
def is_floating_point(x):
if hasattr(torch, 'is_floating_point'):
return torch.is_floating_point(x)
try:
torch_type = x.type()
return torch_type.endswith('FloatTensor') or \
torch_type.endswith('HalfTensor') or \
torch_type.endswith('DoubleTensor') or \
torch_type.endswith('BFloat16Tensor')
except AttributeError:
return False
def scalar_python_val(x):
if hasattr(x, 'item'):
return x.item()
else:
if isinstance(x, torch.autograd.Variable):
return x.data[0]
else:
return x[0]
# Accounts for the possibility that some ops may be removed from a namespace.
def filter_attrs(module, attrs):
return list(attrname for attrname in attrs if hasattr(module, attrname))
import torch
from ._initialize import _initialize
from ._amp_state import _amp_state, warn_or_err, maybe_print
from collections import OrderedDict
class Properties(object):
"""
This class has two purposes: to establish a set of default properties,
and to route setting of these attributes through __setattr__ so that (in theory)
they can be checked for consistency with other existing args.
"""
def __init__(self):
self.options = {
"enabled" : False,
"opt_level" : None,
"cast_model_type" : None,
"patch_torch_functions" : False,
# TODO: patch_torch_functions_type could probably be unified with
# patch_torch_functions. Currently introducing a new attribute
# to be on the safer side and not break stuff.
"patch_torch_functions_type" : None,
"keep_batchnorm_fp32" : None,
"master_weights" : None,
"loss_scale" : 1.0,
# Reserved for future functionality
# "fused_optimizer" : False,
# "enable_ddp_interop" : False,
}
"""
This function allows updating several options at a time without routing through
__setattr__ checks, to avoid "you can't get there from here" scenarios.
Currently not intended to be exposed; users are expected to select an opt_level
and apply consistent modifications.
"""
def _update_options_dict(self, new_options):
for k, v in new_options:
if k in self.options:
self.options[k] = v
else:
raise ValueError("Tried to set unexpected option {}".format(k))
"""
The members of "options" are not direct attributes of self, so access attempts
will roll down to __getattr__. This borrows from the logic in torch.nn.Module.
"""
def __getattr__(self, name):
if "options" in self.__dict__:
options = self.__dict__["options"]
if name in options:
return options[name]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))
def __setattr__(self, name, value):
if "options" in self.__dict__:
if name in self.options:
# print("setting {} {}".format(name, value))
if name == "cast_model_type":
if self.opt_level in {"O1", "O4"} and value is not None:
if value is not False:
if value is not torch.float32:
warn_or_err("O1 inserts casts around Torch functions rather than "
"model weights, so with O1, the model weights themselves "
"should remain FP32. If you wish to cast the model to a "
"different type, use opt_level='O2' or 'O3'. " +
"cast_model_type was {}".format(value))
self.options[name] = value
elif name == "patch_torch_functions":
if self.opt_level not in {"O1", "O4"} and value:
warn_or_err("Currently, patch_torch_functions=True should only be set by "
"selecting opt_level='O1' or 'O4'.")
self.options[name] = value
elif name == "patch_torch_functions_type":
if self.opt_level not in {"O1", "O4"} and value is not None:
warn_or_err("Currently, patch_torch_functions_type should only be set by "
"selecting opt_level='O1' or 'O4'.")
elif self.opt_level == "O1" and value != torch.float16:
warn_or_err("patch_torch_functions_type should only be set to torch.float16 "
"for opt_level='O1.")
elif self.opt_level == "O4" and value != torch.bfloat16:
warn_or_err("patch_torch_functions_type should only be set to torch.bfloat16 "
"for opt_level='O4.")
else:
self.options[name] = value
elif name == "keep_batchnorm_fp32":
if self.opt_level in {"O1", "O4"} and value is not None:
warn_or_err("With opt_level O1 or O4, batchnorm functions are automatically patched "
"to run in FP32, so keep_batchnorm_fp32 should be None." +
" keep_batchnorm_fp32 was {}".format(value))
if value == "False":
self.options[name] = False
elif value == "True":
self.options[name] = True
else:
assert (value is True or value is False or value is None),\
"keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', "\
"or None, found keep_batchnorm_fp32={}".format(value)
self.options[name] = value
elif name == "master_weights":
if self.opt_level in {"O1", "O4"} and value is not None:
warn_or_err("It doesn't make sense to use master_weights with O1 and O4 . "
"With O1 and O4, your model weights themselves should be FP32.")
self.options[name] = value
elif name == "loss_scale":
if value == "dynamic":
self.options[name] = value
else:
self.options[name] = float(value)
else:
self.options[name] = value
else:
super(Properties, self).__setattr__(name, value)
""" O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. """
class O3:
brief = "O3: Pure FP16 training."
more = "Calls .half() on your model, converting the entire model to FP16.\n"\
"A casting operation is also inserted to cast incoming Tensors to FP16,\n"\
"so you don't need to change your data pipeline.\n"\
"This mode is useful for establishing a performance ceiling.\n"\
"It's also possible training may 'just work' in this mode.\n"\
"If not, try other optimization levels."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O3"
properties.cast_model_type = torch.float16
properties.patch_torch_functions = False
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = False
properties.master_weights = False
properties.loss_scale = 1.0
# properties.fused_optimizer = False
# properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
class O2:
brief = "O2: FP16 training with FP32 batchnorm and FP32 master weights.\n"
more = "Calls .half() on your model, converting the entire model (except for batchnorms)\n"\
"to FP16. Batchnorms are retained in FP32 for additional stability.\n"\
"The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change\n"\
"your data pipeline.\n"\
"O2 creates FP32 master weights outside the model and patches any optimizers to update\n"\
"these master weights, then copy the master weights into the FP16 model weights.\n"\
"Master weights can also improve convergence and stability."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O2"
properties.cast_model_type = torch.float16
properties.patch_torch_functions = False
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = True
properties.master_weights = True
properties.loss_scale = "dynamic"
# properties.fused_optimizer = False
# properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
class O1:
brief = "O1: Insert automatic casts around Pytorch functions and Tensor methods.\n"
more = "The type of your model's weights is not altered. However, internally,\n"\
"Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,\n"\
"while operations that might benefit from the additional stability of FP32 are patched\n"\
"to cast their inputs to fp32.\n"\
"O1 is the safest way to try mixed precision training, and is recommended when\n"\
"trying mixed precision training for the first time."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O1"
properties.cast_model_type = None
properties.patch_torch_functions = True
properties.patch_torch_functions_type = torch.float16
properties.keep_batchnorm_fp32 = None
properties.master_weights = None
properties.loss_scale = "dynamic"
# properties.fused_optimizer = False
# properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
class O0:
brief = "O0: Pure FP32 training.\n"
more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\
"types of weights and internal Pytorch operations are not altered. This mode disables any\n"\
"FP16 arithmetic, although other optimizations like DDP interop may still be requested.\n"
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O0"
properties.cast_model_type = torch.float32
properties.patch_torch_functions = False
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = None
properties.master_weights = False
properties.loss_scale = 1.0
# properties.fused_optimizer = False
# properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
class O4:
brief = "O4: Insert automatic casts around Pytorch functions and Tensor methods.\n"
more = "The type of your model's weights is not altered. However, internally,\n"\
"Pytorch functions are patched to cast any Tensor Core-friendly ops to BFLOAT16 for speed,\n"\
"while operations that might benefit from the additional stability of FP32 are patched\n"\
"to cast their inputs to fp32.\n"\
"Loss scaling is not required in O4 mode since bflaot16 has the same dynamic range as fp32."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O4"
properties.cast_model_type = None
properties.patch_torch_functions = True
properties.patch_torch_functions_type = torch.bfloat16
properties.keep_batchnorm_fp32 = None
properties.master_weights = None
properties.loss_scale = 1
return properties # modified in place so this isn't really necessary
class O5:
brief = "O5: BFLOAT16 training with FP32 batchnorm and FP32 master weights.\n"
more = "Calls .bfloat16() on your model, converting the entire model (except for batchnorms)\n"\
"to BFLOAT16. Batchnorms are retained in FP32 for additional stability.\n"\
"The forward pass is patched to cast incoming Tensors to BFLOAT16, so you don't need to change\n"\
"your data pipeline.\n"\
"O5 creates FP32 master weights outside the model and patches any optimizers to update\n"\
"these master weights, then copy the master weights into the BFLOAT16 model weights.\n"\
"Master weights can also improve convergence and stability."
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O5"
properties.cast_model_type = torch.bfloat16
properties.patch_torch_functions = False
properties.patch_torch_functions = None
properties.patch_torch_functions_type = None
properties.keep_batchnorm_fp32 = True
properties.master_weights = True
properties.loss_scale = 1
return properties # modified in place so this isn't really necessary
opt_levels = {"O3": O3(),
"O2": O2(),
"O1": O1(),
"O0": O0(),
"O4": O4(),
"O5": O5()}
# allow user to directly pass Properties struct as well?
def initialize(
models,
optimizers=None,
enabled=True,
opt_level="O1",
cast_model_type=None,
patch_torch_functions=None,
patch_torch_functions_type=None,
keep_batchnorm_fp32=None,
master_weights=None,
loss_scale=None,
cast_model_outputs=None,
num_losses=1,
verbosity=1,
min_loss_scale=None,
max_loss_scale=2.**24
):
"""
Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
chosen ``opt_level`` and overridden properties, if any.
``amp.initialize`` should be called **after** you have finished
constructing your model(s) and
optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper.
See `Distributed training`_ in the Imagenet example.
Currently, ``amp.initialize`` should only be called **once**,
although it can process an arbitrary number of
models and optimizers (see the corresponding `Advanced Amp Usage topic`_).
If you think your use case requires ``amp.initialize`` to be called more than once,
`let us know`_.
Any property keyword argument that is not ``None`` will be interpreted as a manual override.
To prevent having to rewrite anything else in your script, name the returned models/optimizers
to replace the passed models/optimizers, as in the code sample below.
Args:
models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast.
optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers): Optimizers to modify/cast.
REQUIRED for training, optional for inference.
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
should run as if Amp were not present.
opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
"O0", "O1", "O2", "O3", "O4" and "O5", explained in detail above.
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
above.
patch_torch_functions (bool, optional, default=None): Optional property override.
patch_torch_functions_type (``torch.dtype``, optional, default=None): Optional property override
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
passed as a string, must be the string "True" or "False".
master_weights (bool, optional, default=None): Optional property override.
loss_scale (float or str, optional, default=None): Optional property override. If passed as a string,
must be a string representing a number, e.g., "128.0", or the string "dynamic".
cast_model_outputs (torch.dtype, optional, default=None): Option to ensure that the outputs
of your model(s) are always cast to a particular type regardless of ``opt_level``.
num_losses (int, optional, default=1): Option to tell Amp in advance how many losses/backward
passes you plan to use. When used in conjunction with the ``loss_id`` argument to
``amp.scale_loss``, enables Amp to use a different loss scale per loss/backward pass,
which can improve stability. See "Multiple models/optimizers/losses"
under `Advanced Amp Usage`_ for examples. If ``num_losses`` is left to 1, Amp will still
support multiple losses/backward passes, but use a single global loss scale
for all of them.
verbosity (int, default=1): Set to 0 to suppress Amp-related output.
min_loss_scale (float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic
loss scaling. The default value of None means that no floor is imposed.
If dynamic loss scaling is not used, `min_loss_scale` is ignored.
max_loss_scale (float, default=2.**24): Sets a ceiling for the loss scale values that can be chosen by
dynamic loss scaling. If dynamic loss scaling is not used, `max_loss_scale` is ignored.
Returns:
Model(s) and optimizer(s) modified according to the ``opt_level``.
If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will
also be a list.
Permissible invocations::
model, optim = amp.initialize(model, optim,...)
model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...)
[model1, model2], optim = amp.initialize([model1, model2], optim,...)
[model1, model2], [optim1, optim2] = amp.initialize([model1, model2], [optim1, optim2],...)
# This is not an exhaustive list of the cross product of options that are possible,
# just a set of examples.
model, optim = amp.initialize(model, optim, opt_level="O0")
model, optim = amp.initialize(model, optim, opt_level="O0", loss_scale="dynamic"|128.0|"128.0")
model, optim = amp.initialize(model, optim, opt_level="O1") # uses "loss_scale="dynamic" default
model, optim = amp.initialize(model, optim, opt_level="O1", loss_scale=128.0|"128.0")
model, optim = amp.initialize(model, optim, opt_level="O2") # uses "loss_scale="dynamic" default
model, optim = amp.initialize(model, optim, opt_level="O2", loss_scale=128.0|"128.0")
model, optim = amp.initialize(model, optim, opt_level="O2", keep_batchnorm_fp32=True|False|"True"|"False")
model, optim = amp.initialize(model, optim, opt_level="O3") # uses loss_scale=1.0 default
model, optim = amp.initialize(model, optim, opt_level="O3", loss_scale="dynamic"|128.0|"128.0")
model, optim = amp.initialize(model, optim, opt_level="O3", keep_batchnorm_fp32=True|False|"True"|"False")
The `Imagenet example`_ demonstrates live use of various opt_levels and overrides.
.. _`Distributed training`:
https://github.com/NVIDIA/apex/tree/master/examples/imagenet#distributed-training
.. _`Imagenet example`:
https://github.com/NVIDIA/apex/tree/master/examples/imagenet
.. _`Advanced Amp Usage`:
https://nvidia.github.io/apex/advanced.html
.. _`Advanced Amp Usage topic`:
https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses
.. _`let us know`:
https://github.com/NVIDIA/apex/issues
"""
_amp_state.opt_properties = Properties()
_amp_state.verbosity = verbosity
if not enabled:
if optimizers is None:
return models
else:
return models, optimizers
if not torch.backends.cudnn.enabled:
raise RuntimeError(
"Amp requires torch.backends.cudnn.enabled = True")
if opt_level not in opt_levels:
raise RuntimeError(
"Unexpected optimization level {}. ".format(opt_level) +
"Options are 'O0', 'O1', 'O2', 'O3', 'O4', 'O5'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " +
"not the number zero.")
else:
_amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)
maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True)
maybe_print("Defaults for this optimization level are:", True)
for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:26} : {}".format(k, v), True)
_amp_state.min_loss_scale = min_loss_scale
_amp_state.max_loss_scale = max_loss_scale
maybe_print("Processing user overrides (additional kwargs that are not None)...", True)
# I chose to have the keyword arguments listed directly in the argument list,
# instead of **kwargs, so I can't use kwargs.items() here.
if enabled is not None:
_amp_state.opt_properties.enabled = enabled
if opt_level is not None:
_amp_state.opt_properties.opt_level = opt_level
if cast_model_type is not None:
_amp_state.opt_properties.cast_model_type = cast_model_type
if patch_torch_functions is not None:
_amp_state.opt_properties.patch_torch_functions = patch_torch_functions
if patch_torch_functions_type is not None:
_amp_state.opt_properties.patch_torch_functions_type = patch_torch_functions_type
if keep_batchnorm_fp32 is not None:
_amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32
if master_weights is not None:
_amp_state.opt_properties.master_weights = master_weights
if loss_scale is not None:
_amp_state.opt_properties.loss_scale = loss_scale
maybe_print("After processing overrides, optimization options are:", True)
for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:26} : {}".format(k, v), True)
return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
def state_dict(destination=None):
if destination is None:
destination = OrderedDict()
for idx, loss_scaler in enumerate(_amp_state.loss_scalers):
destination['loss_scaler%d' % idx] = {
'loss_scale': loss_scaler.loss_scale(),
'unskipped': loss_scaler._unskipped,
}
return destination
def load_state_dict(state_dict):
# Check if state_dict containes the same number of loss_scalers as current setup
if len(state_dict) != len(_amp_state.loss_scalers):
print('Warning: state_dict contains {} entries, while {} loss_scalers are used'.format(
len(state_dict), len(_amp_state.loss_scalers)))
state_dict = state_dict.copy()
nb_loss_scalers = len(_amp_state.loss_scalers)
unexpected_keys = []
# Initialize idx outside, since unexpected_keys will increase it if enumerate is used
idx = 0
for key in state_dict:
if 'loss_scaler' not in key:
unexpected_keys.append(key)
else:
if idx > (nb_loss_scalers - 1):
print('Skipping loss_scaler[{}], since num_losses was set to {}'.format(
idx, nb_loss_scalers))
break
_amp_state.loss_scalers[idx]._loss_scale = state_dict[key]['loss_scale']
_amp_state.loss_scalers[idx]._unskipped = state_dict[key]['unskipped']
idx += 1
if len(unexpected_keys) > 0:
raise RuntimeError(
'Error(s) in loading state_dict. Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
# TODO: is this necessary/useful?
# def check_option_consistency(enabled=True,
# opt_level=None,
# cast_model_type=None,
# patch_torch_functions=None,
# keep_batchnorm_fp32=None,
# master_weights=None,
# loss_scale=None,
# enable_ddp_interop=None,
# hard_override=False):
# """
# Utility function that enables users to quickly check if the option combination they intend
# to use is permitted. ``check_option_consistency`` does not require models or optimizers
# to be constructed, and can be called at any point in the script. ``check_option_consistency``
# is totally self-contained; it does not set any amp global state or affect anything outside
# of itself.
# """
#
# if not enabled:
# return
#
# if opt_level not in opt_levels:
# raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
# else:
# opt_properties = opt_levels[opt_level](Properties())
# print("Selected optimization level {}", opt_levels[opt_level].brief)
# print("Defaults for this optimization level are:")
# for k, v in opt_properties.options:
# print("{:22} : {}".format(k, v))
#
# print("Processing user overrides (additional kwargs that are not None)...")
# for k, v in kwargs:
# if k not in _amp_state.opt_properties.options:
# raise RuntimeError("Unexpected kwarg {}".format(k))
# if v is not None:
# setattr(opt_properties, k, v)
#
# print("After processing overrides, optimization options are:")
# for k, v in opt_properties.options:
# print("{:22} : {}".format(k, v))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment