Unverified Commit 2b8fe8b4 authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

[CUDA] Add binary objective for cuda_exp (#5425)

* add binary objective for cuda_exp

* include <string> and <vector>

* exchange include ordering

* fix length of score to copy in evaluation

* fix EvalOneMetric

* fix cuda binary objective and prediction when boosting on gpu

* Add white space

* fix BoostFromScore for CUDABinaryLogloss

update log in test_register_logger

* include <algorithm>

* simplify shared memory buffer
parent 81d4d4d1
...@@ -398,6 +398,8 @@ endif() ...@@ -398,6 +398,8 @@ endif()
if(USE_CUDA_EXP) if(USE_CUDA_EXP)
src/boosting/cuda/*.cpp src/boosting/cuda/*.cpp
src/boosting/cuda/*.cu src/boosting/cuda/*.cu
src/objective/cuda/*.cpp
src/objective/cuda/*.cu
src/treelearner/cuda/*.cpp src/treelearner/cuda/*.cpp
src/treelearner/cuda/*.cu src/treelearner/cuda/*.cu
src/io/cuda/*.cu src/io/cuda/*.cu
......
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifndef LIGHTGBM_OBJECTIVE_CUDA_CUDA_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_CUDA_CUDA_OBJECTIVE_HPP_
#ifdef USE_CUDA_EXP
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/meta.h>
namespace LightGBM {
class CUDAObjectiveInterface {
public:
virtual void ConvertOutputCUDA(const data_size_t /*num_data*/, const double* /*input*/, double* /*output*/) const {}
};
} // namespace LightGBM
#endif // USE_CUDA_EXP
#endif // LIGHTGBM_OBJECTIVE_CUDA_CUDA_OBJECTIVE_HPP_
...@@ -93,6 +93,15 @@ class ObjectiveFunction { ...@@ -93,6 +93,15 @@ class ObjectiveFunction {
* \brief Whether boosting is done on CUDA * \brief Whether boosting is done on CUDA
*/ */
virtual bool IsCUDAObjective() const { return false; } virtual bool IsCUDAObjective() const { return false; }
#ifdef USE_CUDA_EXP
/*!
* \brief Get output convert function for CUDA version
*/
virtual std::function<void(data_size_t, const double*, double*)> GetCUDAConvertOutputFunc() const {
return [] (data_size_t, const double*, double*) {};
}
#endif // USE_CUDA_EXP
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -607,7 +607,11 @@ void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) { ...@@ -607,7 +607,11 @@ void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
} }
} }
std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score) const { #ifdef USE_CUDA_EXP
std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score, const data_size_t num_data) const {
#else
std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score, const data_size_t /*num_data*/) const {
#endif // USE_CUDA_EXP
#ifdef USE_CUDA_EXP #ifdef USE_CUDA_EXP
const bool evaluation_on_cuda = metric->IsCUDAMetric(); const bool evaluation_on_cuda = metric->IsCUDAMetric();
if ((boosting_on_gpu_ && evaluation_on_cuda) || (!boosting_on_gpu_ && !evaluation_on_cuda)) { if ((boosting_on_gpu_ && evaluation_on_cuda) || (!boosting_on_gpu_ && !evaluation_on_cuda)) {
...@@ -615,14 +619,14 @@ std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* scor ...@@ -615,14 +619,14 @@ std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* scor
return metric->Eval(score, objective_function_); return metric->Eval(score, objective_function_);
#ifdef USE_CUDA_EXP #ifdef USE_CUDA_EXP
} else if (boosting_on_gpu_ && !evaluation_on_cuda) { } else if (boosting_on_gpu_ && !evaluation_on_cuda) {
const size_t total_size = static_cast<size_t>(num_data_) * static_cast<size_t>(num_tree_per_iteration_); const size_t total_size = static_cast<size_t>(num_data) * static_cast<size_t>(num_tree_per_iteration_);
if (total_size > host_score_.size()) { if (total_size > host_score_.size()) {
host_score_.resize(total_size, 0.0f); host_score_.resize(total_size, 0.0f);
} }
CopyFromCUDADeviceToHost<double>(host_score_.data(), score, total_size, __FILE__, __LINE__); CopyFromCUDADeviceToHost<double>(host_score_.data(), score, total_size, __FILE__, __LINE__);
return metric->Eval(host_score_.data(), objective_function_); return metric->Eval(host_score_.data(), objective_function_);
} else { } else {
const size_t total_size = static_cast<size_t>(num_data_) * static_cast<size_t>(num_tree_per_iteration_); const size_t total_size = static_cast<size_t>(num_data) * static_cast<size_t>(num_tree_per_iteration_);
if (total_size > cuda_score_.Size()) { if (total_size > cuda_score_.Size()) {
cuda_score_.Resize(total_size); cuda_score_.Resize(total_size);
} }
...@@ -641,7 +645,7 @@ std::string GBDT::OutputMetric(int iter) { ...@@ -641,7 +645,7 @@ std::string GBDT::OutputMetric(int iter) {
if (need_output) { if (need_output) {
for (auto& sub_metric : training_metrics_) { for (auto& sub_metric : training_metrics_) {
auto name = sub_metric->GetName(); auto name = sub_metric->GetName();
auto scores = EvalOneMetric(sub_metric, train_score_updater_->score()); auto scores = EvalOneMetric(sub_metric, train_score_updater_->score(), train_score_updater_->num_data());
for (size_t k = 0; k < name.size(); ++k) { for (size_t k = 0; k < name.size(); ++k) {
std::stringstream tmp_buf; std::stringstream tmp_buf;
tmp_buf << "Iteration:" << iter tmp_buf << "Iteration:" << iter
...@@ -658,7 +662,7 @@ std::string GBDT::OutputMetric(int iter) { ...@@ -658,7 +662,7 @@ std::string GBDT::OutputMetric(int iter) {
if (need_output || early_stopping_round_ > 0) { if (need_output || early_stopping_round_ > 0) {
for (size_t i = 0; i < valid_metrics_.size(); ++i) { for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) { for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
auto test_scores = EvalOneMetric(valid_metrics_[i][j], valid_score_updater_[i]->score()); auto test_scores = EvalOneMetric(valid_metrics_[i][j], valid_score_updater_[i]->score(), valid_score_updater_[i]->num_data());
auto name = valid_metrics_[i][j]->GetName(); auto name = valid_metrics_[i][j]->GetName();
for (size_t k = 0; k < name.size(); ++k) { for (size_t k = 0; k < name.size(); ++k) {
std::stringstream tmp_buf; std::stringstream tmp_buf;
...@@ -698,7 +702,7 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const { ...@@ -698,7 +702,7 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const {
std::vector<double> ret; std::vector<double> ret;
if (data_idx == 0) { if (data_idx == 0) {
for (auto& sub_metric : training_metrics_) { for (auto& sub_metric : training_metrics_) {
auto scores = EvalOneMetric(sub_metric, train_score_updater_->score()); auto scores = EvalOneMetric(sub_metric, train_score_updater_->score(), train_score_updater_->num_data());
for (auto score : scores) { for (auto score : scores) {
ret.push_back(score); ret.push_back(score);
} }
...@@ -706,7 +710,7 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const { ...@@ -706,7 +710,7 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const {
} else { } else {
auto used_idx = data_idx - 1; auto used_idx = data_idx - 1;
for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) { for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
auto test_scores = EvalOneMetric(valid_metrics_[used_idx][j], valid_score_updater_[used_idx]->score()); auto test_scores = EvalOneMetric(valid_metrics_[used_idx][j], valid_score_updater_[used_idx]->score(), valid_score_updater_[used_idx]->num_data());
for (auto score : test_scores) { for (auto score : test_scores) {
ret.push_back(score); ret.push_back(score);
} }
...@@ -760,6 +764,14 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { ...@@ -760,6 +764,14 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
num_data = valid_score_updater_[used_idx]->num_data(); num_data = valid_score_updater_[used_idx]->num_data();
*out_len = static_cast<int64_t>(num_data) * num_class_; *out_len = static_cast<int64_t>(num_data) * num_class_;
} }
#ifdef USE_CUDA_EXP
std::vector<double> host_raw_scores;
if (boosting_on_gpu_) {
host_raw_scores.resize(static_cast<size_t>(*out_len), 0.0);
CopyFromCUDADeviceToHost<double>(host_raw_scores.data(), raw_scores, static_cast<size_t>(*out_len), __FILE__, __LINE__);
raw_scores = host_raw_scores.data();
}
#endif // USE_CUDA_EXP
if (objective_function_ != nullptr) { if (objective_function_ != nullptr) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
......
...@@ -443,7 +443,7 @@ class GBDT : public GBDTBase { ...@@ -443,7 +443,7 @@ class GBDT : public GBDTBase {
* \brief eval results for one metric * \brief eval results for one metric
*/ */
virtual std::vector<double> EvalOneMetric(const Metric* metric, const double* score) const; virtual std::vector<double> EvalOneMetric(const Metric* metric, const double* score, const data_size_t num_data) const;
/*! /*!
* \brief Print metric result of current iteration * \brief Print metric result of current iteration
......
...@@ -189,7 +189,7 @@ class BinaryLogloss: public ObjectiveFunction { ...@@ -189,7 +189,7 @@ class BinaryLogloss: public ObjectiveFunction {
data_size_t NumPositiveData() const override { return num_pos_data_; } data_size_t NumPositiveData() const override { return num_pos_data_; }
private: protected:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of positive samples */ /*! \brief Number of positive samples */
......
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA_EXP
#include "cuda_binary_objective.hpp"
#include <string>
#include <vector>
namespace LightGBM {
CUDABinaryLogloss::CUDABinaryLogloss(const Config& config):
BinaryLogloss(config), ova_class_id_(-1) {
cuda_label_ = nullptr;
cuda_ova_label_ = nullptr;
cuda_weights_ = nullptr;
cuda_boost_from_score_ = nullptr;
cuda_sum_weights_ = nullptr;
cuda_label_weights_ = nullptr;
}
CUDABinaryLogloss::CUDABinaryLogloss(const Config& config, const int ova_class_id):
BinaryLogloss(config, [ova_class_id](label_t label) { return static_cast<int>(label) == ova_class_id; }), ova_class_id_(ova_class_id) {}
CUDABinaryLogloss::CUDABinaryLogloss(const std::vector<std::string>& strs): BinaryLogloss(strs) {}
CUDABinaryLogloss::~CUDABinaryLogloss() {
DeallocateCUDAMemory<label_t>(&cuda_ova_label_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_label_weights_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_boost_from_score_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_sum_weights_, __FILE__, __LINE__);
}
void CUDABinaryLogloss::Init(const Metadata& metadata, data_size_t num_data) {
BinaryLogloss::Init(metadata, num_data);
if (ova_class_id_ == -1) {
cuda_label_ = metadata.cuda_metadata()->cuda_label();
cuda_ova_label_ = nullptr;
} else {
InitCUDAMemoryFromHostMemory<label_t>(&cuda_ova_label_, metadata.cuda_metadata()->cuda_label(), static_cast<size_t>(num_data), __FILE__, __LINE__);
LaunchResetOVACUDALableKernel();
cuda_label_ = cuda_ova_label_;
}
cuda_weights_ = metadata.cuda_metadata()->cuda_weights();
AllocateCUDAMemory<double>(&cuda_boost_from_score_, 1, __FILE__, __LINE__);
SetCUDAMemory<double>(cuda_boost_from_score_, 0, 1, __FILE__, __LINE__);
AllocateCUDAMemory<double>(&cuda_sum_weights_, 1, __FILE__, __LINE__);
SetCUDAMemory<double>(cuda_sum_weights_, 0, 1, __FILE__, __LINE__);
if (label_weights_[0] != 1.0f || label_weights_[1] != 1.0f) {
InitCUDAMemoryFromHostMemory<double>(&cuda_label_weights_, label_weights_, 2, __FILE__, __LINE__);
} else {
cuda_label_weights_ = nullptr;
}
}
void CUDABinaryLogloss::GetGradients(const double* scores, score_t* gradients, score_t* hessians) const {
LaunchGetGradientsKernel(scores, gradients, hessians);
SynchronizeCUDADevice(__FILE__, __LINE__);
}
double CUDABinaryLogloss::BoostFromScore(int) const {
LaunchBoostFromScoreKernel();
SynchronizeCUDADevice(__FILE__, __LINE__);
double boost_from_score = 0.0f;
CopyFromCUDADeviceToHost<double>(&boost_from_score, cuda_boost_from_score_, 1, __FILE__, __LINE__);
double pavg = 0.0f;
CopyFromCUDADeviceToHost<double>(&pavg, cuda_sum_weights_, 1, __FILE__, __LINE__);
Log::Info("[%s:%s]: pavg=%f -> initscore=%f", GetName(), __func__, pavg, boost_from_score);
return boost_from_score;
}
void CUDABinaryLogloss::ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const {
LaunchConvertOutputCUDAKernel(num_data, input, output);
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA_EXP
#include <algorithm>
#include "cuda_binary_objective.hpp"
namespace LightGBM {
template <bool IS_OVA, bool USE_WEIGHT>
__global__ void BoostFromScoreKernel_1_BinaryLogloss(const label_t* cuda_labels, const data_size_t num_data, double* out_cuda_sum_labels,
double* out_cuda_sum_weights, const label_t* cuda_weights, const int ova_class_id) {
__shared__ double shared_buffer[32];
const uint32_t mask = 0xffffffff;
const uint32_t warpLane = threadIdx.x % warpSize;
const uint32_t warpID = threadIdx.x / warpSize;
const uint32_t num_warp = blockDim.x / warpSize;
const data_size_t index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
double label_value = 0.0;
double weight_value = 0.0;
if (index < num_data) {
if (USE_WEIGHT) {
const label_t cuda_label = cuda_labels[index];
const double sample_weight = cuda_weights[index];
const label_t label = IS_OVA ? (static_cast<int>(cuda_label) == ova_class_id ? 1 : 0) : (cuda_label > 0 ? 1 : 0);
label_value = label * sample_weight;
weight_value = sample_weight;
} else {
const label_t cuda_label = cuda_labels[index];
label_value = IS_OVA ? (static_cast<int>(cuda_label) == ova_class_id ? 1 : 0) : (cuda_label > 0 ? 1 : 0);
}
}
for (uint32_t offset = warpSize / 2; offset >= 1; offset >>= 1) {
label_value += __shfl_down_sync(mask, label_value, offset);
}
if (warpLane == 0) {
shared_buffer[warpID] = label_value;
}
__syncthreads();
if (warpID == 0) {
label_value = (warpLane < num_warp ? shared_buffer[warpLane] : 0);
for (uint32_t offset = warpSize / 2; offset >= 1; offset >>= 1) {
label_value += __shfl_down_sync(mask, label_value, offset);
}
}
__syncthreads();
if (USE_WEIGHT) {
for (uint32_t offset = warpSize / 2; offset >= 1; offset >>= 1) {
weight_value += __shfl_down_sync(mask, weight_value, offset);
}
if (warpLane == 0) {
shared_buffer[warpID] = weight_value;
}
__syncthreads();
if (warpID == 0) {
weight_value = (warpLane < num_warp ? shared_buffer[warpLane] : 0);
for (uint32_t offset = warpSize / 2; offset >= 1; offset >>= 1) {
weight_value += __shfl_down_sync(mask, weight_value, offset);
}
}
__syncthreads();
}
if (threadIdx.x == 0) {
atomicAdd_system(out_cuda_sum_labels, label_value);
if (USE_WEIGHT) {
atomicAdd_system(out_cuda_sum_weights, weight_value);
}
}
}
template <bool USE_WEIGHT>
__global__ void BoostFromScoreKernel_2_BinaryLogloss(double* out_cuda_sum_labels, double* out_cuda_sum_weights,
const data_size_t num_data, const double sigmoid) {
const double suml = *out_cuda_sum_labels;
const double sumw = USE_WEIGHT ? *out_cuda_sum_weights : static_cast<double>(num_data);
double pavg = suml / sumw;
pavg = min(pavg, 1.0 - kEpsilon);
pavg = max(pavg, kEpsilon);
const double init_score = log(pavg / (1.0f - pavg)) / sigmoid;
*out_cuda_sum_weights = pavg;
*out_cuda_sum_labels = init_score;
}
void CUDABinaryLogloss::LaunchBoostFromScoreKernel() const {
const int num_blocks = (num_data_ + CALC_INIT_SCORE_BLOCK_SIZE_BINARY - 1) / CALC_INIT_SCORE_BLOCK_SIZE_BINARY;
if (ova_class_id_ == -1) {
if (cuda_weights_ == nullptr) {
BoostFromScoreKernel_1_BinaryLogloss<false, false><<<num_blocks, CALC_INIT_SCORE_BLOCK_SIZE_BINARY>>>
(cuda_label_, num_data_, cuda_boost_from_score_, cuda_sum_weights_, cuda_weights_, ova_class_id_);
} else {
BoostFromScoreKernel_1_BinaryLogloss<false, true><<<num_blocks, CALC_INIT_SCORE_BLOCK_SIZE_BINARY>>>
(cuda_label_, num_data_, cuda_boost_from_score_, cuda_sum_weights_, cuda_weights_, ova_class_id_);
}
} else {
if (cuda_weights_ == nullptr) {
BoostFromScoreKernel_1_BinaryLogloss<true, false><<<num_blocks, CALC_INIT_SCORE_BLOCK_SIZE_BINARY>>>
(cuda_label_, num_data_, cuda_boost_from_score_, cuda_sum_weights_, cuda_weights_, ova_class_id_);
} else {
BoostFromScoreKernel_1_BinaryLogloss<true, true><<<num_blocks, CALC_INIT_SCORE_BLOCK_SIZE_BINARY>>>
(cuda_label_, num_data_, cuda_boost_from_score_, cuda_sum_weights_, cuda_weights_, ova_class_id_);
}
}
SynchronizeCUDADevice(__FILE__, __LINE__);
if (cuda_weights_ == nullptr) {
BoostFromScoreKernel_2_BinaryLogloss<false><<<1, 1>>>(cuda_boost_from_score_, cuda_sum_weights_, num_data_, sigmoid_);
} else {
BoostFromScoreKernel_2_BinaryLogloss<true><<<1, 1>>>(cuda_boost_from_score_, cuda_sum_weights_, num_data_, sigmoid_);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
}
template <bool USE_LABEL_WEIGHT, bool USE_WEIGHT, bool IS_OVA>
__global__ void GetGradientsKernel_BinaryLogloss(const double* cuda_scores, const label_t* cuda_labels,
const double* cuda_label_weights, const label_t* cuda_weights, const int ova_class_id,
const double sigmoid, const data_size_t num_data,
score_t* cuda_out_gradients, score_t* cuda_out_hessians) {
const data_size_t data_index = static_cast<data_size_t>(blockDim.x * blockIdx.x + threadIdx.x);
if (data_index < num_data) {
const label_t cuda_label = static_cast<int>(cuda_labels[data_index]);
const int label = IS_OVA ? (cuda_label == ova_class_id ? 1 : -1) : (cuda_label > 0 ? 1 : -1);
const double response = -label * sigmoid / (1.0f + exp(label * sigmoid * cuda_scores[data_index]));
const double abs_response = fabs(response);
if (!USE_WEIGHT) {
if (USE_LABEL_WEIGHT) {
const double label_weight = cuda_label_weights[label];
cuda_out_gradients[data_index] = static_cast<score_t>(response * label_weight);
cuda_out_hessians[data_index] = static_cast<score_t>(abs_response * (sigmoid - abs_response) * label_weight);
} else {
cuda_out_gradients[data_index] = static_cast<score_t>(response);
cuda_out_hessians[data_index] = static_cast<score_t>(abs_response * (sigmoid - abs_response));
}
} else {
const double sample_weight = cuda_weights[data_index];
if (USE_LABEL_WEIGHT) {
const double label_weight = cuda_label_weights[label];
cuda_out_gradients[data_index] = static_cast<score_t>(response * label_weight * sample_weight);
cuda_out_hessians[data_index] = static_cast<score_t>(abs_response * (sigmoid - abs_response) * label_weight * sample_weight);
} else {
cuda_out_gradients[data_index] = static_cast<score_t>(response * sample_weight);
cuda_out_hessians[data_index] = static_cast<score_t>(abs_response * (sigmoid - abs_response) * sample_weight);
}
}
}
}
#define GetGradientsKernel_BinaryLogloss_ARGS \
scores, \
cuda_label_, \
cuda_label_weights_, \
cuda_weights_, \
ova_class_id_, \
sigmoid_, \
num_data_, \
gradients, \
hessians
void CUDABinaryLogloss::LaunchGetGradientsKernel(const double* scores, score_t* gradients, score_t* hessians) const {
const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_BINARY - 1) / GET_GRADIENTS_BLOCK_SIZE_BINARY;
if (ova_class_id_ == -1) {
if (cuda_label_weights_ == nullptr) {
if (cuda_weights_ == nullptr) {
GetGradientsKernel_BinaryLogloss<false, false, false><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_BINARY>>>(GetGradientsKernel_BinaryLogloss_ARGS);
} else {
GetGradientsKernel_BinaryLogloss<false, true, false><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_BINARY>>>(GetGradientsKernel_BinaryLogloss_ARGS);
}
} else {
if (cuda_weights_ == nullptr) {
GetGradientsKernel_BinaryLogloss<true, false, false><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_BINARY>>>(GetGradientsKernel_BinaryLogloss_ARGS);
} else {
GetGradientsKernel_BinaryLogloss<true, true, false><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_BINARY>>>(GetGradientsKernel_BinaryLogloss_ARGS);
}
}
} else {
if (cuda_label_weights_ == nullptr) {
if (cuda_weights_ == nullptr) {
GetGradientsKernel_BinaryLogloss<false, false, true><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_BINARY>>>(GetGradientsKernel_BinaryLogloss_ARGS);
} else {
GetGradientsKernel_BinaryLogloss<false, true, true><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_BINARY>>>(GetGradientsKernel_BinaryLogloss_ARGS);
}
} else {
if (cuda_weights_ == nullptr) {
GetGradientsKernel_BinaryLogloss<true, false, true><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_BINARY>>>(GetGradientsKernel_BinaryLogloss_ARGS);
} else {
GetGradientsKernel_BinaryLogloss<true, true, true><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_BINARY>>>(GetGradientsKernel_BinaryLogloss_ARGS);
}
}
}
}
#undef GetGradientsKernel_BinaryLogloss_ARGS
__global__ void ConvertOutputCUDAKernel_BinaryLogloss(const double sigmoid, const data_size_t num_data, const double* input, double* output) {
const data_size_t data_index = static_cast<data_size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (data_index < num_data) {
output[data_index] = 1.0f / (1.0f + exp(-sigmoid * input[data_index]));
}
}
void CUDABinaryLogloss::LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const {
const int num_blocks = (num_data + GET_GRADIENTS_BLOCK_SIZE_BINARY - 1) / GET_GRADIENTS_BLOCK_SIZE_BINARY;
ConvertOutputCUDAKernel_BinaryLogloss<<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_BINARY>>>(sigmoid_, num_data, input, output);
}
__global__ void ResetOVACUDALableKernel(
const int ova_class_id,
const data_size_t num_data,
label_t* cuda_label) {
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
if (data_index < num_data) {
const int int_label = static_cast<int>(cuda_label[data_index]);
cuda_label[data_index] = (int_label == ova_class_id ? 1.0f : 0.0f);
}
}
void CUDABinaryLogloss::LaunchResetOVACUDALableKernel() const {
const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_BINARY - 1) / GET_GRADIENTS_BLOCK_SIZE_BINARY;
ResetOVACUDALableKernel<<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_BINARY>>>(ova_class_id_, num_data_, cuda_ova_label_);
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifndef LIGHTGBM_OBJECTIVE_CUDA_CUDA_BINARY_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_CUDA_CUDA_BINARY_OBJECTIVE_HPP_
#ifdef USE_CUDA_EXP
#define GET_GRADIENTS_BLOCK_SIZE_BINARY (1024)
#define CALC_INIT_SCORE_BLOCK_SIZE_BINARY (1024)
#include <LightGBM/cuda/cuda_objective_function.hpp>
#include <string>
#include <vector>
#include "../binary_objective.hpp"
namespace LightGBM {
class CUDABinaryLogloss : public CUDAObjectiveInterface, public BinaryLogloss {
public:
explicit CUDABinaryLogloss(const Config& config);
explicit CUDABinaryLogloss(const Config& config, const int ova_class_id);
explicit CUDABinaryLogloss(const std::vector<std::string>& strs);
~CUDABinaryLogloss();
void Init(const Metadata& metadata, data_size_t num_data) override;
void GetGradients(const double* scores, score_t* gradients, score_t* hessians) const override;
double BoostFromScore(int) const override;
void ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const override;
std::function<void(data_size_t, const double*, double*)> GetCUDAConvertOutputFunc() const override {
return [this] (data_size_t num_data, const double* input, double* output) {
ConvertOutputCUDA(num_data, input, output);
};
}
bool IsCUDAObjective() const override { return true; }
private:
void LaunchGetGradientsKernel(const double* scores, score_t* gradients, score_t* hessians) const;
void LaunchBoostFromScoreKernel() const;
void LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const;
void LaunchResetOVACUDALableKernel() const;
// CUDA memory, held by other objects
const label_t* cuda_label_;
label_t* cuda_ova_label_;
const label_t* cuda_weights_;
// CUDA memory, held by this object
double* cuda_boost_from_score_;
double* cuda_sum_weights_;
double* cuda_label_weights_;
const int ova_class_id_ = -1;
};
} // namespace LightGBM
#endif // USE_CUDA_EXP
#endif // LIGHTGBM_OBJECTIVE_CUDA_CUDA_BINARY_OBJECTIVE_HPP_
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "regression_objective.hpp" #include "regression_objective.hpp"
#include "xentropy_objective.hpp" #include "xentropy_objective.hpp"
#include "cuda/cuda_binary_objective.hpp"
namespace LightGBM { namespace LightGBM {
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const Config& config) { ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const Config& config) {
...@@ -34,8 +36,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& ...@@ -34,8 +36,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
Log::Warning("Objective poisson is not implemented in cuda_exp version. Fall back to boosting on CPU."); Log::Warning("Objective poisson is not implemented in cuda_exp version. Fall back to boosting on CPU.");
return new RegressionPoissonLoss(config); return new RegressionPoissonLoss(config);
} else if (type == std::string("binary")) { } else if (type == std::string("binary")) {
Log::Warning("Objective binary is not implemented in cuda_exp version. Fall back to boosting on CPU."); return new CUDABinaryLogloss(config);
return new BinaryLogloss(config);
} else if (type == std::string("lambdarank")) { } else if (type == std::string("lambdarank")) {
Log::Warning("Objective lambdarank is not implemented in cuda_exp version. Fall back to boosting on CPU."); Log::Warning("Objective lambdarank is not implemented in cuda_exp version. Fall back to boosting on CPU.");
return new LambdarankNDCG(config); return new LambdarankNDCG(config);
......
...@@ -92,7 +92,6 @@ WARNING | More than one metric available, picking one to plot. ...@@ -92,7 +92,6 @@ WARNING | More than one metric available, picking one to plot.
"INFO | [LightGBM] [Info] LightGBM using CUDA trainer with DP float!!" "INFO | [LightGBM] [Info] LightGBM using CUDA trainer with DP float!!"
] ]
cuda_exp_lines = [ cuda_exp_lines = [
"INFO | [LightGBM] [Warning] Objective binary is not implemented in cuda_exp version. Fall back to boosting on CPU.",
"INFO | [LightGBM] [Warning] Metric auc is not implemented in cuda_exp version. Fall back to evaluation on CPU.", "INFO | [LightGBM] [Warning] Metric auc is not implemented in cuda_exp version. Fall back to evaluation on CPU.",
"INFO | [LightGBM] [Warning] Metric binary_error is not implemented in cuda_exp version. Fall back to evaluation on CPU.", "INFO | [LightGBM] [Warning] Metric binary_error is not implemented in cuda_exp version. Fall back to evaluation on CPU.",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment