Commit 0c74571f authored by Wil Kong's avatar Wil Kong Committed by mcarilli
Browse files

Add softmax cross entropy loss with label smoothing support. (#295)

* Add softmax cross entropy loss with label smoothing support.

* Fix deprecation of AT_DISPATCH_XXX and several minor issues.

* Fix issues commented by reviewers.

* Add FB license.

* Remove code generation constraints.

* Add a simple unittest for label smoothing.
parent 3e2883dd
#include <torch/extension.h>
// CUDA forward declarations
std::vector<at::Tensor> softmax_xentropy_cuda(
const at::Tensor &input,
const at::Tensor &labels,
const float smoothing,
const bool half_to_float);
at::Tensor softmax_xentropy_backward_cuda(
const at::Tensor &grad_loss,
const at::Tensor &logits,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> softmax_xentropy_forward(
const at::Tensor &input,
const at::Tensor &labels,
const float smoothing,
const bool half_to_float) {
CHECK_CUDA(input);
CHECK_INPUT(labels);
return softmax_xentropy_cuda(input, labels, smoothing, half_to_float);
}
at::Tensor softmax_xentropy_backward(
const at::Tensor &grad_loss,
const at::Tensor &logits,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing) {
CHECK_CUDA(grad_loss);
CHECK_CUDA(logits);
CHECK_INPUT(max_log_sum_exp);
CHECK_INPUT(labels);
return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)");
m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)");
}
This diff is collapsed.
import torch
from apex.contrib import xentropy as label_smoothing
import unittest
import warnings
import random
import numpy as np
import time
def label_smoothing_raw(x, target, padding_idx, smoothing):
logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)
non_pad_mask = (target != padding_idx)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)[non_pad_mask]
smooth_loss = -logprobs.mean(dim=-1)[non_pad_mask]
loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss
return loss
def label_smoothing_opt_1(x, target, padding_idx, smoothing):
logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)
pad_mask = (target == padding_idx)
ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1)
smooth_loss = logprobs.mean(dim=-1)
loss = (smoothing - 1.0) * ll_loss - smoothing * smooth_loss
loss.masked_fill_(pad_mask, 0)
return loss
class LabelSmoothingTest(unittest.TestCase):
def setUp(self, seed=1234):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Set pytorch print precision
torch.set_printoptions(precision=10)
def gen_test_inputs(self, N, T, H, smoothing, padding_idx):
logits = torch.randn((N*T, H), dtype=torch.half, device='cuda',
requires_grad=True)
labels = torch.randint(0, H, [N*T], device='cuda')
for i in random.sample(range(N*T), N*T//6):
labels[i] = padding_idx
half_to_float = (logits.dtype == torch.half)
return logits, labels, half_to_float
def print_max_diff_elem(self, ref, tst):
ref, tst = ref.flatten(), tst.flatten()
diff = (ref - tst).abs().max()
idx = (ref - tst).abs().argmax()
print("Max atol idx: {}, diff: {:.6f}, ref: {:.6f}, tst: {:.6f}".format(
idx, diff, ref[idx], tst[idx]))
def test_label_smoothing_function(self):
# Set label smoothing configuration
smoothing, padding_idx = 0.1, 0
N, T, H = 128, 74, 32320
iters = 10
loss_func = label_smoothing.SoftmaxCrossEntropyLoss.apply
for i in range(iters):
logits, labels, half_to_float = self.gen_test_inputs(
N, T, H, smoothing, padding_idx)
# Run original softmax cross entropy with label smoothing
logits.grad = None
losses = label_smoothing_raw(logits, labels, padding_idx, smoothing)
loss = losses.sum()
loss.backward()
ref_loss = loss.clone().detach()
ref_grad = logits.grad.clone().detach()
# Run optimized softmax cross entropy with label smoothing
logits.grad = None
losses = loss_func(logits, labels, smoothing, padding_idx, half_to_float)
loss = losses.sum()
loss.backward()
val_loss = loss.clone().detach()
val_grad = logits.grad.clone().detach()
# Validate
self.print_max_diff_elem(ref_grad, val_grad)
self.assertTrue(torch.allclose(ref_loss, val_loss, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_grad, val_grad, atol=1e-5, rtol=1e-5))
def test_label_smoothing_perf(self):
# Set label smoothing configuration
smoothing, padding_idx = 0.1, 0
N, T, H = 128, 74, 32320
iters = 1000
loss_func = label_smoothing.SoftmaxCrossEntropyLoss.apply
print()
logits, labels, half_to_float = self.gen_test_inputs(
N, T, H, smoothing, padding_idx)
# Run original softmax cross entropy with label smoothing
torch.cuda.synchronize()
ts = time.time()
for i in range(iters):
logits.grad = None
losses = label_smoothing_raw(logits, labels, padding_idx, smoothing)
loss = losses.sum() / N
loss.backward()
torch.cuda.synchronize()
print("Raw time {:.2f} s elapsed for {} iterations, norm {:.4f}".format(
time.time() - ts, iters, logits.grad.norm()))
# Run optimized softmax cross entropy with label smoothing
torch.cuda.synchronize()
ts = time.time()
for i in range(iters):
logits.grad = None
losses = loss_func(logits, labels, smoothing, padding_idx, half_to_float)
loss = losses.sum() / N
loss.backward()
torch.cuda.synchronize()
print("Opt time {:.2f} s elapsed for {} iterations, norm {:.4f}".format(
time.time() - ts, iters, logits.grad.norm()))
if __name__ == '__main__':
unittest.main()
try:
import torch
import xentropy_cuda
from .softmax_xentropy import SoftmaxCrossEntropyLoss
del torch
del xentropy_cuda
del softmax_xentropy
except ImportError as err:
print("apex was installed without --xentropy flag, contrib.xentropy is not available")
import torch
import xentropy_cuda
class SoftmaxCrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):
losses, max_log_sum_exp = xentropy_cuda.forward(
logits, labels, smoothing, half_to_float)
losses.masked_fill_(labels==padding_idx, 0)
ctx.save_for_backward(logits, max_log_sum_exp, labels,
torch.FloatTensor([smoothing]),
torch.LongTensor([padding_idx]))
return losses
@staticmethod
def backward(ctx, grad_loss):
logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors
if not grad_loss.is_contiguous():
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==padding_idx.item(), 0)
grad_logits = xentropy_cuda.backward(
grad_loss.contiguous(), logits, max_log_sum_exp,
labels, smoothing.item())
return grad_logits, None, None, None, None
......@@ -107,7 +107,7 @@ if "--bnp" in sys.argv:
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
raise RuntimeError("--bnp was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
......@@ -128,6 +128,29 @@ if "--bnp" in sys.argv:
'-gencode',
'arch=compute_70,code=sm_70'] + version_ge_1_1}))
if "--xentropy" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--xentropy")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_1 = ['-DVERSION_GE_1_1']
ext_modules.append(
CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=['csrc'],
extra_compile_args={'cxx': ['-O3'] + version_ge_1_1,
'nvcc':['-O3'] + version_ge_1_1}))
setup(
name='apex',
version='0.1',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment