Commit 2fa4dbaf authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Initial release

parents
apex.egg-info
dist
build
docs/build
\ No newline at end of file
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
# Introduction
This is a repo is designed to hold PyTorch modules and utilities that are under active development and experimental. This repo is not designed as a long term solution or a production solution. Things placed in here are intended to be eventually moved to upstream PyTorch.
# Requirements
Python 3
PyTorch 0.3 or newer
CUDA 9
# [Full Documentation](https://nvidia.github.io/apex)
# Quick Start
To build the extension run the following command in the root directory of this project
```
python setup.py install
```
To use the extension simply run
```
import apex
```
and optionally (if required for your use)
```
import apex._C as apex_backend
```
# What's included
Current version of apex contains:
1. Mixed precision utilities can be found [here](https://nvidia.github.io/apex/fp16_utils) examples of using mixed precision utilities can be found for the [PyTorch imagenet example](https://github.com/csarofeen/examples/tree/apex/imagenet) and the [PyTorch word language model example](https://github.com/csarofeen/examples/tree/apex/word_language_model).
2. Parallel utilities can be found [here](https://nvidia.github.io/apex/parallel) and an example/walkthrough can be found [here](https://github.com/csarofeen/examples/tree/apex/distributed)
- apex/parallel/distributed.py contains a simplified implementation of PyTorch's DistributedDataParallel that's optimized for use with NCCL in single gpu / process mode
- apex/parallel/multiproc.py is a simple multi-process launcher that can be used on a single node/computer with multiple GPU's
3. Reparameterization function that allows you to recursively apply reparameterization to an entire module (including children modules).
4. An experimental and in development flexible RNN API.
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import math
def is_iterable(maybe_iterable):
return isinstance(maybe_iterable, list) or isinstance(maybe_iterable, tuple)
def flatten_list(tens_list):
"""
flatten_list
"""
if not is_iterable(tens_list):
return tens_list
return torch.cat(tens_list, dim=0).view(len(tens_list), *tens_list[0].size() )
#These modules always assumes batch_first
class bidirectionalRNN(nn.Module):
"""
bidirectionalRNN
"""
def __init__(self, inputRNN, num_layers=1, dropout = 0):
super(bidirectionalRNN, self).__init__()
self.dropout = dropout
self.fwd = stackedRNN(inputRNN, num_layers=num_layers, dropout = dropout)
self.bckwrd = stackedRNN(inputRNN.new_like(), num_layers=num_layers, dropout = dropout)
self.rnns = nn.ModuleList([self.fwd, self.bckwrd])
#collect hidden option will return all hidden/cell states from entire RNN
def forward(self, input, collect_hidden=False):
"""
forward()
"""
seq_len = input.size(0)
bsz = input.size(1)
fwd_out, fwd_hiddens = list(self.fwd(input, collect_hidden = collect_hidden))
bckwrd_out, bckwrd_hiddens = list(self.bckwrd(input, reverse=True, collect_hidden = collect_hidden))
output = torch.cat( [fwd_out, bckwrd_out], -1 )
hiddens = tuple( torch.cat(hidden, -1) for hidden in zip( fwd_hiddens, bckwrd_hiddens) )
return output, hiddens
def reset_parameters(self):
"""
reset_parameters()
"""
for rnn in self.rnns:
rnn.reset_parameters()
def init_hidden(self, bsz):
"""
init_hidden()
"""
for rnn in self.rnns:
rnn.init_hidden(bsz)
def detach_hidden(self):
"""
detach_hidden()
"""
for rnn in self.rnns:
rnn.detachHidden()
def reset_hidden(self, bsz):
"""
reset_hidden()
"""
for rnn in self.rnns:
rnn.reset_hidden(bsz)
def init_inference(self, bsz):
"""
init_inference()
"""
for rnn in self.rnns:
rnn.init_inference(bsz)
#assumes hidden_state[0] of inputRNN is output hidden state
#constructor either takes an RNNCell or list of RNN layers
class stackedRNN(nn.Module):
"""
stackedRNN
"""
def __init__(self, inputRNN, num_layers=1, dropout=0):
super(stackedRNN, self).__init__()
self.dropout = dropout
if isinstance(inputRNN, RNNCell):
self.rnns = [inputRNN]
for i in range(num_layers-1):
self.rnns.append(inputRNN.new_like(inputRNN.output_size))
elif isinstance(inputRNN, list):
assert len(inputRNN) == num_layers, "RNN list length must be equal to num_layers"
self.rnns=inputRNN
else:
raise RuntimeError()
self.nLayers = len(self.rnns)
self.rnns = nn.ModuleList(self.rnns)
'''
Returns output as hidden_state[0] Tensor([sequence steps][batch size][features])
If collect hidden will also return Tuple(
[n_hidden_states][sequence steps] Tensor([layer][batch size][features])
)
If not collect hidden will also return Tuple(
[n_hidden_states] Tensor([layer][batch size][features])
'''
def forward(self, input, collect_hidden=False, reverse=False):
"""
forward()
"""
seq_len = input.size(0)
bsz = input.size(1)
inp_iter = reversed(range(seq_len)) if reverse else range(seq_len)
hidden_states = [[] for i in range(self.nLayers)]
outputs = []
for seq in inp_iter:
for layer in range(self.nLayers):
if layer == 0:
prev_out = input[seq]
outs = self.rnns[layer](prev_out)
if collect_hidden:
hidden_states[layer].append(outs)
elif seq == seq_len-1:
hidden_states[layer].append(outs)
prev_out = outs[0]
outputs.append(prev_out)
if reverse:
outputs = list(reversed(outputs))
'''
At this point outputs is in format:
list( [seq_length] x Tensor([bsz][features]) )
need to convert it to:
list( Tensor([seq_length][bsz][features]) )
'''
output = flatten_list(outputs)
'''
hidden_states at this point is in format:
list( [layer][seq_length][hidden_states] x Tensor([bsz][features]) )
need to convert it to:
For not collect hidden:
list( [hidden_states] x Tensor([layer][bsz][features]) )
For collect hidden:
list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
'''
if not collect_hidden:
seq_len = 1
n_hid = self.rnns[0].n_hidden_states
new_hidden = [ [ [ None for k in range(self.nLayers)] for j in range(seq_len) ] for i in range(n_hid) ]
for i in range(n_hid):
for j in range(seq_len):
for k in range(self.nLayers):
new_hidden[i][j][k] = hidden_states[k][j][i]
hidden_states = new_hidden
#Now in format list( [hidden_states][seq_length][layer] x Tensor([bsz][features]) )
#Reverse seq_length if reverse
if reverse:
hidden_states = list( list(reversed(list(entry))) for entry in hidden_states)
#flatten layer dimension into tensor
hiddens = list( list(
flatten_list(seq) for seq in hidden )
for hidden in hidden_states )
#Now in format list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
#Remove seq_length dimension if not collect_hidden
if not collect_hidden:
hidden_states = list( entry[0] for entry in hidden_states)
return output, hidden_states
def reset_parameters(self):
"""
reset_parameters()
"""
for rnn in self.rnns:
rnn.reset_parameters()
def init_hidden(self, bsz):
"""
init_hidden()
"""
for rnn in self.rnns:
rnn.init_hidden(bsz)
def detach_hidden(self):
"""
detach_hidden()
"""
for rnn in self.rnns:
rnn.detach_hidden()
def reset_hidden(self, bsz):
"""
reset_hidden()
"""
for rnn in self.rnns:
rnn.reset_hidden(bsz)
def init_inference(self, bsz):
"""
init_inference()
"""
for rnn in self.rnns:
rnn.init_inference(bsz)
class RNNCell(nn.Module):
"""
RNNCell
gate_multiplier is related to the architecture you're working with
For LSTM-like it will be 4 and GRU-like will be 3.
Always assumes input is NOT batch_first.
Output size that's not hidden size will use output projection
Hidden_states is number of hidden states that are needed for cell
if one will go directly to cell as tensor, if more will go as list
"""
def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_states = 2, bias = False, output_size = None):
super(RNNCell, self).__init__()
self.gate_multiplier = gate_multiplier
self.input_size = input_size
self.hidden_size = hidden_size
self.cell = cell
self.bias = bias
self.output_size = output_size
if output_size is None:
self.output_size = hidden_size
self.gate_size = gate_multiplier * self.hidden_size
self.n_hidden_states = n_hidden_states
self.w_ih = nn.Parameter(torch.Tensor(self.gate_size, self.input_size))
self.w_hh = nn.Parameter(torch.Tensor(self.gate_size, self.output_size))
#Check if there's recurrent projection
if(self.output_size != self.hidden_size):
self.w_ho = nn.Parameter(torch.Tensor(self.output_size, self.hidden_size))
self.b_ih = self.b_hh = None
if self.bias:
self.b_ih = nn.Parameter(torch.Tensor(self.gate_size))
self.b_hh = nn.Parameter(torch.Tensor(self.gate_size))
#hidden states for forward
self.hidden = [ None for states in range(self.n_hidden_states)]
self.reset_parameters()
def new_like(self, new_input_size=None):
"""
new_like()
"""
if new_input_size is None:
new_input_size = self.input_size
return type(self)(self.gate_multiplier,
new_input_size,
self.hidden_size,
self.cell,
self.n_hidden_states,
self.bias,
self.output_size)
#Use xavier where we can (weights), otherwise use uniform (bias)
def reset_parameters(self, gain=1):
"""
reset_parameters()
"""
stdev = 1.0 / math.sqrt(self.hidden_size)
for param in self.parameters():
param.data.uniform_(-stdev, stdev)
'''
Xavier reset:
def reset_parameters(self, gain=1):
stdv = 1.0 / math.sqrt(self.gate_size)
for param in self.parameters():
if (param.dim() > 1):
torch.nn.init.xavier_normal(param, gain)
else:
param.data.uniform_(-stdv, stdv)
'''
def init_hidden(self, bsz):
"""
init_hidden()
"""
for param in self.parameters():
if param is not None:
a_param = param
break
for i, _ in enumerate(self.hidden):
if(self.hidden[i] is None or self.hidden[i].data.size()[0] != bsz):
if i==0:
hidden_size = self.output_size
else:
hidden_size = self.hidden_size
tens = a_param.data.new(bsz, hidden_size).zero_()
self.hidden[i] = Variable(tens, requires_grad=False)
def reset_hidden(self, bsz):
"""
reset_hidden()
"""
for i, _ in enumerate(self.hidden):
self.hidden[i] = None
self.init_hidden(bsz)
def detach_hidden(self):
"""
detach_hidden()
"""
for i, _ in enumerate(self.hidden):
if self.hidden[i] is None:
raise RuntimeError("Must inialize hidden state before you can detach it")
for i, _ in enumerate(self.hidden):
self.hidden[i] = self.hidden[i].detach()
def forward(self, input):
"""
forward()
if not inited or bsz has changed this will create hidden states
"""
self.init_hidden(input.size()[0])
hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
self.hidden = self.cell(input, hidden_state, self.w_ih, self.w_hh, b_ih=self.b_ih, b_hh=self.b_hh)
if(self.n_hidden_states > 1):
self.hidden = list(self.hidden)
else:
self.hidden=[self.hidden]
if self.output_size != self.hidden_size:
self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
return tuple(self.hidden)
from .models import LSTM, GRU, ReLU, Tanh, mLSTM
__all__ = ['models']
import torch
import torch.nn as nn
import torch.nn.functional as F
from .RNNBackend import RNNCell
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
import math
class mLSTMRNNCell(RNNCell):
"""
mLSTMRNNCell
"""
def __init__(self, input_size, hidden_size, bias = False, output_size = None):
gate_multiplier = 4
super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size)
self.w_mih = nn.Parameter(torch.Tensor(self.output_size, self.input_size))
self.w_mhh = nn.Parameter(torch.Tensor(self.output_size, self.output_size))
self.reset_parameters()
def forward(self, input):
"""
mLSTMRNNCell.forward()
"""
#if not inited or bsz has changed this will create hidden states
self.init_hidden(input.size()[0])
hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
self.hidden = list(
self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_mih, self.w_mhh,
b_ih=self.b_ih, b_hh=self.b_hh)
)
if self.output_size != self.hidden_size:
self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
return tuple(self.hidden)
def new_like(self, new_input_size=None):
if new_input_size is None:
new_input_size = self.input_size
return type(self)(
new_input_size,
self.hidden_size,
self.bias,
self.output_size)
def mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None):
"""
mLSTMCell
"""
if input.is_cuda:
igates = F.linear(input, w_ih)
m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)
hgates = F.linear(m, w_hh)
state = fusedBackend.LSTMFused.apply
return state(igates, hgates, hidden[1], b_ih, b_hh)
hx, cx = hidden
m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)
igates = F.linear(input, w_ih, b_ih) + F.linear(m, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy
import torch
from torch.nn._functions.rnn import LSTMCell, RNNReLUCell, RNNTanhCell, GRUCell
from .RNNBackend import bidirectionalRNN, stackedRNN, RNNCell
from .cells import mLSTMRNNCell, mLSTMCell
def toRNNBackend(inputRNN, num_layers, bidirectional=False, dropout = 0):
"""
:class:`toRNNBackend`
"""
if bidirectional:
return bidirectionalRNN(inputRNN, num_layers, dropout = dropout)
else:
return stackedRNN(inputRNN, num_layers, dropout = dropout)
def LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
"""
:class:`LSTM`
"""
inputRNN = RNNCell(4, input_size, hidden_size, LSTMCell, 2, bias, output_size)
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
def GRU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
"""
:class:`GRU`
"""
inputRNN = RNNCell(3, input_size, hidden_size, GRUCell, 1, bias, output_size)
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
def ReLU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
"""
:class:`ReLU`
"""
inputRNN = RNNCell(1, input_size, hidden_size, RNNReLUCell, 1, bias, output_size)
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
def Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
"""
:class:`Tanh`
"""
inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size)
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
def mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
"""
:class:`mLSTM`
"""
inputRNN = mLSTMRNNCell(input_size, hidden_size, bias=bias, output_size=output_size)
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
from . import RNN
from . import reparameterization
from . import fp16_utils
from . import parallel
from .fp16util import (
BN_convert_float,
network_to_half,
prep_param_lists,
model_grads_to_master_grads,
master_params_to_model_params,
tofp16,
)
from .fused_weight_norm import Fused_Weight_Norm
from .fp16_optimizer import fp32_to_fp16, fp16_to_fp32, FP16_Module, FP16_Optimizer
from .loss_scaler import LossScaler, DynamicLossScaler
This diff is collapsed.
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
class tofp16(nn.Module):
"""
Model wrapper that implements::
def forward(self, input):
return input.half()
"""
def __init__(self):
super(tofp16, self).__init__()
def forward(self, input):
return input.half()
def BN_convert_float(module):
'''
Designed to work with network_to_half.
BatchNorm layers need parameters in single precision.
Find all layers and convert them back to float. This can't
be done with built in .apply as that function will apply
fn to all modules, parameters, and buffers. Thus we wouldn't
be able to guard the float conversion based on the module type.
'''
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.float()
for child in module.children():
BN_convert_float(child)
return module
def network_to_half(network):
"""
Convert model to half precision in a batchnorm-safe way.
"""
return nn.Sequential(tofp16(), BN_convert_float(network.half()))
def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!")
def prep_param_lists(model, flat_master=False):
r"""
Creates a list of FP32 master parameters for a given model, as in
`Training Neural Networks with Mixed Precision: Real Examples`_.
Args:
model (torch.nn.Module): Existing Pytorch model
flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization.
Returns:
A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element.
Example::
model_params, master_params = prep_param_lists(model)
.. warning::
Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`.
.. _`Training Neural Networks with Mixed Precision: Real Examples`:
http://on-demand.gputechconf.com/gtc/2018/video/S81012/
"""
model_params = [param for param in model.parameters() if param.requires_grad]
if flat_master:
# flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
master_params = torch.nn.Parameter(master_params)
master_params.requires_grad = True
# master_params.register_hook(backwards_debug_hook)
if master_params.grad is None:
master_params.grad = master_params.new(*master_params.size())
return model_params, [master_params]
else:
master_params = [param.detach().clone().float() for param in model_params]
for param in master_params:
param.requires_grad = True
return model_params, master_params
def model_grads_to_master_grads(model_params, master_params, flat_master=False):
"""
Copy model gradients to master gradients.
Args:
model_params: List of model parameters created by :func:`prep_param_lists`.
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`.
"""
if flat_master:
# The flattening may incur one more deep copy than is necessary.
master_params[0].grad.data.copy_(
_flatten_dense_tensors([p.grad.data for p in model_params]))
else:
for model, master in zip(model_params, master_params):
if model.grad is not None:
if master.grad is None:
master.grad = Variable(master.data.new(*master.data.size()))
master.grad.data.copy_(model.grad.data)
else:
master.grad = None
def master_params_to_model_params(model_params, master_params, flat_master=False):
"""
Copy master parameters to model parameters.
Args:
model_params: List of model parameters created by :func:`prep_param_lists`.
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.
"""
if flat_master:
for model, master in zip(model_params,
_unflatten_dense_tensors(master_params[0].data, model_params)):
model.data.copy_(master)
else:
for model, master in zip(model_params, master_params):
model.data.copy_(master.data)
import torch
from torch.autograd import Variable
from torch.autograd.function import Function, once_differentiable
import apex._C
def check_contig_cuda(tensors, names):
for tensor, name in zip(tensors, names):
if not tensor.is_contiguous():
raise RuntimeError(name+" with size {} is not contiguous"
.format(tensor.size()))
if not tensor.is_cuda:
raise RuntimeError(name+".is_cuda = False."
"Currently, only cuda tensors are supported.")
class Fused_Weight_Norm(Function):
"""
Implements weight norm along a tensor's slowest dimension using fused kernel launches for
the forward and backward pass.
Accepts fp32 or fp16 input; the output type will match the input type.
Within the kernels, all calculations are performed in fp32 for numerical stability, regardless
of input/output precision.
"""
@staticmethod
def forward(ctx, input, g, dim=0):
"""
:attr:`input` is assumed to be contiguous.
:attr:`input` may be either float or half precision.
The precision of :attr:`output` will match the precision of :attr:`input`.
A float copy of the L2 norm across each slow dimension
is also created and saved for the backward pass.
"""
# torch.cuda.nvtx.range_push("FusedNorm.forward, input.size() = {}"
# .format(input.size()))
check_contig_cuda((input,g),("input","g"))
"""
This is ok, new() treats a torch.Size object properly.
No need to unpack with an asterisk via new(*input.size()).
"""
output = input.new(input.size()).contiguous()
"""
For output with size (slow, faster, faster, ...fastest), we may want
norms with size (slow, 1, 1, ...1), so that if you want retrieve norms
and apply the same normalizing factors to another Tensor "t" with the
same size as output, "t/norms" will broadcast each element of norms
across the corresponding slowest dim of t.
"""
if dim == 0:
norm_size = (output.size(0),) + (1,)*(output.dim() - 1)
elif dim == output.dim() - 1:
norm_size = (1,)*(output.dim() - 1) + (output.size(-1),)
else:
raise RuntimeError("Currently, Fused_Weight_Norm only supports first or last dimension.")
norms = torch.cuda.FloatTensor(*norm_size).contiguous()
"""
Beware: If you call the following:
norms = torch.cuda.FloatTensor(norm_size).contiguous()
the constructor sees a tuple:
FloatTensor( (output_size(0),1,1,...) )
and creates a 1D tensor with values from the tuple:
[output_size(0),1,1,...].
"""
apex._C.weight_norm_fwd(output, norms, input, g, dim)
ctx.save_for_backward(input, g)
# save_for_backward can only save input or output tensors,
# use ctx state to save the norms and dimension:
ctx.norms = norms
ctx.dim = dim
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
"""
:attr:`grad_output` is assumed to be contiguous.
:attr:`grad_output` may be either float or half precision.
The precision of :attr:`grad_input` will match the precision of :attr:`grad_output`.
"""
check_contig_cuda((grad_output), ("grad_output"))
savedInput, savedg = ctx.saved_tensors
savedNorms = ctx.norms
# better safe than sorry
grad_output_contig = grad_output.contiguous()
grad_input = grad_output_contig.new(grad_output.size()).contiguous()
grad_g = savedg.new(savedg.size()).contiguous()
apex._C.weight_norm_bwd(grad_input,
grad_g,
grad_output_contig,
savedInput,
savedg,
savedNorms,
ctx.dim)
return grad_input, grad_g, None
"""
Top of loss_scaler.py stub. Can't figure out a way to get the module file
highlighted in a pretty way, or link back to source.
"""
import torch
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
class LossScaler:
"""
Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
Use of LossScaler is enabled via the ``static_loss_scale`` argument to
:class:`FP16_Optimizer`'s constructor.
"""
def __init__(self, scale=1):
self.cur_scale = scale
# `params` is a list / generator of torch.Variable
def has_overflow(self, params):
return False
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
return False
# `overflow` is boolean indicating whether we overflowed in gradient
def update_scale(self, overflow):
pass
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss):
scaled_loss = loss*self.loss_scale
scaled_loss.backward()
class DynamicLossScaler:
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
operates, because the default options can be changed using the
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
Loss scaling is designed to combat the problem of underflowing gradients encountered at long
times when training FP16 networks. Dynamic loss scaling begins by attempting a very high loss
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
encountered, DynamicLossScaler informs :class:`FP16_Optimizer` that an overflow has occurred.
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
If a certain number of iterations occur without overflowing gradients detected,
:class:`DynamicLossScaler` increases the loss scale once more.
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
always using the highest loss scale possible without incurring overflow.
Args:
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
"""
def __init__(self,
init_scale=2**32,
scale_factor=2.,
scale_window=1000):
self.cur_scale = init_scale
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_window = scale_window
# `params` is a list / generator of torch.Variable
def has_overflow(self, params):
for p in params:
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
return True
return False
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
try:
# Stopgap until upstream fixes sum() on HalfTensors
cpu_sum = float(x.float().sum())
# cpu_sum = float(x.sum())
# print(cpu_sum)
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
# `overflow` is boolean indicating whether we overflowed in gradient
def update_scale(self, overflow):
if overflow:
# self.cur_scale /= self.scale_factor
self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
self.last_overflow_iter = self.cur_iter
else:
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
self.cur_scale *= self.scale_factor
self.cur_iter += 1
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss):
scaled_loss = loss*self.loss_scale
scaled_loss.backward()
##############################################################
# Example usage below here -- assuming it's in a separate file
##############################################################
"""
TO-DO separate out into an example.
if __name__ == "__main__":
import torch
from torch.autograd import Variable
from dynamic_loss_scaler import DynamicLossScaler
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in), requires_grad=False)
y = Variable(torch.randn(N, D_out), requires_grad=False)
w1 = Variable(torch.randn(D_in, H), requires_grad=True)
w2 = Variable(torch.randn(H, D_out), requires_grad=True)
parameters = [w1, w2]
learning_rate = 1e-6
optimizer = torch.optim.SGD(parameters, lr=learning_rate)
loss_scaler = DynamicLossScaler()
for t in range(500):
y_pred = x.mm(w1).clamp(min=0).mm(w2)
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
# Run backprop
optimizer.zero_grad()
loss.backward()
# Check for overflow
has_overflow = DynamicLossScaler.has_overflow(parameters)
# If no overflow, unscale grad and update as usual
if not has_overflow:
for param in parameters:
param.grad.data.mul_(1. / loss_scaler.loss_scale)
optimizer.step()
# Otherwise, don't do anything -- ie, skip iteration
else:
print('OVERFLOW!')
# Update loss scale for next iteration
loss_scaler.update_scale(has_overflow)
"""
from .distributed import DistributedDataParallel
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
def flat_dist_call(tensors, call, extra_args=None):
flat_dist_call.warn_on_half = True
buckets = {}
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
if flat_dist_call.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case.")
flat_dist_call.warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
coalesced = _flatten_dense_tensors(bucket)
if extra_args is not None:
call(coalesced, *extra_args)
else:
call(coalesced)
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
class DistributedDataParallel(Module):
"""
:class:`DistributedDataParallel` is a simpler version of upstream :class:`
DistributedDataParallel` that is optimized for use with NCCL. Its usage is designed
to be used in conjunction with apex.parallel.multiproc.py. It assumes that your run
is using multiprocess with 1 GPU/process, that the model is on the correct device,
and that torch.set_device has been used to set the device. Parameters are broadcasted
to the other processes on initialization of DistributedDataParallel, and will be
allreduced in buckets durring the backward pass.
See https://github.com/csarofeen/examples/tree/apex/distributed for detailed usage.
Args:
module: Network definition to be run in multi-gpu/distributed mode.
message_size (Default = 10000000): Minimum number of elements in a communication bucket.
"""
def __init__(self, module, message_size=10000000):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
self.message_size = message_size
#reference to last iterations parameters to see if anything has changed
self.param_refs = []
self.reduction_stream = torch.cuda.Stream()
self.module = module
self.param_list = list(self.module.parameters())
if dist._backend == dist.dist_backend.NCCL:
for param in self.param_list:
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
self.record = []
self.create_hooks()
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
def create_hooks(self):
#all reduce gradient hook
def allreduce_params():
if(self.needs_reduction):
self.needs_reduction = False
self.needs_refresh = False
else:
return
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce)
t_record = torch.cuda.IntTensor(self.record)
dist.broadcast(t_record, 0)
self.record = [int(entry) for entry in t_record]
def flush_buckets():
if not self.needs_reduction:
return
self.needs_reduction = False
ready = []
for i in range(len(self.param_state)):
if self.param_state[i] == 1:
param = self.param_list[self.record[i]]
if param.grad is not None:
ready.append(param.grad.data)
if(len(ready)>0):
orig_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.reduction_stream):
self.reduction_stream.wait_stream(orig_stream)
flat_dist_call(ready, dist.all_reduce)
torch.cuda.current_stream().wait_stream(self.reduction_stream)
for param_i, param in enumerate(list(self.module.parameters())):
def wrapper(param_i):
def allreduce_hook(*unused):
if self.needs_refresh:
self.record.append(param_i)
Variable._execution_engine.queue_callback(allreduce_params)
else:
Variable._execution_engine.queue_callback(flush_buckets)
self.param_state[self.record.index(param_i)] = 1
self.comm_ready_buckets()
if param.requires_grad:
param.register_hook(allreduce_hook)
wrapper(param_i)
def comm_ready_buckets(self):
ready = []
counter = 0
while counter < len(self.param_state) and self.param_state[counter] == 2:
counter += 1
while counter < len(self.param_state) and self.param_state[counter] == 1:
ready.append(counter)
counter += 1
if not ready:
return
grads = []
for ind in ready:
param_ind = self.record[ind]
if self.param_list[param_ind].grad is not None:
grads.append(self.param_list[param_ind].grad.data)
bucket = []
bucket_inds = []
while grads:
bucket.append(grads.pop(0))
bucket_inds.append(ready.pop(0))
cumm_size = 0
for ten in bucket:
cumm_size += ten.numel()
if cumm_size < self.message_size:
continue
evt = torch.cuda.Event()
evt.record(torch.cuda.current_stream())
evt.wait(stream=self.reduction_stream)
with torch.cuda.stream(self.reduction_stream):
flat_dist_call(bucket, dist.all_reduce)
for ind in bucket_inds:
self.param_state[ind] = 2
def forward(self, *inputs, **kwargs):
"""
Forward function for DDP.
Args:
inputs: inputs that match the module's passed in for initialization.
kwargs: kwargs that match the module's passed in for initialization.
"""
param_list = [param for param in list(self.module.parameters()) if param.requires_grad]
self.needs_refresh = True if not self.param_refs else any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]
)
if self.needs_refresh:
self.record = []
self.param_state = [0 for i in range(len(param_list))]
self.param_refs = param_list
self.needs_reduction = True
return self.module(*inputs, **kwargs)
import torch
import sys
import subprocess
def docstring_hack():
"""
Multiproc file which will launcch a set of processes locally for multi-gpu
usage: python -m apex.parallel.multiproc main.py ...
"""
pass
argslist = list(sys.argv)[1:]
world_size = torch.cuda.device_count()
if '--world-size' in argslist:
argslist[argslist.index('--world-size')+1] = str(world_size)
else:
argslist.append('--world-size')
argslist.append(str(world_size))
workers = []
for i in range(world_size):
if '--rank' in argslist:
argslist[argslist.index('--rank')+1] = str(i)
else:
argslist.append('--rank')
argslist.append(str(i))
stdout = None if i == 0 else open("GPU_"+str(i)+".log", "w")
print(argslist)
p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout)
workers.append(p)
for p in workers:
p.wait()
from .weight_norm import WeightNorm
from .reparameterization import Reparameterization
def apply_weight_norm(module, name='', dim=0, hook_child=True):
"""
Applies weight normalization to a parameter in the given module.
If no parameter is provided, applies weight normalization to all
parameters in model (except 1-d vectors and scalars).
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by `name` (e.g. "weight") with two parameters: one specifying the magnitude
(e.g. "weight_g") and one specifying the direction (e.g. "weight_v").
Weight normalization is implemented via a hook that recomputes the weight
tensor from the magnitude and direction before every :meth:`~Module.forward`
call.
By default, with `dim=0`, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
`dim=None`.
See https://arxiv.org/abs/1602.07868
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
parameters. If False, it's added to `module` instead. Default: True
Returns:
The original module with the weight norm hook
Example::
>>> m = apply_weight_norm(nn.Linear(20, 40), name='weight')
Linear (20 -> 40)
>>> m.weight_g.size()
torch.Size([40, 1])
>>> m.weight_v.size()
torch.Size([40, 20])
"""
return apply_reparameterization(module, reparameterization=WeightNorm, hook_child=hook_child,
name=name, dim=dim)
def remove_weight_norm(module, name='', remove_all=False):
"""
Removes the weight normalization reparameterization of a parameter from a module.
If no parameter is supplied then all weight norm parameterizations are removed.
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = apply_weight_norm(nn.Linear(20, 40))
>>> remove_weight_norm(m)
"""
return remove_reparameterization(module, reparameterization=WeightNorm,
name=name, remove_all=remove_all)
def apply_reparameterization(module, reparameterization=None, name='', dim=0, hook_child=True):
"""
Applies a given weight reparameterization (such as weight normalization) to
a parameter in the given module. If no parameter is given, applies the reparameterization
to all parameters in model (except 1-d vectors and scalars).
Args:
module (nn.Module): containing module
reparameterization (Reparameterization): reparamaterization class to apply
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to perform reparameterization op
hook_child (boolean, optional): adds reparameterization hook to direct parent of the
parameters. If False, it's added to `module` instead. Default: True
Returns:
The original module with the reparameterization hook
Example::
>>> m = apply_reparameterization(nn.Linear(20, 40), WeightNorm)
Linear (20 -> 40)
"""
assert reparameterization is not None
if name != '':
Reparameterization.apply(module, name, dim, reparameterization, hook_child)
else:
names = list(module.state_dict().keys())
for name in names:
apply_reparameterization(module, reparameterization, name, dim, hook_child)
return module
def remove_reparameterization(module, reparameterization=Reparameterization,
name='', remove_all=False):
"""
Removes the given reparameterization of a parameter from a module.
If no parameter is supplied then all reparameterizations are removed.
Args:
module (nn.Module): containing module
reparameterization (Reparameterization): reparamaterization class to apply
name (str, optional): name of weight parameter
remove_all (bool, optional): if True, remove all reparamaterizations of given type. Default: False
Example:
>>> m = apply_reparameterization(nn.Linear(20, 40),WeightNorm)
>>> remove_reparameterization(m)
"""
if name != '' or remove_all:
to_remove = []
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, reparameterization) and (hook.name == name or remove_all):
hook.remove(module)
to_remove.append(k)
if len(to_remove) > 0:
for k in to_remove:
del module._forward_pre_hooks[k]
return module
if not remove_all:
raise ValueError("reparameterization of '{}' not found in {}"
.format(name, module))
else:
modules = [module]+[x for x in module.modules()]
for m in modules:
remove_reparameterization(m, reparameterization=reparameterization, remove_all=True)
return module
import torch
from torch.nn.parameter import Parameter
import sys
class Reparameterization(object):
"""
Class interface for performing weight reparameterizations
Arguments:
name (str): name of weight parameter
dim (int): dimension over which to compute the norm
module (nn.Module): parent module to which param `name` is registered to
retain_forward (bool, optional): if False deletes weight on call to
module.backward. Used to avoid memory leaks with DataParallel Default: True
Attributes:
reparameterization_names (list, str): contains names of all parameters
needed to compute reparameterization.
backward_hook_key (int): torch.utils.hooks.RemovableHandle.id for hook used in module backward pass.
"""
def __init__(self, name, dim, module, retain_forward=True):
self.name = name
self.dim = dim
self.evaluated = False
self.retain_forward = retain_forward
self.reparameterization_names = []
self.backward_hook_key = None
self.module = module
def compute_weight(self, module=None, name=None):
"""
Computes reparameterized weight value to assign value to module attribute
with name `name`.
See WeightNorm class for example.
Arguments:
module (nn.Module): module with weight we'd like to reparameterize
Returns:
w (Tensor): Tensor object containing value of reparameterized weight
"""
raise NotImplementedError
def reparameterize(self, name, weight, dim):
"""
Creates Parameters to be used for reparameterization and creates names that
for attributes for the module these Parameters will correspond to.
The parameters will be registered according to the names provided.
See WeightNorm class for example.
Arguments:
module (nn.Module): module with weight we'd like to reparameterize
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute parameterization
Returns:
names (list, str): names of Parameters to be used for reparameterization
params (list, Parameter): Parameters to be used for reparameterization
"""
raise NotImplementedError
@staticmethod
def apply(module, name, dim, reparameterization=None, hook_child=True):
"""
Applies reparametrization to module's `name` parameter and modifies instance attributes as appropriate.
`hook_child` adds reparameterization hook to direct parent of the parameters. If False, it's added to `module` instead.
"""
if reparameterization is None:
reparameterization = Reparameterization
module2use, name2use = Reparameterization.get_module_and_name(module, name)
# does not work on sparse
if name2use is None or isinstance(module2use, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
return
if hook_child:
fn = reparameterization(name2use, dim, module2use)
else:
fn = reparameterization(name, dim, module)
weight = getattr(module2use, name2use)
if weight.dim() <= 1:
return
# remove weight from parameter list
del module2use._parameters[name2use]
# add parameters of reparameterization of parameter to module
names, params = fn.reparameterize(name2use, weight, dim)
for n, p in zip(names, params):
module2use.register_parameter(n, p)
# add parameters to reparameterization so they can be removed later
fn.reparameterization_names = names
setattr(module2use, name2use, None)
hook_module = module2use
if not hook_child:
hook_module = module
# recompute weight before every forward()
hook_module.register_forward_pre_hook(fn)
# remove weight during backward
handle = hook_module.register_backward_hook(fn.backward_hook)
# get hook key so we can delete it later
fn.backward_hook_key = handle.id
return fn
@staticmethod
def get_module_and_name(module, name):
"""
recursively fetches (possible) child module and name of weight to be reparameterized
"""
name2use = None
module2use = None
names = name.split('.')
if len(names) == 1 and names[0] != '':
name2use = names[0]
module2use = module
elif len(names) > 1:
module2use = module
name2use = names[0]
for i in range(len(names)-1):
module2use = getattr(module2use, name2use)
name2use = names[i+1]
return module2use, name2use
def get_params(self, module):
"""gets params of reparameterization based on known attribute names"""
return [getattr(module, n) for n in self.reparameterization_names]
def remove(self, module):
"""removes reparameterization and backward hook (does not remove forward hook)"""
module2use, name2use = Reparameterization.get_module_and_name(module, self.name)
for p in self.get_params(module2use):
p.requires_grad = False
weight = self.compute_weight(module2use, name2use)
delattr(module2use, name2use)
for n in self.reparameterization_names:
del module2use._parameters[n]
module2use.register_parameter(name2use, Parameter(weight.data))
del module._backward_hooks[self.backward_hook_key]
def __call__(self, module, inputs):
"""callable hook for forward pass"""
module2use, name2use = Reparameterization.get_module_and_name(module, self.name)
_w = getattr(module2use, name2use)
if not self.evaluated or _w is None:
setattr(module2use, name2use, self.compute_weight(module2use, name2use))
self.evaluated = True
def backward_hook(self, module, grad_input, grad_output):
"""callable hook for backward pass"""
module2use, name2use = Reparameterization.get_module_and_name(module, self.name)
wn = getattr(module2use, name2use)
self.evaluated = False
import torch
from torch.nn.parameter import Parameter
from ..fp16_utils import Fused_Weight_Norm
import time
from .reparameterization import Reparameterization
def _norm(p, dim):
"""Computes the norm over all dimensions except dim"""
if dim is None:
return p.norm()
elif dim == 0:
output_size = (p.size(0),) + (1,) * (p.dim() - 1)
return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)
elif dim == p.dim() - 1:
output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size)
return _norm(p.transpose(0, dim), 0).transpose(0, dim)
HALF_TYPES = (torch.cuda.HalfTensor, torch.HalfTensor)
class WeightNorm(Reparameterization):
"""
Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by `name` (e.g. "weight") with two parameters: one specifying the magnitude
(e.g. "weight_g") and one specifying the direction (e.g. "weight_v").
Weight normalization is implemented via a hook that recomputes the weight
tensor from the magnitude and direction before every :meth:`~Module.forward`
call.
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
By default, with `dim=0`, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
`dim=None`.
"""
def compute_weight(self, module=None, name=None):
"""
Computes weight normalized weight value to assign value to module attribute
with name `name`.
Arguments:
module (nn.Module): module with weight we'd like to reparameterize
Returns:
w (Tensor): Tensor object containing value of reparameterized weight
"""
if module is None:
module = self.module
if name is None:
name = self.name
module, name = Reparameterization.get_module_and_name(module, name)
g = getattr(module, name + '_g')
v = getattr(module, name + '_v')
fused_weight_norm = Fused_Weight_Norm.apply
v = v.contiguous()
w = fused_weight_norm(v, g, self.dim)
return w
def reparameterize(self, name, weight, dim):
"""
Creates Parameters v and gto be used for weight normalization
and creates names that for attributes for the module these Parameters
will correspond to. The parameters will be registered according to the names
provided.
Arguments:
module (nn.Module): module with weight we'd like to reparameterize
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute parameterization
Returns:
names (list, str): names of Parameters to be used for reparameterization
params (list, Parameter): Parameters to be used for reparameterization
"""
names = [name + '_g', name + '_v']
params = [Parameter(_norm(weight, dim).data), Parameter(weight.data)]
return names, params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment