Unverified Commit 9e89ee7f authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

[CUDA] L2 regression objective for cuda_exp (#5452)

* add (l2) regression objective for cuda_exp

* fix lint errors

* correct time tag
parent 2b8fe8b4
......@@ -384,6 +384,12 @@ __device__ void BitonicArgSortDevice(const VAL_T* values, INDEX_T* indices, cons
}
}
template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceSumGlobal(const VAL_T* values, size_t n, REDUCE_T* block_buffer);
template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceDotProdGlobal(const VAL_T* values1, const VAL_T* values2, size_t n, REDUCE_T* block_buffer);
} // namespace LightGBM
#endif // USE_CUDA_EXP
......
......@@ -48,6 +48,9 @@ class ObjectiveFunction {
const data_size_t*,
data_size_t) const { return ori_output; }
virtual void RenewTreeOutputCUDA(const double* /*score*/, const data_size_t* /*data_indices_in_leaf*/, const data_size_t* /*num_data_in_leaf*/,
const data_size_t* /*data_start_in_leaf*/, const int /*num_leaves*/, double* /*leaf_value*/) const {}
virtual double BoostFromScore(int /*class_id*/) const { return 0.0; }
virtual bool ClassNeedTrain(int /*class_id*/) const { return true; }
......
......@@ -77,6 +77,68 @@ void ShufflePrefixSumGlobal(uint64_t* values, size_t len, uint64_t* block_prefix
ShufflePrefixSumGlobalInner<uint64_t>(values, len, block_prefix_sum_buffer);
}
template <typename T>
__global__ void BlockReduceSum(T* block_buffer, const data_size_t num_blocks) {
__shared__ T shared_buffer[32];
T thread_sum = 0;
for (data_size_t block_index = static_cast<data_size_t>(threadIdx.x); block_index < num_blocks; block_index += static_cast<data_size_t>(blockDim.x)) {
thread_sum += block_buffer[block_index];
}
thread_sum = ShuffleReduceSum<T>(thread_sum, shared_buffer, blockDim.x);
if (threadIdx.x == 0) {
block_buffer[0] = thread_sum;
}
}
template <typename VAL_T, typename REDUCE_T>
__global__ void ShuffleReduceSumGlobalKernel(const VAL_T* values, const data_size_t num_value, REDUCE_T* block_buffer) {
__shared__ REDUCE_T shared_buffer[32];
const data_size_t data_index = static_cast<data_size_t>(blockIdx.x * blockDim.x + threadIdx.x);
const REDUCE_T value = (data_index < num_value ? static_cast<REDUCE_T>(values[data_index]) : 0.0f);
const REDUCE_T reduce_value = ShuffleReduceSum<REDUCE_T>(value, shared_buffer, blockDim.x);
if (threadIdx.x == 0) {
block_buffer[blockIdx.x] = reduce_value;
}
}
template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceSumGlobalInner(const VAL_T* values, size_t n, REDUCE_T* block_buffer) {
const data_size_t num_value = static_cast<data_size_t>(n);
const data_size_t num_blocks = (num_value + GLOBAL_PREFIX_SUM_BLOCK_SIZE - 1) / GLOBAL_PREFIX_SUM_BLOCK_SIZE;
ShuffleReduceSumGlobalKernel<VAL_T, REDUCE_T><<<num_blocks, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(values, num_value, block_buffer);
BlockReduceSum<REDUCE_T><<<1, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(block_buffer, num_blocks);
}
template <>
void ShuffleReduceSumGlobal<label_t, double>(const label_t* values, size_t n, double* block_buffer) {
ShuffleReduceSumGlobalInner(values, n, block_buffer);
}
template <typename VAL_T, typename REDUCE_T>
__global__ void ShuffleReduceDotProdGlobalKernel(const VAL_T* values1, const VAL_T* values2, const data_size_t num_value, REDUCE_T* block_buffer) {
__shared__ REDUCE_T shared_buffer[32];
const data_size_t data_index = static_cast<data_size_t>(blockIdx.x * blockDim.x + threadIdx.x);
const REDUCE_T value1 = (data_index < num_value ? static_cast<REDUCE_T>(values1[data_index]) : 0.0f);
const REDUCE_T value2 = (data_index < num_value ? static_cast<REDUCE_T>(values2[data_index]) : 0.0f);
const REDUCE_T reduce_value = ShuffleReduceSum<REDUCE_T>(value1 * value2, shared_buffer, blockDim.x);
if (threadIdx.x == 0) {
block_buffer[blockIdx.x] = reduce_value;
}
}
template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceDotProdGlobalInner(const VAL_T* values1, const VAL_T* values2, size_t n, REDUCE_T* block_buffer) {
const data_size_t num_value = static_cast<data_size_t>(n);
const data_size_t num_blocks = (num_value + GLOBAL_PREFIX_SUM_BLOCK_SIZE - 1) / GLOBAL_PREFIX_SUM_BLOCK_SIZE;
ShuffleReduceDotProdGlobalKernel<VAL_T, REDUCE_T><<<num_blocks, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(values1, values2, num_value, block_buffer);
BlockReduceSum<REDUCE_T><<<1, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(block_buffer, num_blocks);
}
template <>
void ShuffleReduceDotProdGlobal<label_t, double>(const label_t* values1, const label_t* values2, size_t n, double* block_buffer) {
ShuffleReduceDotProdGlobalInner(values1, values2, n, block_buffer);
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA_EXP
#include "cuda_regression_objective.hpp"
#include <string>
#include <vector>
namespace LightGBM {
CUDARegressionL2loss::CUDARegressionL2loss(const Config& config):
RegressionL2loss(config) {
cuda_block_buffer_ = nullptr;
cuda_trans_label_ = nullptr;
}
CUDARegressionL2loss::CUDARegressionL2loss(const std::vector<std::string>& strs):
RegressionL2loss(strs) {}
CUDARegressionL2loss::~CUDARegressionL2loss() {
DeallocateCUDAMemory(&cuda_block_buffer_, __FILE__, __LINE__);
DeallocateCUDAMemory(&cuda_trans_label_, __FILE__, __LINE__);
}
void CUDARegressionL2loss::Init(const Metadata& metadata, data_size_t num_data) {
RegressionL2loss::Init(metadata, num_data);
cuda_labels_ = metadata.cuda_metadata()->cuda_label();
cuda_weights_ = metadata.cuda_metadata()->cuda_weights();
num_get_gradients_blocks_ = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_REGRESSION - 1) / GET_GRADIENTS_BLOCK_SIZE_REGRESSION;
AllocateCUDAMemory<double>(&cuda_block_buffer_, static_cast<size_t>(num_get_gradients_blocks_), __FILE__, __LINE__);
if (sqrt_) {
InitCUDAMemoryFromHostMemory<label_t>(&cuda_trans_label_, trans_label_.data(), trans_label_.size(), __FILE__, __LINE__);
cuda_labels_ = cuda_trans_label_;
}
}
void CUDARegressionL2loss::GetGradients(const double* score, score_t* gradients, score_t* hessians) const {
LaunchGetGradientsKernel(score, gradients, hessians);
}
double CUDARegressionL2loss::BoostFromScore(int) const {
return LaunchCalcInitScoreKernel();
}
void CUDARegressionL2loss::ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const {
LaunchConvertOutputCUDAKernel(num_data, input, output);
}
void CUDARegressionL2loss::RenewTreeOutputCUDA(
const double* score,
const data_size_t* data_indices_in_leaf,
const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf,
const int num_leaves,
double* leaf_value) const {
global_timer.Start("CUDARegressionL2loss::LaunchRenewTreeOutputCUDAKernel");
LaunchRenewTreeOutputCUDAKernel(score, data_indices_in_leaf, num_data_in_leaf, data_start_in_leaf, num_leaves, leaf_value);
SynchronizeCUDADevice(__FILE__, __LINE__);
global_timer.Stop("CUDARegressionL2loss::LaunchRenewTreeOutputCUDAKernel");
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA_EXP
#include "cuda_regression_objective.hpp"
#include <LightGBM/cuda/cuda_algorithms.hpp>
namespace LightGBM {
double CUDARegressionL2loss::LaunchCalcInitScoreKernel() const {
double label_sum = 0.0f, weight_sum = 0.0f;
if (cuda_weights_ == nullptr) {
ShuffleReduceSumGlobal<label_t, double>(cuda_labels_, static_cast<size_t>(num_data_), cuda_block_buffer_);
CopyFromCUDADeviceToHost<double>(&label_sum, cuda_block_buffer_, 1, __FILE__, __LINE__);
weight_sum = static_cast<double>(num_data_);
} else {
ShuffleReduceDotProdGlobal<label_t, double>(cuda_labels_, cuda_weights_, static_cast<size_t>(num_data_), cuda_block_buffer_);
CopyFromCUDADeviceToHost<double>(&label_sum, cuda_block_buffer_, 1, __FILE__, __LINE__);
ShuffleReduceSumGlobal<label_t, double>(cuda_weights_, static_cast<size_t>(num_data_), cuda_block_buffer_);
CopyFromCUDADeviceToHost<double>(&weight_sum, cuda_block_buffer_, 1, __FILE__, __LINE__);
}
return label_sum / weight_sum;
}
__global__ void ConvertOutputCUDAKernel_Regression(const bool sqrt, const data_size_t num_data, const double* input, double* output) {
const int data_index = static_cast<data_size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (data_index < num_data) {
if (sqrt) {
const double sign = input[data_index] >= 0.0f ? 1 : -1;
output[data_index] = sign * input[data_index] * input[data_index];
} else {
output[data_index] = input[data_index];
}
}
}
void CUDARegressionL2loss::LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const {
const int num_blocks = (num_data + GET_GRADIENTS_BLOCK_SIZE_REGRESSION - 1) / GET_GRADIENTS_BLOCK_SIZE_REGRESSION;
ConvertOutputCUDAKernel_Regression<<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_REGRESSION>>>(sqrt_, num_data, input, output);
}
template <bool USE_WEIGHT>
__global__ void GetGradientsKernel_RegressionL2(const double* cuda_scores, const label_t* cuda_labels, const label_t* cuda_weights, const data_size_t num_data,
score_t* cuda_out_gradients, score_t* cuda_out_hessians) {
const data_size_t data_index = static_cast<data_size_t>(blockDim.x * blockIdx.x + threadIdx.x);
if (data_index < num_data) {
if (!USE_WEIGHT) {
cuda_out_gradients[data_index] = static_cast<score_t>(cuda_scores[data_index] - cuda_labels[data_index]);
cuda_out_hessians[data_index] = 1.0f;
} else {
const score_t weight = static_cast<score_t>(cuda_weights[data_index]);
cuda_out_gradients[data_index] = static_cast<score_t>(cuda_scores[data_index] - cuda_labels[data_index]) * weight;
cuda_out_hessians[data_index] = weight;
}
}
}
void CUDARegressionL2loss::LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const {
const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_REGRESSION - 1) / GET_GRADIENTS_BLOCK_SIZE_REGRESSION;
if (cuda_weights_ == nullptr) {
GetGradientsKernel_RegressionL2<false><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_REGRESSION>>>(score, cuda_labels_, nullptr, num_data_, gradients, hessians);
} else {
GetGradientsKernel_RegressionL2<true><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_REGRESSION>>>(score, cuda_labels_, cuda_weights_, num_data_, gradients, hessians);
}
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifndef LIGHTGBM_NEW_CUDA_REGRESSION_OBJECTIVE_HPP_
#define LIGHTGBM_NEW_CUDA_REGRESSION_OBJECTIVE_HPP_
#ifdef USE_CUDA_EXP
#define GET_GRADIENTS_BLOCK_SIZE_REGRESSION (1024)
#include <LightGBM/cuda/cuda_objective_function.hpp>
#include <string>
#include <vector>
#include "../regression_objective.hpp"
namespace LightGBM {
class CUDARegressionL2loss : public CUDAObjectiveInterface, public RegressionL2loss {
public:
explicit CUDARegressionL2loss(const Config& config);
explicit CUDARegressionL2loss(const std::vector<std::string>& strs);
~CUDARegressionL2loss();
void Init(const Metadata& metadata, data_size_t num_data) override;
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override;
void ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const override;
double BoostFromScore(int) const override;
void RenewTreeOutputCUDA(const double* score, const data_size_t* data_indices_in_leaf, const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf, const int num_leaves, double* leaf_value) const override;
std::function<void(data_size_t, const double*, double*)> GetCUDAConvertOutputFunc() const override {
return [this] (data_size_t num_data, const double* input, double* output) {
ConvertOutputCUDA(num_data, input, output);
};
}
bool IsConstantHessian() const override {
if (cuda_weights_ == nullptr) {
return true;
} else {
return false;
}
}
bool IsCUDAObjective() const override { return true; }
protected:
virtual double LaunchCalcInitScoreKernel() const;
virtual void LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const;
virtual void LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const;
virtual void LaunchRenewTreeOutputCUDAKernel(
const double* /*score*/, const data_size_t* /*data_indices_in_leaf*/, const data_size_t* /*num_data_in_leaf*/,
const data_size_t* /*data_start_in_leaf*/, const int /*num_leaves*/, double* /*leaf_value*/) const {}
const label_t* cuda_labels_;
const label_t* cuda_weights_;
label_t* cuda_trans_label_;
double* cuda_block_buffer_;
data_size_t num_get_gradients_blocks_;
data_size_t num_init_score_blocks_;
};
} // namespace LightGBM
#endif // USE_CUDA_EXP
#endif // LIGHTGBM_NEW_CUDA_REGRESSION_OBJECTIVE_HPP_
......@@ -11,6 +11,7 @@
#include "xentropy_objective.hpp"
#include "cuda/cuda_binary_objective.hpp"
#include "cuda/cuda_regression_objective.hpp"
namespace LightGBM {
......@@ -18,8 +19,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
#ifdef USE_CUDA_EXP
if (config.device_type == std::string("cuda_exp") && config.boosting == std::string("gbdt")) {
if (type == std::string("regression")) {
Log::Warning("Objective regression is not implemented in cuda_exp version. Fall back to boosting on CPU.");
return new RegressionL2loss(config);
return new CUDARegressionL2loss(config);
} else if (type == std::string("regression_l1")) {
Log::Warning("Objective regression_l1 is not implemented in cuda_exp version. Fall back to boosting on CPU.");
return new RegressionL1loss(config);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment