Unverified Commit d78d273a authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add loss multiclass log per pixel weighted cuda (#2194)



* add cuda implementation for loss_multiclass_log_per_pixel_weighted

* add test for cuda and cpu implementations

* fix comment

* move weighted label to its own file

* Update path in doc
Co-authored-by: default avatarDavis E. King <davis685@gmail.com>
parent 4125a7bb
......@@ -8,6 +8,7 @@
#include "tensor.h"
#include "../geometry/rectangle.h"
#include "../dnn/misc.h"
namespace dlib
{
......@@ -521,6 +522,82 @@ namespace dlib
// -----------------------------------------------------------------------------------
class compute_loss_multiclass_log_per_pixel_weighted
{
/*! The point of this class is to compute the loss for loss_multiclass_log_per_pixel_weighted_
on the cpu to provide an analogous implementation of the cuda version
!*/
public:
compute_loss_multiclass_log_per_pixel_weighted(
)
{
}
template <
typename const_label_iterator
>
void operator()(
const_label_iterator truth,
const tensor& output_tensor,
tensor& grad,
double& loss
) const
{
softmax(grad, output_tensor);
// The loss we output is the weighted average loss over the mini-batch, and also over each element of the matrix output.
const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.nr() * output_tensor.nc());
loss = 0;
float* const g = grad.host();
for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth)
{
for (long r = 0; r < output_tensor.nr(); ++r)
{
for (long c = 0; c < output_tensor.nc(); ++c)
{
const weighted_label<uint16_t>& weighted_label = truth->operator()(r, c);
const uint16_t y = weighted_label.label;
const float weight = weighted_label.weight;
// The network must produce a number of outputs that is equal to the number
// of labels when using this type of loss.
DLIB_CASSERT(static_cast<long>(y) < output_tensor.k() || weight == 0.f,
"y: " << y << ", output_tensor.k(): " << output_tensor.k());
for (long k = 0; k < output_tensor.k(); ++k)
{
const size_t idx = tensor_index(output_tensor, i, k, r, c);
if (k == y)
{
loss += weight*scale*-safe_log(g[idx]);
g[idx] = weight*scale*(g[idx] - 1);
}
else
{
g[idx] = weight*scale*g[idx];
}
}
}
}
}
}
private:
template <typename T>
T safe_log(T input, T epsilon = 1e-10) const
{
// Prevent trying to calculate the logarithm of a very small number (let alone zero)
return std::log(std::max(input, epsilon));
}
static size_t tensor_index(const tensor& t, long sample, long k, long row, long column)
{
// See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38
return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column;
}
};
// -----------------------------------------------------------------------------------
class compute_loss_mean_squared_per_channel_and_pixel
{
/*! The point of this class is to compute the loss for loss_mean_squared_per_channel_and_pixel_
......
......@@ -1833,6 +1833,30 @@ namespace dlib
warp_reduce_atomic_add(*loss_out, loss);
}
__global__ void _cuda_compute_loss_multiclass_log_per_pixel_weighted(float* loss_out, float* g, const uint16_t* truth, size_t n, size_t plane_size, size_t sample_size, size_t nk, const float* weights, const float scale)
{
float loss = 0;
for(auto i : grid_stride_range(0, n))
{
const size_t k = (i/plane_size)%nk;
const size_t idx = (i%plane_size) + plane_size*(i/sample_size);
const size_t y = truth[idx];
const float weight = weights[idx];
if (k == y)
{
loss -= weight*cuda_safe_log(g[i]);
g[i] = weight*scale*(g[i] - 1);
}
else
{
g[i] = weight*scale*g[i];
}
}
warp_reduce_atomic_add(*loss_out, loss);
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_compute_loss_mean_squared_per_channel_and_pixel(float* loss_out, float* g, const float* truth, const float* out_data, size_t n, const float scale)
......@@ -1897,6 +1921,30 @@ namespace dlib
loss = scale*floss;
}
void compute_loss_multiclass_log_per_pixel_weighted::
do_work(
cuda_data_ptr<float> loss_work_buffer,
cuda_data_ptr<const uint16_t> truth_buffer,
cuda_data_ptr<const float> weights_buffer,
const tensor& subnetwork_output,
tensor& gradient,
double& loss
)
{
CHECK_CUDA(cudaMemset(loss_work_buffer, 0, sizeof(float)));
softmax(gradient, subnetwork_output);
// The loss we output is the average loss over the mini-batch, and also over each element of the matrix output.
const double scale = 1.0 / (subnetwork_output.num_samples() * subnetwork_output.nr() * subnetwork_output.nc());
launch_kernel(_cuda_compute_loss_multiclass_log_per_pixel_weighted, max_jobs(gradient.size()),
loss_work_buffer.data(), gradient.device(), truth_buffer.data(), gradient.size(), gradient.nr()*gradient.nc(), gradient.nr()*gradient.nc()*gradient.k(), gradient.k(), weights_buffer.data(), scale);
float floss;
dlib::cuda::memcpy(&floss, loss_work_buffer);
loss = scale*floss;
}
void compute_loss_mean_squared_per_channel_and_pixel::
do_work(
cuda_data_ptr<float> loss_work_buffer,
......
......@@ -6,6 +6,7 @@
#include "tensor.h"
#include "../geometry/rectangle.h"
#include "../dnn/misc.h"
namespace dlib
{
......@@ -559,6 +560,82 @@ namespace dlib
mutable cuda_data_void_ptr buf;
};
// ----------------------------------------------------------------------------------------
class compute_loss_multiclass_log_per_pixel_weighted
{
/*!
The point of this class is to compute the loss computed by
loss_multiclass_log_per_pixel_weighted_, but to do so with CUDA.
!*/
public:
compute_loss_multiclass_log_per_pixel_weighted(
)
{
}
template <
typename const_label_iterator
>
void operator() (
const_label_iterator truth,
const tensor& subnetwork_output,
tensor& gradient,
double& loss
) const
{
const auto image_size = subnetwork_output.nr()*subnetwork_output.nc();
const size_t bytes_per_plane = image_size*sizeof(uint16_t);
const size_t weight_bytes_per_plane = image_size*sizeof(float);
matrix<uint16_t> labels(truth->nr(), truth->nc());
matrix<float> weights(truth->nr(), truth->nc());
// Allocate a cuda buffer to store all the truth images and also one float
// for the scalar loss output.
buf = device_global_buffer(subnetwork_output.num_samples()*(bytes_per_plane + weight_bytes_per_plane) + sizeof(float));
cuda_data_ptr<float> loss_buf = static_pointer_cast<float>(buf, 1);
buf = buf+sizeof(float);
const auto weights_offset = subnetwork_output.num_samples() * bytes_per_plane;
// copy the truth data into a cuda buffer.
for (long i = 0; i < subnetwork_output.num_samples(); ++i, ++truth)
{
const matrix<weighted_label<uint16_t>>& t = *truth;
DLIB_ASSERT(t.nr() == subnetwork_output.nr());
DLIB_ASSERT(t.nc() == subnetwork_output.nc());
for (long r = 0; r < t.nr(); ++r)
{
for (long c = 0; c < t.nc(); ++c)
{
labels(r, c) = t(r, c).label;
weights(r, c) = t(r, c).weight;
}
}
memcpy(buf + i*bytes_per_plane, &labels(0,0), bytes_per_plane);
memcpy(buf + weights_offset + i*weight_bytes_per_plane, &weights(0, 0), weight_bytes_per_plane);
}
auto truth_buf = static_pointer_cast<const uint16_t>(buf, subnetwork_output.num_samples()*image_size);
buf = buf+weights_offset;
auto weights_buf = static_pointer_cast<const float>(buf, subnetwork_output.num_samples()*image_size);
do_work(loss_buf, truth_buf, weights_buf, subnetwork_output, gradient, loss);
}
private:
static void do_work(
cuda_data_ptr<float> loss_work_buffer,
cuda_data_ptr<const uint16_t> truth_buffer,
cuda_data_ptr<const float> weights_buffer,
const tensor& subnetwork_output,
tensor& gradient,
double& loss
);
mutable cuda_data_void_ptr buf;
};
// ----------------------------------------------------------------------------------------
class compute_loss_mean_squared_per_channel_and_pixel
......
......@@ -6,6 +6,7 @@
#include "loss_abstract.h"
#include "core.h"
#include "utilities.h"
#include "misc.h"
#include "../matrix.h"
#include "../cuda/tensor_tools.h"
#include "../geometry.h"
......@@ -368,22 +369,6 @@ namespace dlib
template <typename SUBNET>
using loss_multiclass_log = add_loss_layer<loss_multiclass_log_, SUBNET>;
// ----------------------------------------------------------------------------------------
template <typename label_type>
struct weighted_label
{
weighted_label()
{}
weighted_label(label_type label, float weight = 1.f)
: label(label), weight(weight)
{}
label_type label{};
float weight = 1.f;
};
// ----------------------------------------------------------------------------------------
class loss_multiclass_log_weighted_
......@@ -3139,41 +3124,12 @@ namespace dlib
"output size = " << output_tensor.nr() << " x " << output_tensor.nc());
}
tt::softmax(grad, output_tensor);
// The loss we output is the weighted average loss over the mini-batch, and also over each element of the matrix output.
const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.nr() * output_tensor.nc());
double loss = 0;
float* const g = grad.host();
for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth)
{
for (long r = 0; r < output_tensor.nr(); ++r)
{
for (long c = 0; c < output_tensor.nc(); ++c)
{
const weighted_label& weighted_label = truth->operator()(r, c);
const uint16_t y = weighted_label.label;
const float weight = weighted_label.weight;
// The network must produce a number of outputs that is equal to the number
// of labels when using this type of loss.
DLIB_CASSERT(static_cast<long>(y) < output_tensor.k() || weight == 0.f,
"y: " << y << ", output_tensor.k(): " << output_tensor.k());
for (long k = 0; k < output_tensor.k(); ++k)
{
const size_t idx = tensor_index(output_tensor, i, k, r, c);
if (k == y)
{
loss += weight*scale*-safe_log(g[idx]);
g[idx] = weight*scale*(g[idx] - 1);
}
else
{
g[idx] = weight*scale*g[idx];
}
}
}
}
}
double loss;
#ifdef DLIB_USE_CUDA
cuda_compute(truth, output_tensor, grad, loss);
#else
cpu_compute(truth, output_tensor, grad, loss);
#endif
return loss;
}
......@@ -3207,6 +3163,11 @@ namespace dlib
// See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38
return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column;
}
#ifdef DLIB_USE_CUDA
cuda::compute_loss_multiclass_log_per_pixel_weighted cuda_compute;
#else
cpu::compute_loss_multiclass_log_per_pixel_weighted cpu_compute;
#endif
};
......
......@@ -379,11 +379,12 @@ namespace dlib
This object represents the truth label of a single sample, together with
an associated weight (the higher the weight, the more emphasis the
corresponding sample is given during the training).
For technical reasons, it is defined in misc.h
This object is used in the following loss layers:
- loss_multiclass_log_weighted_ with unsigned long as label_type
- loss_multiclass_log_per_pixel_weighted_ with uint16_t as label_type,
since, in semantic segmentation, 65536 classes ought to be enough for
anybody.
anybody.
!*/
weighted_label()
{}
......
// Copyright (C) 2020 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_MISC_h
#define DLIB_DNn_MISC_h
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <typename label_type>
struct weighted_label
{
weighted_label()
{}
weighted_label(label_type label, float weight = 1.f)
: label(label), weight(weight)
{}
label_type label{};
float weight = 1.f;
};
}
#endif // DLIB_DNn_MISC_h
......@@ -3311,6 +3311,18 @@ namespace
DLIB_TEST_MSG(num_weighted_class > num_not_weighted_class,
"The weighted class (" << weighted_class << ") does not dominate: "
<< num_weighted_class << " <= " << num_not_weighted_class);
#if DLIB_USE_CUDA
cuda::compute_loss_multiclass_log_per_pixel_weighted cuda_compute;
cpu::compute_loss_multiclass_log_per_pixel_weighted cpu_compute;
double cuda_loss, cpu_loss;
const tensor& output_tensor = net.subnet().get_output();
tensor& grad = net.subnet().get_gradient_input();
cuda_compute(y_weighted.begin(), output_tensor, grad, cuda_loss);
cpu_compute(y_weighted.begin(), output_tensor, grad, cpu_loss);
const auto err = abs(cuda_loss - cpu_loss) / cpu_loss;
DLIB_TEST_MSG(err < 1e-6, "multi class log per pixel weighted cuda and cpu losses differ");
#endif
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment