Commit 2cbca1a4 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'master' into api_refactor

parents a9a3fe57 340e71a4
......@@ -3,23 +3,13 @@
from . import fp16_utils
from . import parallel
from . import amp
try:
from . import optimizers
except ImportError:
# An attempt to fix https://github.com/NVIDIA/apex/issues/97. I'm not sure why 97 is even
# happening because Python modules should only be imported once, even if import is called
# multiple times.
try:
_ = warned_optimizers
except NameError:
print("Warning: apex was installed without --cuda_ext. FusedAdam will be unavailable.")
warned_optimizers = True
try:
from . import normalization
except ImportError:
try:
_ = warned_normalization
except NameError:
print("Warning: apex was installed without --cuda_ext. FusedLayerNorm will be unavailable.")
warned_normalization = True
# For optimizers and normalization there is no Python fallback.
# Absence of cuda backend is a hard error.
# I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda
# to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext
# so they expect those backends to be available, but for some reason they actually aren't
# available (for example because they built improperly in a way that isn't revealed until
# load time) the error message is timely and visible.
from . import optimizers
from . import normalization
......@@ -40,23 +40,16 @@ class AmpHandle(object):
'use `optimizer.scale_loss(loss)`.')
# TODO: this code block is duplicated here and `opt.py`. Unify.
loss_backward = loss.backward
def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2)
loss_backward()
loss.backward = warning_wrapper
loss_scale = self._default_scaler.loss_scale()
yield loss * loss_scale
loss.backward = loss_backward
should_skip = self._default_scaler.unscale_and_update(
optimizer.param_groups, loss_scale)
if should_skip:
optimizer_step = optimizer.step
def skip_step():
logging.info('Gradient overflow, skipping update')
logger = logging.getLogger('apex.amp')
logger.warning('Gradient overflow, skipping update')
optimizer.step = optimizer_step
optimizer.step = skip_step
......
......@@ -21,14 +21,6 @@ class OptimWrapper(object):
yield loss
return
loss_backward = loss.backward
def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2)
loss_backward()
loss.backward = warning_wrapper
# When there are multiple losses per-optimizer, we need
# to save out current grad accumulation, since we won't be
# able to unscale this particulare loss once the grads are
......@@ -44,7 +36,6 @@ class OptimWrapper(object):
loss_scale = self._cur_loss_scaler().loss_scale()
yield loss * loss_scale
loss.backward = loss_backward
self._skip_next[self._loss_idx] = self._cur_loss_scaler().unscale_and_update(
self._optimizer.param_groups, loss_scale)
......@@ -76,7 +67,8 @@ class OptimWrapper(object):
'The `closure` argument is unsupported by the amp ' +
'optimizer wrapper.')
if any(self._skip_next):
logging.info('Gradient overflow, skipping update')
logger = logging.getLogger('apex.amp')
logger.info('Gradient overflow, skipping update')
self._skip_next = [False] * self._num_loss
else:
return self._optimizer.step(closure=closure)
......
import torch
import logging
# from apex_C import scale_check_overflow
# Python stopgap, until we get a future-proof kernel into upstream
def scale_check_overflow(d_grads, scale):
def scale_check_overflow_python(d_grads, scale):
# Exception handling for 18.04 compatibility
try:
cpu_sum = float(d_grads.float().sum())
......@@ -18,28 +18,60 @@ def scale_check_overflow(d_grads, scale):
return False
class LossScaler(object):
warned_no_fused_kernel = False
warned_fp16_grad = False
has_fused_kernel = False
def __init__(self):
self._loss_scale = 2.**16
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
self._has_overflow = False
# self._overflow_buf = torch.cuda.ByteTensor(1024,)
try:
import amp_C
LossScaler.has_fused_kernel = True
LossScaler.scale_check_overflow_cuda = amp_C.scale_check_overflow
self._overflow_buf = torch.cuda.IntTensor([0])
except ImportError as err:
if not LossScaler.warned_no_fused_kernel:
print("Warning: Amp fused downscale kernel is unavailable, possibly because apex "
"was installed without --cuda_ext. Using Python fallback. ImportError was: ",
err)
LossScaler.has_fused_kernel = False
LossScaler.warned_no_fused_kernel = True
def loss_scale(self):
return self._loss_scale
def unscale_and_update(self, param_groups, scale):
# self._overflow_buf.zero_()
if LossScaler.has_fused_kernel:
self._overflow_buf.zero_()
self._has_overflow = False
for p in iter_params(param_groups):
if p.grad is not None:
self._has_overflow = scale_check_overflow(p.grad.data,
1. / scale)
if LossScaler.has_fused_kernel and p.grad.data.type() == "torch.cuda.FloatTensor":
LossScaler.scale_check_overflow_cuda(p.grad.data,
1./scale,
self._overflow_buf,
p.grad.data)
else:
if (p.grad.data.type() != "torch.cuda.FloatTensor"
and not LossScaler.warned_fp16_grad):
logger = logging.getLogger("apex.amp")
logger.warning("Incoming grads are not fp32 (not master grads). "
"Downscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_fp16_grad = True
self._has_overflow = scale_check_overflow_python(p.grad.data,
1./scale)
if self._has_overflow:
break
# if self._overflow_buf.any():
# If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel and not self._has_overflow:
self._has_overflow = self._overflow_buf.item()
if self._has_overflow:
should_skip = True
self._loss_scale /= 2.
......
......@@ -3,11 +3,13 @@ import torch
import numbers
from torch.nn.parameter import Parameter
from torch.nn import init
import fused_layer_norm_cuda
import importlib
class FusedLayerNormAffineFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
self.normalized_shape = normalized_shape
self.eps = eps
......@@ -31,6 +33,8 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class FusedLayerNormFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
self.normalized_shape = normalized_shape
self.eps = eps
......@@ -117,6 +121,10 @@ class FusedLayerNorm(torch.nn.Module):
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(FusedLayerNorm, self).__init__()
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
......
......@@ -214,3 +214,69 @@ class FP16_Optimizer(object):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['cur_scale'] = self.cur_scale
state_dict['cur_iter'] = self.cur_iter
if state_dict['dynamic_loss_scale']:
state_dict['last_overflow_iter'] = self.last_overflow_iter
state_dict['scale_factor'] = self.scale_factor
state_dict['scale_window'] = self.scale_window
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict['fp32_groups_flat'] = self.fp32_groups_flat
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.cur_scale = state_dict['cur_scale']
self.cur_iter = state_dict['cur_iter']
if state_dict['dynamic_loss_scale']:
self.last_overflow_iter = state_dict['last_overflow_iter']
self.scale_factor = state_dict['scale_factor']
self.scale_window = state_dict['scale_window']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 2.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']):
current.data.copy_(saved.data)
import types
import torch
import fused_adam_cuda
import importlib
class FusedAdam(torch.optim.Optimizer):
......@@ -36,6 +36,9 @@ class FusedAdam(torch.optim.Optimizer):
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
......
......@@ -8,16 +8,15 @@ else:
ReduceOp = torch.distributed.deprecated.reduce_op
from .distributed import DistributedDataParallel, Reducer
# This is tricky because I'd like SyncBatchNorm to be exposed the same way
# for both the cuda-enabled and python-fallback versions, and I don't want
# to suppress the error information.
try:
import syncbn
from .optimized_sync_batchnorm import SyncBatchNorm
except ImportError:
try:
_ = warned_syncbn
except NameError:
print("Warning: apex was installed without --cuda_ext. Fused syncbn kernels will be unavailable. Python fallbacks will be used instead.")
warned_syncbn = True
except ImportError as err:
from .sync_batchnorm import SyncBatchNorm
SyncBatchNorm.syncbn_import_error = err
def convert_syncbn_model(module, process_group=None, channel_last=False):
'''
......
import torch
# from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
try:
from apex_C import flatten
from apex_C import unflatten
except ImportError:
try:
_ = warned_flatten
except NameError:
print("Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten.")
warned_flatten = True
from torch._utils import _flatten_dense_tensors as flatten
from torch._utils import _unflatten_dense_tensors as unflatten
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
from collections import OrderedDict
from itertools import chain
import copy
import importlib
imported_flatten_impl = False
def import_flatten_impl():
global flatten_impl, unflatten_impl, imported_flatten_impl
try:
import apex_C
flatten_impl = apex_C.flatten
unflatten_impl = apex_C.unflatten
except ImportError:
print("Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten.")
flatten_impl = torch._utils._flatten_dense_tensors
unflatten_impl = torch._utils._unflatten_dense_tensors
imported_flatten_impl = True
def flatten(bucket):
if not imported_flatten_impl:
import_flatten_impl()
return flatten_impl(bucket)
def unflatten(coalesced, bucket):
if not imported_flatten_impl:
import_flatten_impl()
return unflatten_impl(coalesced, bucket)
# apply_dist_call requires that tensors in 'bucket' are all the same type.
def apply_flat_dist_call(bucket, call, extra_args=None):
......
......@@ -2,6 +2,7 @@ import torch
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn import functional as F
import syncbn
from .optimized_sync_batchnorm_kernel import SyncBatchnormFunction
......
......@@ -45,7 +45,14 @@ class SyncBatchNorm(_BatchNorm):
>>> out = sbn(inp)
"""
warned = False
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None):
if not SyncBatchNorm.warned:
print("Warning: using Python fallback for SyncBatchNorm, possibly because apex was installed without --cuda_ext. The exception raised when attempting to import the cuda backend was: ", self.syncbn_import_error)
SyncBatchNorm.warned = True
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group
......
#include <torch/extension.h>
void scale_check_overflow_cuda(const at::Tensor& grads,
float scale,
const at::Tensor& d_buf,
const at::Tensor& downscaled_grads);
void scale_check_overflow(at::Tensor grads,
float scale,
at::Tensor overflow_buf,
at::Tensor downscaled_grads)
// const at::optional<at::Tensor> downscaled_grads)
{
AT_CHECK(grads.type().is_cuda(), "grads must be a CUDA tensor");
AT_CHECK(grads.is_contiguous(), "grads must be contiguous");
AT_CHECK(overflow_buf.type().is_cuda(), "overflow_buf must be a CUDA tensor");
AT_CHECK(overflow_buf.is_contiguous(), "overflow_buf must be contiguous");
AT_CHECK(downscaled_grads.type().is_cuda(), "downscaled_grads must be a CUDA tensor");
AT_CHECK(downscaled_grads.is_contiguous(), "downscaled_grads must be contiguous");
// Make sure we are downscaling the FP32 master grads
AT_CHECK(downscaled_grads.type().scalarType() == at::ScalarType::Float,
"The output grads supplied to scale_check_overflow should be fp32 (master grads).")
AT_CHECK(grads.numel() == downscaled_grads.numel(), "Input and output grads must be the same size.");
scale_check_overflow_cuda(grads, scale, overflow_buf, downscaled_grads);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scale_check_overflow", &scale_check_overflow, "Fused overflow check + scale for FP32 tensors");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <cuda_runtime.h>
#define BLOCK_SIZE 1024
#define NBLOCKS 160
// It makes sense to lock the output type to fp32 because the downscaled
// grads should be master grads (and in the case of Amp, the params and their
// gradients should always be fp32.
// This can be optimized with ILP but it's fine for now.
template<typename in_t>
__global__ void scale_reduce_overflow(in_t* in,
float* out,
int n,
float scale,
volatile int* overflow_global)
{
__shared__ int overflow;
int tid = blockIdx.x*blockDim.x + threadIdx.x;
int stride = gridDim.x*blockDim.x;
// Non-divergent exit condition for the __syncthreads
for(int i = tid; i - threadIdx.x < n; i += stride)
{
if(threadIdx.x == 0)
overflow = *overflow_global;
__syncthreads();
if(overflow == 1)
break;
if(i < n)
{
float incoming_val = static_cast<float>(in[i]);
if(isfinite(incoming_val))
out[i] = incoming_val*scale;
else
*overflow_global = 1; // Blindly fire off a write. These will race but that's ok.
// This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
// I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
// It's possible we can just lean on the cache (no smem or syncs) and still be fast.
}
}
}
void scale_check_overflow_cuda
(const at::Tensor& grads,
float scale,
const at::Tensor& overflow_buf,
const at::Tensor& downscaled_grads)
{
using namespace at;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int n = grads.numel();
// Lock the output (downscaled) type to float.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grads.type(),
"scale_check_overflow_cuda",
[&]
{
// using accscalar_t = acc_type<scalar_t, true>;
scale_reduce_overflow<<<NBLOCKS, BLOCK_SIZE, 0, stream>>>
(grads.data<scalar_t>(),
downscaled_grads.data<float>(),
n,
scale,
overflow_buf.data<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
......@@ -21,8 +21,8 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift);
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
......@@ -32,7 +32,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight);
const at::optional<at::Tensor> weight);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
......@@ -41,7 +41,7 @@ at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
......@@ -57,8 +57,8 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift);
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
......@@ -68,7 +68,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight);
const at::optional<at::Tensor> weight);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
......@@ -78,7 +78,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
......
......@@ -305,8 +305,8 @@ __global__ void batchnorm_forward_kernel(
const int bs) {
auto m_c = mean[blockIdx.x];
auto inv_std_c = inv_std[blockIdx.x];
auto w_c = static_cast<accscalar_t>(weight[blockIdx.x]);
auto s_c = static_cast<accscalar_t>(shift[blockIdx.x]);
auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x]);
auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[blockIdx.x]);
for (int batch_offset = blockIdx.y*blockDim.y + threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
......@@ -370,8 +370,12 @@ __global__ void reduce_bn_kernel(
sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu);
if (thread_id == 0) {
if (grad_bias != NULL) {
grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy);
}
if (grad_weight != NULL) {
grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
}
mean_dy[blockIdx.x] = sum_dy / total_item_num;
mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
}
......@@ -393,7 +397,7 @@ __global__ void batchnorm_backward_kernel(
auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
auto factor_1_c = inv_std[blockIdx.x];
auto factor_2_c = static_cast<accscalar_t>(weight[blockIdx.x]) * factor_1_c;
auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x])) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
......@@ -603,8 +607,8 @@ __global__ void batchnorm_forward_c_last_kernel(
auto m_c = mean[c_offset];
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
auto w_c = static_cast<accscalar_t>(weight[c_offset]);
auto s_c = static_cast<accscalar_t>(shift[c_offset]);
auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
......@@ -749,16 +753,24 @@ __global__ void reduce_bn_c_last_kernel(
merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
if (threadIdx.y == 0 && c_offset < stride) {
if (grad_bias != NULL) {
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
}
if (grad_weight != NULL) {
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
}
mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
}
}
} else {
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
if (grad_bias != NULL) {
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
}
if (grad_weight != NULL) {
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
}
mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
}
......@@ -793,7 +805,7 @@ __global__ void batchnorm_backward_c_last_kernel(
auto m_c = mean[c_offset];
auto m_dy_c = mean_dy[c_offset];
auto factor_1_c = inv_std[c_offset];
auto factor_2_c = static_cast<accscalar_t>(weight[c_offset]) * factor_1_c;
auto factor_2_c = (weight == NULL? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[c_offset];
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
......@@ -850,8 +862,8 @@ at::Tensor batchnorm_forward_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift) {
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) {
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
at::Tensor out = at::empty_like(input);
......@@ -866,29 +878,34 @@ at::Tensor batchnorm_forward_CUDA(
const dim3 grid(feature_size, batch_group_size, grid_z);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(),
shift.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>() : NULL,
out.data<scalar_t>(),
space_size,
batch_size);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<scalar_t>(),
shift.data<scalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL,
shift.has_value() ? shift.value().data<scalar_t>() : NULL,
out.data<scalar_t>(),
space_size,
batch_size);
......@@ -902,7 +919,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight)
const at::optional<at::Tensor> weight)
{
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
......@@ -911,8 +928,16 @@ std::vector<at::Tensor> reduce_bn_CUDA(
at::Tensor mean_dy = at::empty({feature_size}, mean.options());
at::Tensor mean_dy_xmu = at::empty({feature_size}, mean.options());
at::Tensor grad_weight = at::empty({feature_size}, weight.options());
at::Tensor grad_bias = at::empty({feature_size}, weight.options());
at::Tensor grad_weight;
at::Tensor grad_bias;
if (weight.has_value()) {
grad_weight = at::empty({feature_size}, weight.value().options());
grad_bias = at::empty({feature_size}, weight.value().options());
} else {
grad_weight = at::empty({0}, mean.options());
grad_bias = at::empty({0}, mean.options());
}
auto space_size = get_tensor_spatial_size(input);
......@@ -922,7 +947,9 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const dim3 grid(feature_size);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
......@@ -932,14 +959,17 @@ std::vector<at::Tensor> reduce_bn_CUDA(
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<accscalar_t>(),
grad_bias.data<accscalar_t>(),
weight.has_value() ? grad_weight.data<accscalar_t>() : NULL,
weight.has_value() ? grad_bias.data<accscalar_t>() : NULL,
batch_size,
feature_size,
space_size);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
......@@ -949,8 +979,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<scalar_t>(),
grad_bias.data<scalar_t>(),
weight.has_value() ? grad_weight.data<scalar_t>() : NULL,
weight.has_value() ? grad_bias.data<scalar_t>() : NULL,
batch_size,
feature_size,
space_size);
......@@ -965,7 +995,7 @@ at::Tensor batchnorm_backward_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) {
const auto batch_size = input.size(0);
......@@ -984,7 +1014,9 @@ at::Tensor batchnorm_backward_CUDA(
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
......@@ -992,7 +1024,7 @@ at::Tensor batchnorm_backward_CUDA(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
......@@ -1000,7 +1032,10 @@ at::Tensor batchnorm_backward_CUDA(
batch_size);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
......@@ -1008,7 +1043,7 @@ at::Tensor batchnorm_backward_CUDA(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<scalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL,
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
......@@ -1099,8 +1134,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift) {
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
......@@ -1113,7 +1148,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half
&& weight.type().scalarType() == at::ScalarType::Float) {
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
......@@ -1121,15 +1156,17 @@ at::Tensor batchnorm_forward_c_last_CUDA(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(),
shift.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t>(),
reduction_size,
stride);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(),
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
......@@ -1137,8 +1174,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<scalar_t>(),
shift.data<scalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL,
shift.has_value() ? shift.value().data<scalar_t>(): NULL,
out.data<scalar_t>(),
reduction_size,
stride);
......@@ -1152,14 +1189,23 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight) {
const at::optional<at::Tensor> weight) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor mean_dy = at::empty({stride}, mean.options());
at::Tensor mean_dy_xmu = at::empty({stride}, mean.options());
at::Tensor grad_weight = at::empty({stride}, weight.options());
at::Tensor grad_bias = at::empty({stride}, weight.options());
at::Tensor grad_weight;
at::Tensor grad_bias;
if (weight.has_value()) {
grad_weight = at::empty({stride}, weight.value().options());
grad_bias = at::empty({stride}, weight.value().options());
} else {
// because I cannot return an uninitialized at::Tensor
grad_weight = at::empty({0}, mean.options());
grad_bias = at::empty({0}, mean.options());
}
dim3 block;
dim3 grid;
......@@ -1173,7 +1219,9 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
}
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value()
&& weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
......@@ -1186,15 +1234,18 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<accscalar_t>(),
grad_bias.data<accscalar_t>(),
weight.has_value() ? grad_weight.data<accscalar_t>() : NULL,
weight.has_value() ?grad_bias.data<accscalar_t>() : NULL,
staging_data_ptr,
semaphores_ptr,
reduction_size,
stride);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
......@@ -1207,8 +1258,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<scalar_t>(),
grad_bias.data<scalar_t>(),
weight.has_value() ? grad_weight.data<scalar_t>() : NULL,
weight.has_value() ?grad_bias.data<scalar_t>() : NULL,
staging_data_ptr,
semaphores_ptr,
reduction_size,
......@@ -1224,7 +1275,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) {
const auto stride = input.size(input.ndimension()-1);
......@@ -1239,7 +1290,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half
&& weight.type().scalarType() == at::ScalarType::Float) {
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
......@@ -1248,7 +1299,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
......@@ -1256,8 +1307,10 @@ at::Tensor batchnorm_backward_c_last_CUDA(
stride);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(),
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
......@@ -1266,7 +1319,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<scalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL,
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
......
......@@ -119,7 +119,9 @@ def main():
if args.static_loss_scale != 1.0:
if not args.fp16:
print("Warning: if --fp16 is not used, static_loss_scale will be ignored.")
print("Warning: static_loss_scale != 1.0 is only necessary with --fp16. "
"Resetting static_loss_scale to 1.0")
args.static_loss_scale = 1.0
# create model
if args.pretrained:
......@@ -273,8 +275,8 @@ class data_prefetcher():
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True)
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
if args.fp16:
self.next_input = self.next_input.half()
else:
......
......@@ -256,8 +256,8 @@ class data_prefetcher():
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True)
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
# With Amp, it isn't necessary to manually convert data to half.
# Type conversions are done internally on the fly within patched torch functions.
# if args.fp16:
......
......@@ -265,8 +265,8 @@ class data_prefetcher():
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True)
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
if args.fp16:
self.next_input = self.next_input.half()
else:
......
......@@ -259,8 +259,8 @@ class data_prefetcher():
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True)
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
if args.fp16:
self.next_input = self.next_input.half()
else:
......@@ -377,8 +377,6 @@ def validate(val_loader, model, criterion):
while input is not None:
i += 1
target = target.cuda(async=True)
# compute output
with torch.no_grad():
output = model(input)
......
......@@ -9,10 +9,26 @@ The trained model can then be used by the generate script to generate new text.
`main_fp16_optimizer.py` with `--fp16` demonstrates use of `apex.fp16_utils.FP16_Optimizer` to automatically manage master parameters and loss scaling.
These examples are intended as an illustration of the mixed precision recipe, not necessarily as a performance showcase. However, they do demonstrate certain best practices.
First, a default loss scale of 128.0 is used. In our testing, this improves converged test perplexity modestly with mixed precision, from around 93 with loss scale 1.0 to around 90 with loss scale 128.0.
Second, to enable Tensor Core use with `--fp16` and improve performance, dimensions that participate in GEMMs in the model are made multiples of 8. Specifically, these are
* dictionary length (ntokens in `main.py`),
* embedding size (`--emsize`),
* hidden size (`--nhid`), and
* batch size (`--batch_size`).
The dictionary length is a property of the dataset, and is not controlled by a command line argument. In `main.py`, `corpus = data.Corpus(args.data, pad_to_multiple_of=8)` and the `Corpus` constructor in
`data.py` ensure that the dictionary length is a multiple of 8.
Also, for mixed precision performance, a good general rule is: the more work you give the GPU, the better. Bigger models and larger batch sizes supply the cores with more work and do a better job saturating the device. A (very rough) way to check if you're saturating the device is to run nvidia-smi from another terminal, and see what fraction of device memory you're using. This will tell you how much leeway you have to increase model or batch size.
```bash
python main.py --cuda --epochs 6 # Train a LSTM on Wikitext-2 with CUDA, reaching perplexity of 117.61
python main.py --cuda --epochs 6 --tied # Train a tied LSTM on Wikitext-2 with CUDA, reaching perplexity of 110.44
python main.py --cuda --tied # Train a tied LSTM on Wikitext-2 with CUDA for 40 epochs, reaching perplexity of 87.17
python main.py --cuda --epochs 6 # Train a LSTM on Wikitext-2 with CUDA
python main.py --cuda --epochs 6 --fp16 # Train a LSTM on Wikitext-2 with CUDA and mixed precision
python main.py --cuda --epochs 6 --tied # Train a tied LSTM on Wikitext-2 with CUDA
python main.py --cuda --tied # Train a tied LSTM on Wikitext-2 with CUDA for 40 epochs
python generate.py # Generate samples from the trained LSTM model.
```
......@@ -67,16 +83,11 @@ optional arguments:
```
which triggers the use of dynamic loss scaling. Supplying `--dynamic-loss-scale` will override the `--loss_scale` argument, if any.
With these arguments, a variety of models can be tested.
As an example, the following arguments produce slower but better models:
With these arguments, a variety of models can be tested. For example
```bash
python main.py --cuda --emsize 650 --nhid 650 --dropout 0.5 --epochs 40 # Test perplexity of 80.97
python main.py --cuda --emsize 650 --nhid 650 --dropout 0.5 --epochs 40 --tied # Test perplexity of 75.96
python main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 # Test perplexity of 77.42
python main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 --tied # Test perplexity of 72.30
python main.py --cuda --emsize 656 --nhid 656 --dropout 0.5 --epochs 40
python main.py --cuda --emsize 656 --nhid 656 --dropout 0.5 --epochs 40 --tied
python main.py --cuda --emsize 1504 --nhid 1504 --dropout 0.65 --epochs 40
python main.py --cuda --emsize 1504 --nhid 1504 --dropout 0.65 --epochs 40 --tied
```
Perplexities on PTB are equal or better than
[Recurrent Neural Network Regularization (Zaremba et al. 2014)](https://arxiv.org/pdf/1409.2329.pdf)
and are similar to [Using the Output Embedding to Improve Language Models (Press & Wolf 2016](https://arxiv.org/abs/1608.05859) and [Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling (Inan et al. 2016)](https://arxiv.org/pdf/1611.01462.pdf), though both of these papers have improved perplexities by using a form of recurrent dropout [(variational dropout)](http://papers.nips.cc/paper/6241-a-theoretically-grounded-application-of-dropout-in-recurrent-neural-networks).
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment