Unverified Commit 3d4e08e1 authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

[CUDA] Add multiclass objective for cuda_exp (#5473)

* add multiclass objective for cuda_exp

* remove debug code

* add includes requested by lint checks

* fix compilation failure for cuda with cuda-9.0

* clean code
parent 2e9848c6
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifdef USE_CUDA_EXP
#include "cuda_multiclass_objective.hpp"
#include <string>
#include <vector>
namespace LightGBM {
CUDAMulticlassSoftmax::CUDAMulticlassSoftmax(const Config& config): MulticlassSoftmax(config) {}
CUDAMulticlassSoftmax::CUDAMulticlassSoftmax(const std::vector<std::string>& strs): MulticlassSoftmax(strs) {}
CUDAMulticlassSoftmax::~CUDAMulticlassSoftmax() {}
void CUDAMulticlassSoftmax::Init(const Metadata& metadata, data_size_t num_data) {
MulticlassSoftmax::Init(metadata, num_data);
cuda_label_ = metadata.cuda_metadata()->cuda_label();
cuda_weights_ = metadata.cuda_metadata()->cuda_weights();
cuda_softmax_buffer_.Resize(static_cast<size_t>(num_data) * static_cast<size_t>(num_class_));
SynchronizeCUDADevice(__FILE__, __LINE__);
}
void CUDAMulticlassSoftmax::GetGradients(const double* score, score_t* gradients, score_t* hessians) const {
LaunchGetGradientsKernel(score, gradients, hessians);
SynchronizeCUDADevice(__FILE__, __LINE__);
}
void CUDAMulticlassSoftmax::ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const {
LaunchConvertOutputCUDAKernel(num_data, input, output);
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifdef USE_CUDA_EXP
#include <algorithm>
#include "cuda_multiclass_objective.hpp"
namespace LightGBM {
__device__ void SoftmaxCUDA(double* softmax_buffer, int len) {
double wmax = softmax_buffer[0];
for (int i = 1; i < len; ++i) {
wmax = max(softmax_buffer[i], wmax);
}
double wsum = 0.0f;
for (int i = 0; i < len; ++i) {
softmax_buffer[i] = exp(softmax_buffer[i] - wmax);
wsum += softmax_buffer[i];
}
for (int i = 0; i < len; ++i) {
softmax_buffer[i] /= static_cast<double>(wsum);
}
}
template <bool USE_WEIGHT>
__global__ void GetGradientsKernel_MulticlassSoftmax(
const double* cuda_scores, const label_t* cuda_labels, const label_t* cuda_weights,
const double factor, const int num_class, const data_size_t num_data,
double* cuda_softmax_buffer, score_t* cuda_out_gradients, score_t* cuda_out_hessians) {
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
if (data_index < num_data) {
const data_size_t offset = data_index * num_class;
double* softmax_result = cuda_softmax_buffer + offset;
for (int k = 0; k < num_class; ++k) {
const double point_score = cuda_scores[k * num_data + data_index];
softmax_result[k] = cuda_scores[k * num_data + data_index];
}
SoftmaxCUDA(softmax_result, num_class);
if (!USE_WEIGHT) {
for (int k = 0; k < num_class; ++k) {
const double p = softmax_result[k];
size_t idx = static_cast<size_t>(num_data) * k + data_index;
if (static_cast<int>(cuda_labels[data_index]) == k) {
cuda_out_gradients[idx] = static_cast<score_t>(p - 1.0f);
} else {
cuda_out_gradients[idx] = static_cast<score_t>(p);
}
cuda_out_hessians[idx] = static_cast<score_t>(factor * p * (1.0f - p));
}
} else {
for (int k = 0; k < num_class; ++k) {
const double p = softmax_result[k];
const double weight = cuda_weights[data_index];
size_t idx = static_cast<size_t>(num_data) * k + data_index;
if (static_cast<int>(cuda_labels[data_index]) == k) {
cuda_out_gradients[idx] = static_cast<score_t>((p - 1.0f) * weight);
} else {
cuda_out_gradients[idx] = static_cast<score_t>(p * weight);
}
cuda_out_hessians[idx] = static_cast<score_t>((factor * p * (1.0f - p)) * weight);
}
}
}
}
void CUDAMulticlassSoftmax::LaunchGetGradientsKernel(const double* scores, score_t* gradients, score_t* hessians) const {
const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_MULTICLASS - 1) / GET_GRADIENTS_BLOCK_SIZE_MULTICLASS;
if (cuda_weights_ == nullptr) {
GetGradientsKernel_MulticlassSoftmax<false><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_MULTICLASS>>>(
scores, cuda_label_, cuda_weights_, factor_, num_class_, num_data_,
cuda_softmax_buffer_.RawData(), gradients, hessians);
} else {
GetGradientsKernel_MulticlassSoftmax<true><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_MULTICLASS>>>(
scores, cuda_label_, cuda_weights_, factor_, num_class_, num_data_,
cuda_softmax_buffer_.RawData(), gradients, hessians);
}
}
__global__ void ConvertOutputCUDAKernel_MulticlassSoftmax(
const int num_class, const data_size_t num_data, const double* input, double* cuda_softmax_buffer, double* output) {
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
if (data_index < num_data) {
const data_size_t offset = data_index * num_class;
double* cuda_softmax_buffer_ptr = cuda_softmax_buffer + offset;
for (int class_index = 0; class_index < num_class; ++class_index) {
cuda_softmax_buffer_ptr[class_index] = input[class_index * num_data + data_index];
}
SoftmaxCUDA(cuda_softmax_buffer_ptr, num_class);
for (int class_index = 0; class_index < num_class; ++class_index) {
output[class_index * num_data + data_index] = cuda_softmax_buffer_ptr[class_index];
}
}
}
void CUDAMulticlassSoftmax::LaunchConvertOutputCUDAKernel(
const data_size_t num_data, const double* input, double* output) const {
const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_MULTICLASS - 1) / GET_GRADIENTS_BLOCK_SIZE_MULTICLASS;
ConvertOutputCUDAKernel_MulticlassSoftmax<<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_MULTICLASS>>>(
num_class_, num_data, input, cuda_softmax_buffer_.RawData(), output);
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_OBJECTIVE_CUDA_CUDA_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_CUDA_CUDA_MULTICLASS_OBJECTIVE_HPP_
#ifdef USE_CUDA_EXP
#include <LightGBM/cuda/cuda_objective_function.hpp>
#include <string>
#include <vector>
#include "../multiclass_objective.hpp"
#define GET_GRADIENTS_BLOCK_SIZE_MULTICLASS (1024)
namespace LightGBM {
class CUDAMulticlassSoftmax: public CUDAObjectiveInterface, public MulticlassSoftmax {
public:
explicit CUDAMulticlassSoftmax(const Config& config);
explicit CUDAMulticlassSoftmax(const std::vector<std::string>& strs);
~CUDAMulticlassSoftmax();
void Init(const Metadata& metadata, data_size_t num_data) override;
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override;
void ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const override;
std::function<void(data_size_t, const double*, double*)> GetCUDAConvertOutputFunc() const override {
return [this] (data_size_t num_data, const double* input, double* output) {
ConvertOutputCUDA(num_data, input, output);
};
}
bool IsCUDAObjective() const override { return true; }
private:
void LaunchGetGradientsKernel(const double* scores, score_t* gradients, score_t* hessians) const;
void LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const;
// CUDA memory, held by other objects
const label_t* cuda_label_;
const label_t* cuda_weights_;
// CUDA memory, held by this object
CUDAVector<double> cuda_softmax_buffer_;
};
} // namespace LightGBM
#endif // USE_CUDA_EXP
#endif // LIGHTGBM_OBJECTIVE_CUDA_CUDA_MULTICLASS_OBJECTIVE_HPP_
...@@ -165,7 +165,7 @@ class MulticlassSoftmax: public ObjectiveFunction { ...@@ -165,7 +165,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
} }
} }
private: protected:
double factor_; double factor_;
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "xentropy_objective.hpp" #include "xentropy_objective.hpp"
#include "cuda/cuda_binary_objective.hpp" #include "cuda/cuda_binary_objective.hpp"
#include "cuda/cuda_multiclass_objective.hpp"
#include "cuda/cuda_rank_objective.hpp" #include "cuda/cuda_rank_objective.hpp"
#include "cuda/cuda_regression_objective.hpp" #include "cuda/cuda_regression_objective.hpp"
...@@ -40,8 +41,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& ...@@ -40,8 +41,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
} else if (type == std::string("rank_xendcg")) { } else if (type == std::string("rank_xendcg")) {
return new CUDARankXENDCG(config); return new CUDARankXENDCG(config);
} else if (type == std::string("multiclass")) { } else if (type == std::string("multiclass")) {
Log::Warning("Objective multiclass is not implemented in cuda_exp version. Fall back to boosting on CPU."); return new CUDAMulticlassSoftmax(config);
return new MulticlassSoftmax(config);
} else if (type == std::string("multiclassova")) { } else if (type == std::string("multiclassova")) {
Log::Warning("Objective multiclassova is not implemented in cuda_exp version. Fall back to boosting on CPU."); Log::Warning("Objective multiclassova is not implemented in cuda_exp version. Fall back to boosting on CPU.");
return new MulticlassOVA(config); return new MulticlassOVA(config);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment