Unverified Commit 28f8539c authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Add CUDA Focal Loss Implementation (#1337)



Take-over of #1097

* Add fast CUDA focal loss implementation.

* Enable fast math for CUDA focal loss.

* Correct typo.

* replace deprecated macros

* Add fast CUDA focal loss implementation.

* Enable fast math for CUDA focal loss.

* Correct typo.

* replace deprecated macros

* TORCH_CUDA_CHECK -> AT_CUDA_CHECK

The former is defined in torch/csrc/profiler/cuda.cpp so it's not available usually.
The latter however is defined in ATen/cuda/Exceptions.h as an alias of C10_CUDA_CHECK.

* add test

* clean up

* guard for torchvision
Co-authored-by: default avatarWil Kong <alpha0422@gmail.com>
parent feae3851
#include <torch/torch.h>
#include <vector>
#include <cstdint>
// CUDA forward declarations
std::vector<at::Tensor> focal_loss_forward_cuda(
const at::Tensor &cls_output,
const at::Tensor &cls_targets_at_level,
const at::Tensor &num_positives_sum,
const int64_t num_real_classes,
const float alpha,
const float gamma,
const float smoothing_factor);
at::Tensor focal_loss_backward_cuda(
const at::Tensor &grad_output,
const at::Tensor &partial_grad,
const at::Tensor &num_positives_sum);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> focal_loss_forward(
const at::Tensor &cls_output,
const at::Tensor &cls_targets_at_level,
const at::Tensor &num_positives_sum,
const int64_t num_real_classes,
const float alpha,
const float gamma,
const float smoothing_factor
) {
CHECK_INPUT(cls_output);
CHECK_INPUT(cls_targets_at_level);
CHECK_INPUT(num_positives_sum);
return focal_loss_forward_cuda(
cls_output,
cls_targets_at_level,
num_positives_sum,
num_real_classes,
alpha,
gamma,
smoothing_factor);
}
at::Tensor focal_loss_backward(
const at::Tensor &grad_output,
const at::Tensor &partial_grad,
const at::Tensor &num_positives_sum
) {
CHECK_INPUT(grad_output);
CHECK_INPUT(partial_grad);
return focal_loss_backward_cuda(grad_output, partial_grad, num_positives_sum);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &focal_loss_forward,
"Focal loss calculation forward (CUDA)");
m.def("backward", &focal_loss_backward,
"Focal loss calculation backward (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#define ASSERT_UINT4_ALIGNED(PTR) \
TORCH_INTERNAL_ASSERT(is_aligned<uint4>(PTR), "Tensor " #PTR " is not uint4 aligned")
template <class T> bool is_aligned(const void *ptr) noexcept {
auto iptr = reinterpret_cast<std::uintptr_t>(ptr);
return !(iptr % alignof(T));
}
template <bool SMOOTHING, int ILP, typename scalar_t, typename labelscalar_t,
typename accscalar_t, typename outscalar_t>
__global__ void focal_loss_forward_cuda_kernel(
outscalar_t *loss, scalar_t *partial_grad,
const scalar_t *__restrict__ cls_output,
const labelscalar_t *__restrict__ cls_targets_at_level,
const float *__restrict__ num_positives_sum, const int64_t num_examples,
const int64_t num_classes, const int64_t num_real_classes,
const float alpha, const float gamma, const float smoothing_factor) {
extern __shared__ unsigned char shm[];
accscalar_t *loss_shm = reinterpret_cast<accscalar_t *>(shm);
loss_shm[threadIdx.x] = 0;
accscalar_t loss_acc = 0;
accscalar_t one = accscalar_t(1.0);
accscalar_t K = accscalar_t(2.0);
accscalar_t normalizer = one / static_cast<accscalar_t>(num_positives_sum[0]);
accscalar_t nn_norm, np_norm, pn_norm, pp_norm;
// *_norm is used for label smoothing only
if (SMOOTHING) {
nn_norm = one - smoothing_factor / K;
np_norm = smoothing_factor / K;
pn_norm = smoothing_factor - smoothing_factor / K;
pp_norm = one - smoothing_factor + smoothing_factor / K;
}
uint4 p_vec, grad_vec;
// Accumulate loss on each thread
for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x) * ILP;
i < num_examples * num_classes; i += gridDim.x * blockDim.x * ILP) {
int64_t idy = i / num_classes;
labelscalar_t y = cls_targets_at_level[idy];
int64_t base_yid = i % num_classes;
int64_t pos_idx = idy * num_classes + y;
p_vec = *(uint4 *)&cls_output[i];
// Skip ignored matches
if (y == -2) {
#pragma unroll
for (int j = 0; j < ILP; j++) {
*((scalar_t *)(&grad_vec) + j) = 0;
}
*(uint4 *)&partial_grad[i] = grad_vec;
continue;
}
#pragma unroll
for (int j = 0; j < ILP; j++) {
// Skip the pad classes
if (base_yid + j >= num_real_classes) {
*((scalar_t *)(&grad_vec) + j) = 0;
continue;
}
accscalar_t p = static_cast<accscalar_t>(*((scalar_t *)(&p_vec) + j));
accscalar_t exp_np = ::exp(-p);
accscalar_t exp_pp = ::exp(p);
accscalar_t sigma = one / (one + exp_np);
accscalar_t logee = (p >= 0) ? exp_np : exp_pp;
accscalar_t addee = (p >= 0) ? 0 : -p;
accscalar_t off_a = addee + ::log(one + logee);
// Negative matches
accscalar_t base = SMOOTHING ? nn_norm * p : p;
accscalar_t off_b = (SMOOTHING ? np_norm : 0) - sigma;
accscalar_t coeff_f1 = one - alpha;
accscalar_t coeff_f2 = sigma;
accscalar_t coeff_b1 = gamma;
accscalar_t coeff_b2 = one - sigma;
// Positive matches
if (y >= 0 && (i + j == pos_idx)) {
base = SMOOTHING ? pn_norm * p : 0;
off_b = (SMOOTHING ? pp_norm : one) - sigma;
coeff_f1 = alpha;
coeff_f2 = one - sigma;
coeff_b1 = -gamma;
coeff_b2 = sigma;
}
accscalar_t coeff_f = coeff_f1 * ::pow(coeff_f2, gamma);
accscalar_t coeff_b = coeff_b1 * coeff_b2;
accscalar_t loss_t = coeff_f * (base + off_a);
accscalar_t grad = coeff_f * (coeff_b * (base + off_a) - off_b);
// Delay the normalize of partial gradient by num_positives_sum to back
// propagation because scalar_t reduces precision. Focal loss is very
// sensitive to the small gradient. No worry on overflow here since
// gradient has relative smaller range than input.
loss_acc += loss_t;
*((scalar_t *)(&grad_vec) + j) = static_cast<scalar_t>(grad);
}
// This can't ensure to generate stg.128 and may be two stg.64.
*(uint4 *)&partial_grad[i] = grad_vec;
}
loss_shm[threadIdx.x] = loss_acc;
// Intra-CTA reduction
__syncthreads();
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
loss_shm[threadIdx.x] += loss_shm[threadIdx.x + s];
}
__syncthreads();
}
// Inter-CTA reduction
if (threadIdx.x == 0) {
loss_acc = loss_shm[0] * normalizer;
atomicAdd(loss, loss_acc);
}
}
template <int ILP, typename scalar_t, typename accscalar_t,
typename outscalar_t>
__global__ void focal_loss_backward_cuda_kernel(
scalar_t *partial_grad, const outscalar_t *__restrict__ grad_output,
const float *__restrict__ num_positives_sum, const uint64_t numel) {
int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * ILP;
accscalar_t normalizer = static_cast<accscalar_t>(grad_output[0]) /
static_cast<accscalar_t>(num_positives_sum[0]);
// The input is enforced to pad to use vector load, thus there's no need to
// check whether the last element of ILP can out of bound.
if (idx >= numel)
return;
uint4 grad_vec;
grad_vec = *(uint4 *)&partial_grad[idx];
#pragma unroll(ILP)
for (int i = 0; i < ILP; i++) {
auto grad = static_cast<accscalar_t>(*((scalar_t *)(&grad_vec) + i));
grad *= normalizer;
*((scalar_t *)(&grad_vec) + i) = static_cast<scalar_t>(grad);
}
*(uint4 *)&partial_grad[idx] = grad_vec;
}
std::vector<at::Tensor> focal_loss_forward_cuda(
const at::Tensor &cls_output, const at::Tensor &cls_targets_at_level,
const at::Tensor &num_positives_sum, const int64_t num_real_classes,
const float alpha, const float gamma, const float smoothing_factor) {
// Checks required for correctness
TORCH_INTERNAL_ASSERT(cls_output.size(-1) >= num_real_classes,
"Incorrect number of real classes.");
TORCH_INTERNAL_ASSERT(cls_targets_at_level.scalar_type() == at::kLong,
"Invalid label type.");
TORCH_INTERNAL_ASSERT(
(num_positives_sum.numel() == 1) &&
(num_positives_sum.scalar_type() == at::kFloat),
"Expect num_positives_sum to be a float32 tensor with only one element.");
TORCH_INTERNAL_ASSERT(cls_output.dim() == cls_targets_at_level.dim() + 1,
"Mis-matched dimensions between class output and label.");
for (int64_t i = 0; i < cls_targets_at_level.dim(); i++)
TORCH_INTERNAL_ASSERT(cls_output.size(i) == cls_targets_at_level.size(i),
"Mis-matched shape between class output and label.");
// Checks required for better performance
const int ILP = sizeof(uint4) / cls_output.element_size();
ASSERT_UINT4_ALIGNED(cls_output.data_ptr());
TORCH_INTERNAL_ASSERT(cls_output.size(-1) % ILP == 0,
"Pad number of classes first to take advantage of 128 bit load.");
TORCH_INTERNAL_ASSERT(num_real_classes >= ILP, "Too few classes.");
int64_t num_classes = cls_output.size(-1);
int64_t num_examples = cls_output.numel() / num_classes;
at::Tensor loss = at::zeros({}, cls_output.options().dtype(at::kFloat));
// Compute the incompelete gradient during fprop since most of the heavy
// functions of bprop are the same as fprop, thus trade memory for compute
// helps with focal loss.
at::Tensor partial_grad = at::empty_like(cls_output);
// The grid contains 2 CTA per SM, each CTA loop on input with stride till the
// last item.
cudaDeviceProp props;
cudaGetDeviceProperties(&props, at::cuda::current_device());
dim3 block(512);
dim3 grid(2 * props.multiProcessorCount);
// Specialize on label smoothing or not to reduce redundant operations
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (smoothing_factor == 0.0f) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
cls_output.scalar_type(), "focal_loss_fprop", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
using labelscalar_t = int64_t;
using outscalar_t = float;
const int ILP = sizeof(uint4) / sizeof(scalar_t);
focal_loss_forward_cuda_kernel<false, ILP, scalar_t, labelscalar_t,
accscalar_t, outscalar_t>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
loss.data_ptr<outscalar_t>(),
partial_grad.data_ptr<scalar_t>(),
cls_output.data_ptr<scalar_t>(),
cls_targets_at_level.data_ptr<labelscalar_t>(),
num_positives_sum.data_ptr<float>(), num_examples,
num_classes, num_real_classes, alpha, gamma,
smoothing_factor);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
cls_output.scalar_type(), "focal_loss_fprop", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
using labelscalar_t = int64_t;
using outscalar_t = float;
const int ILP = sizeof(uint4) / sizeof(scalar_t);
focal_loss_forward_cuda_kernel<true, ILP, scalar_t, labelscalar_t,
accscalar_t, outscalar_t>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
loss.data_ptr<outscalar_t>(),
partial_grad.data_ptr<scalar_t>(),
cls_output.data_ptr<scalar_t>(),
cls_targets_at_level.data_ptr<labelscalar_t>(),
num_positives_sum.data_ptr<float>(), num_examples,
num_classes, num_real_classes, alpha, gamma,
smoothing_factor);
});
}
AT_CUDA_CHECK(cudaGetLastError());
return {loss, partial_grad};
}
at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output,
const at::Tensor &partial_grad,
const at::Tensor &num_positives_sum) {
// Each thread process ILP elements
const int ILP = sizeof(uint4) / partial_grad.element_size();
dim3 block(512);
dim3 grid((partial_grad.numel() + block.x * ILP - 1) / (block.x * ILP));
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
partial_grad.scalar_type(), "focal_loss_bprop", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
using outscalar_t = float;
const int ILP = sizeof(uint4) / sizeof(scalar_t);
focal_loss_backward_cuda_kernel<ILP, scalar_t, accscalar_t, outscalar_t>
<<<grid, block, 0, stream>>>(partial_grad.data_ptr<scalar_t>(),
grad_output.data_ptr<outscalar_t>(),
num_positives_sum.data_ptr<float>(),
partial_grad.numel());
});
AT_CUDA_CHECK(cudaGetLastError());
return partial_grad;
}
try:
import torch
import focal_loss_cuda
from .focal_loss import focal_loss
del torch
del focal_loss_cuda
del focal_loss
except ImportError as err:
print("apex was installed without --focal_loss flag, apex.contrib.focal_loss is not available")
import torch
import focal_loss_cuda
class FocalLoss(torch.autograd.Function):
@staticmethod
def forward(
ctx,
cls_output,
cls_targets_at_level,
num_positives_sum,
num_real_classes,
alpha,
gamma,
label_smoothing=0.0,
):
loss, partial_grad = focal_loss_cuda.forward(
cls_output,
cls_targets_at_level,
num_positives_sum,
num_real_classes,
alpha,
gamma,
label_smoothing,
)
ctx.save_for_backward(partial_grad, num_positives_sum)
return loss
@staticmethod
def backward(ctx, grad_loss):
partial_grad, num_positives_sum = ctx.saved_tensors
# The backward kernel is actually in-place to save memory space,
# partial_grad and grad_input are the same tensor.
grad_input = focal_loss_cuda.backward(grad_loss, partial_grad, num_positives_sum)
return grad_input, None, None, None, None, None, None
def focal_loss(
cls_output: torch.Tensor,
cls_targets_at_level: torch.Tensor,
num_positive_sum: torch.Tensor,
num_real_classes: int,
alpha: float,
gamma: float,
label_smoothing: float = 0.0,
) -> torch.Tensor:
"""Fused focal loss function."""
return FocalLoss.apply(
cls_output,
cls_targets_at_level,
num_positive_sum,
num_real_classes,
alpha,
gamma,
label_smoothing,
)
import unittest
import torch
import torch.nn.functional as F
reference_available = True
try:
from torchvision.ops.focal_loss import sigmoid_focal_loss
except ImportError:
reference_available = False
from apex.contrib.focal_loss import focal_loss
class FocalLossTest(unittest.TestCase):
N_SAMPLES = 12
N_CLASSES = 8
ALPHA = 0.24
GAMMA = 2.0
REDUCTION = "sum"
def test_focal_loss(self) -> None:
if not reference_available:
self.skipTest("This test needs `torchvision` for `torchvision.ops.focal_loss.sigmoid_focal_loss`.")
else:
x = torch.randn(FocalLossTest.N_SAMPLES, FocalLossTest.N_CLASSES).cuda()
with torch.no_grad():
x_expected = x.clone()
x_actual = x.clone()
x_expected.requires_grad_()
x_actual.requires_grad_()
classes = torch.randint(0, FocalLossTest.N_CLASSES, (FocalLossTest.N_SAMPLES,)).cuda()
with torch.no_grad():
y = F.one_hot(classes, FocalLossTest.N_CLASSES).float()
expected = sigmoid_focal_loss(
x_expected,
y,
alpha=FocalLossTest.ALPHA,
gamma=FocalLossTest.GAMMA,
reduction=FocalLossTest.REDUCTION,
)
actual = sum([focal_loss.FocalLoss.apply(
x_actual[i:i+1],
classes[i:i+1].long(),
torch.ones([], device="cuda"),
FocalLossTest.N_CLASSES,
FocalLossTest.ALPHA,
FocalLossTest.GAMMA,
0.0,
) for i in range(FocalLossTest.N_SAMPLES)])
# forward parity
torch.testing.assert_close(expected, actual)
expected.backward()
actual.backward()
# grad parity
torch.testing.assert_close(x_expected.grad, x_actual.grad)
if __name__ == "__main__":
torch.manual_seed(42)
unittest.main()
......@@ -392,6 +392,24 @@ if "--xentropy" in sys.argv:
)
)
if "--focal_loss" in sys.argv:
sys.argv.remove("--focal_loss")
raise_if_cuda_home_none("--focal_loss")
ext_modules.append(
CUDAExtension(
name='focal_loss_cuda',
sources=[
'apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp',
'apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu',
],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3', '--use_fast_math', '--ftz=false'] + version_dependent_macros,
},
)
)
if "--deprecated_fused_adam" in sys.argv:
sys.argv.remove("--deprecated_fused_adam")
raise_if_cuda_home_none("--deprecated_fused_adam")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment