Unverified Commit d78b6bc2 authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

[CUDA] Add L1 regression objective for cuda_exp (#5457)

* add (l1) regression objective for cuda_exp

* remove RenewTreeOutputCUDA from CUDARegressionL2loss

* remove mutable and use CUDAVector

* remove white spaces

* remove TODO and document in (#5459)
parent e02ddc44
......@@ -107,6 +107,9 @@ __device__ __forceinline__ T ShufflePrefixSumExclusive(T value, T* shared_mem_bu
template <typename T>
void ShufflePrefixSumGlobal(T* values, size_t len, T* block_prefix_sum_buffer);
template <typename VAL_T, typename REDUCE_T, typename INDEX_T>
void GlobalInclusiveArgPrefixSum(const INDEX_T* sorted_indices, const VAL_T* in_values, REDUCE_T* out_values, REDUCE_T* block_buffer, size_t n);
template <typename T>
__device__ __forceinline__ T ShuffleReduceSumWarp(T value, const data_size_t len) {
if (len > 0) {
......@@ -384,12 +387,112 @@ __device__ void BitonicArgSortDevice(const VAL_T* values, INDEX_T* indices, cons
}
}
template <typename VAL_T, typename INDEX_T, bool ASCENDING>
void BitonicArgSortGlobal(const VAL_T* values, INDEX_T* indices, const size_t len);
template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceSumGlobal(const VAL_T* values, size_t n, REDUCE_T* block_buffer);
template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceDotProdGlobal(const VAL_T* values1, const VAL_T* values2, size_t n, REDUCE_T* block_buffer);
template <typename VAL_T, typename REDUCE_VAL_T, typename INDEX_T>
__device__ void ShuffleSortedPrefixSumDevice(const VAL_T* in_values,
const INDEX_T* sorted_indices,
REDUCE_VAL_T* out_values,
const INDEX_T num_data) {
__shared__ REDUCE_VAL_T shared_buffer[32];
const INDEX_T num_data_per_thread = (num_data + static_cast<INDEX_T>(blockDim.x) - 1) / static_cast<INDEX_T>(blockDim.x);
const INDEX_T start = num_data_per_thread * static_cast<INDEX_T>(threadIdx.x);
const INDEX_T end = min(start + num_data_per_thread, num_data);
REDUCE_VAL_T thread_sum = 0;
for (INDEX_T index = start; index < end; ++index) {
thread_sum += static_cast<REDUCE_VAL_T>(in_values[sorted_indices[index]]);
}
__syncthreads();
thread_sum = ShufflePrefixSumExclusive<REDUCE_VAL_T>(thread_sum, shared_buffer);
const REDUCE_VAL_T thread_base = shared_buffer[threadIdx.x];
for (INDEX_T index = start; index < end; ++index) {
out_values[index] = thread_base + static_cast<REDUCE_VAL_T>(in_values[sorted_indices[index]]);
}
__syncthreads();
}
template <typename VAL_T, typename INDEX_T, typename WEIGHT_T, typename WEIGHT_REDUCE_T, bool ASCENDING, bool USE_WEIGHT>
__global__ void PercentileGlobalKernel(const VAL_T* values,
const WEIGHT_T* weights,
const INDEX_T* sorted_indices,
const WEIGHT_REDUCE_T* weights_prefix_sum,
const double alpha,
const INDEX_T len,
VAL_T* out_value) {
if (!USE_WEIGHT) {
const double float_pos = (1.0f - alpha) * len;
const INDEX_T pos = static_cast<INDEX_T>(float_pos);
if (pos < 1) {
*out_value = values[sorted_indices[0]];
} else if (pos >= len) {
*out_value = values[sorted_indices[len - 1]];
} else {
const double bias = float_pos - static_cast<double>(pos);
const VAL_T v1 = values[sorted_indices[pos - 1]];
const VAL_T v2 = values[sorted_indices[pos]];
*out_value = static_cast<VAL_T>(v1 - (v1 - v2) * bias);
}
} else {
const WEIGHT_REDUCE_T threshold = weights_prefix_sum[len - 1] * (1.0f - alpha);
__shared__ INDEX_T pos;
if (threadIdx.x == 0) {
pos = len;
}
__syncthreads();
for (INDEX_T index = static_cast<INDEX_T>(threadIdx.x); index < len; index += static_cast<INDEX_T>(blockDim.x)) {
if (weights_prefix_sum[index] > threshold && (index == 0 || weights_prefix_sum[index - 1] <= threshold)) {
pos = index;
}
}
__syncthreads();
pos = min(pos, len - 1);
if (pos == 0 || pos == len - 1) {
*out_value = values[pos];
}
const VAL_T v1 = values[sorted_indices[pos - 1]];
const VAL_T v2 = values[sorted_indices[pos]];
*out_value = static_cast<VAL_T>(v1 - (v1 - v2) * (threshold - weights_prefix_sum[pos - 1]) / (weights_prefix_sum[pos] - weights_prefix_sum[pos - 1]));
}
}
template <typename VAL_T, typename INDEX_T, typename WEIGHT_T, typename WEIGHT_REDUCE_T, bool ASCENDING, bool USE_WEIGHT>
void PercentileGlobal(const VAL_T* values,
const WEIGHT_T* weights,
INDEX_T* indices,
WEIGHT_REDUCE_T* weights_prefix_sum,
WEIGHT_REDUCE_T* weights_prefix_sum_buffer,
const double alpha,
const INDEX_T len,
VAL_T* cuda_out_value) {
if (len <= 1) {
CopyFromCUDADeviceToCUDADevice<VAL_T>(cuda_out_value, values, 1, __FILE__, __LINE__);
}
BitonicArgSortGlobal<VAL_T, INDEX_T, ASCENDING>(values, indices, len);
SynchronizeCUDADevice(__FILE__, __LINE__);
if (USE_WEIGHT) {
GlobalInclusiveArgPrefixSum<WEIGHT_T, WEIGHT_REDUCE_T, INDEX_T>(indices, weights, weights_prefix_sum, weights_prefix_sum_buffer, static_cast<size_t>(len));
}
SynchronizeCUDADevice(__FILE__, __LINE__);
PercentileGlobalKernel<VAL_T, INDEX_T, WEIGHT_T, WEIGHT_REDUCE_T, ASCENDING, USE_WEIGHT><<<1, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(values, weights, indices, weights_prefix_sum, alpha, len, cuda_out_value);
SynchronizeCUDADevice(__FILE__, __LINE__);
}
template <typename VAL_T, typename INDEX_T, typename WEIGHT_T, typename REDUCE_WEIGHT_T, bool ASCENDING, bool USE_WEIGHT>
__device__ VAL_T PercentileDevice(const VAL_T* values,
const WEIGHT_T* weights,
INDEX_T* indices,
REDUCE_WEIGHT_T* weights_prefix_sum,
const double alpha,
const INDEX_T len);
} // namespace LightGBM
#endif // USE_CUDA_EXP
......
......@@ -167,7 +167,7 @@ class CUDAVector {
return host_vector;
}
T* RawData() {
T* RawData() const {
return data_;
}
......
......@@ -139,6 +139,334 @@ void ShuffleReduceDotProdGlobal<label_t, double>(const label_t* values1, const l
ShuffleReduceDotProdGlobalInner(values1, values2, n, block_buffer);
}
template <typename INDEX_T, typename VAL_T, typename REDUCE_T>
__global__ void GlobalInclusiveArgPrefixSumKernel(
const INDEX_T* sorted_indices, const VAL_T* in_values, REDUCE_T* out_values, REDUCE_T* block_buffer, data_size_t num_data) {
__shared__ REDUCE_T shared_buffer[32];
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
REDUCE_T value = static_cast<REDUCE_T>(data_index < num_data ? in_values[sorted_indices[data_index]] : 0);
__syncthreads();
value = ShufflePrefixSum<REDUCE_T>(value, shared_buffer);
if (data_index < num_data) {
out_values[data_index] = value;
}
if (threadIdx.x == blockDim.x - 1) {
block_buffer[blockIdx.x + 1] = value;
}
}
template <typename T>
__global__ void GlobalInclusivePrefixSumReduceBlockKernel(T* block_buffer, data_size_t num_blocks) {
__shared__ T shared_buffer[32];
T thread_sum = 0;
const data_size_t num_blocks_per_thread = (num_blocks + static_cast<data_size_t>(blockDim.x)) / static_cast<data_size_t>(blockDim.x);
const data_size_t thread_start_block_index = static_cast<data_size_t>(threadIdx.x) * num_blocks_per_thread;
const data_size_t thread_end_block_index = min(thread_start_block_index + num_blocks_per_thread, num_blocks + 1);
for (data_size_t block_index = thread_start_block_index; block_index < thread_end_block_index; ++block_index) {
thread_sum += block_buffer[block_index];
}
ShufflePrefixSumExclusive<T>(thread_sum, shared_buffer);
for (data_size_t block_index = thread_start_block_index; block_index < thread_end_block_index; ++block_index) {
block_buffer[block_index] += thread_sum;
}
}
template <typename T>
__global__ void GlobalInclusivePrefixSumAddBlockBaseKernel(const T* block_buffer, T* values, data_size_t num_data) {
const T block_sum_base = block_buffer[blockIdx.x];
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
if (data_index < num_data) {
values[data_index] += block_sum_base;
}
}
template <typename VAL_T, typename REDUCE_T, typename INDEX_T>
void GlobalInclusiveArgPrefixSumInner(const INDEX_T* sorted_indices, const VAL_T* in_values, REDUCE_T* out_values, REDUCE_T* block_buffer, size_t n) {
const data_size_t num_data = static_cast<data_size_t>(n);
const data_size_t num_blocks = (num_data + GLOBAL_PREFIX_SUM_BLOCK_SIZE - 1) / GLOBAL_PREFIX_SUM_BLOCK_SIZE;
GlobalInclusiveArgPrefixSumKernel<INDEX_T, VAL_T, REDUCE_T><<<num_blocks, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(
sorted_indices, in_values, out_values, block_buffer, num_data);
SynchronizeCUDADevice(__FILE__, __LINE__);
GlobalInclusivePrefixSumReduceBlockKernel<REDUCE_T><<<1, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(
block_buffer, num_blocks);
SynchronizeCUDADevice(__FILE__, __LINE__);
GlobalInclusivePrefixSumAddBlockBaseKernel<REDUCE_T><<<num_blocks, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(
block_buffer, out_values, num_data);
SynchronizeCUDADevice(__FILE__, __LINE__);
}
template <>
void GlobalInclusiveArgPrefixSum<label_t, double, data_size_t>(const data_size_t* sorted_indices, const label_t* in_values, double* out_values, double* block_buffer, size_t n) {
GlobalInclusiveArgPrefixSumInner<label_t, double, data_size_t>(sorted_indices, in_values, out_values, block_buffer, n);
}
template <typename VAL_T, typename INDEX_T, bool ASCENDING>
__global__ void BitonicArgSortGlobalKernel(const VAL_T* values, INDEX_T* indices, const int num_total_data) {
const int thread_index = static_cast<int>(threadIdx.x);
const int low = static_cast<int>(blockIdx.x * BITONIC_SORT_NUM_ELEMENTS);
const bool outer_ascending = ASCENDING ? (blockIdx.x % 2 == 0) : (blockIdx.x % 2 == 1);
const VAL_T* values_pointer = values + low;
INDEX_T* indices_pointer = indices + low;
const int num_data = min(BITONIC_SORT_NUM_ELEMENTS, num_total_data - low);
__shared__ VAL_T shared_values[BITONIC_SORT_NUM_ELEMENTS];
__shared__ INDEX_T shared_indices[BITONIC_SORT_NUM_ELEMENTS];
if (thread_index < num_data) {
shared_values[thread_index] = values_pointer[thread_index];
shared_indices[thread_index] = static_cast<INDEX_T>(thread_index + blockIdx.x * blockDim.x);
}
__syncthreads();
for (int depth = BITONIC_SORT_DEPTH - 1; depth >= 1; --depth) {
const int segment_length = 1 << (BITONIC_SORT_DEPTH - depth);
const int segment_index = thread_index / segment_length;
const bool ascending = outer_ascending ? (segment_index % 2 == 0) : (segment_index % 2 == 1);
const int num_total_segment = (num_data + segment_length - 1) / segment_length;
{
const int inner_depth = depth;
const int inner_segment_length_half = 1 << (BITONIC_SORT_DEPTH - 1 - inner_depth);
const int inner_segment_index_half = thread_index / inner_segment_length_half;
const int offset = ((inner_segment_index_half >> 1) == num_total_segment - 1 && ascending == outer_ascending) ?
(num_total_segment * segment_length - num_data) : 0;
const int segment_start = segment_index * segment_length;
if (inner_segment_index_half % 2 == 0) {
if (thread_index >= offset + segment_start) {
const int index_to_compare = thread_index + inner_segment_length_half - offset;
const INDEX_T this_index = shared_indices[thread_index];
const INDEX_T other_index = shared_indices[index_to_compare];
const VAL_T this_value = shared_values[thread_index];
const VAL_T other_value = shared_values[index_to_compare];
if (index_to_compare < num_data && (this_value > other_value) == ascending) {
shared_indices[thread_index] = other_index;
shared_indices[index_to_compare] = this_index;
shared_values[thread_index] = other_value;
shared_values[index_to_compare] = this_value;
}
}
}
__syncthreads();
}
for (int inner_depth = depth + 1; inner_depth < BITONIC_SORT_DEPTH; ++inner_depth) {
const int inner_segment_length_half = 1 << (BITONIC_SORT_DEPTH - 1 - inner_depth);
const int inner_segment_index_half = thread_index / inner_segment_length_half;
if (inner_segment_index_half % 2 == 0) {
const int index_to_compare = thread_index + inner_segment_length_half;
const INDEX_T this_index = shared_indices[thread_index];
const INDEX_T other_index = shared_indices[index_to_compare];
const VAL_T this_value = shared_values[thread_index];
const VAL_T other_value = shared_values[index_to_compare];
if (index_to_compare < num_data && (this_value > other_value) == ascending) {
shared_indices[thread_index] = other_index;
shared_indices[index_to_compare] = this_index;
shared_values[thread_index] = other_value;
shared_values[index_to_compare] = this_value;
}
}
__syncthreads();
}
}
if (thread_index < num_data) {
indices_pointer[thread_index] = shared_indices[thread_index];
}
}
template <typename VAL_T, typename INDEX_T, bool ASCENDING>
__global__ void BitonicArgSortMergeKernel(const VAL_T* values, INDEX_T* indices, const int segment_length, const int len) {
const int thread_index = static_cast<int>(threadIdx.x + blockIdx.x * blockDim.x);
const int segment_index = thread_index / segment_length;
const bool ascending = ASCENDING ? (segment_index % 2 == 0) : (segment_index % 2 == 1);
__shared__ VAL_T shared_values[BITONIC_SORT_NUM_ELEMENTS];
__shared__ INDEX_T shared_indices[BITONIC_SORT_NUM_ELEMENTS];
const int offset = static_cast<int>(blockIdx.x * blockDim.x);
const int local_len = min(BITONIC_SORT_NUM_ELEMENTS, len - offset);
if (thread_index < len) {
const INDEX_T index = indices[thread_index];
shared_values[threadIdx.x] = values[index];
shared_indices[threadIdx.x] = index;
}
__syncthreads();
int half_segment_length = BITONIC_SORT_NUM_ELEMENTS / 2;
while (half_segment_length >= 1) {
const int half_segment_index = static_cast<int>(threadIdx.x) / half_segment_length;
if (half_segment_index % 2 == 0) {
const int index_to_compare = static_cast<int>(threadIdx.x) + half_segment_length;
const INDEX_T this_index = shared_indices[threadIdx.x];
const INDEX_T other_index = shared_indices[index_to_compare];
const VAL_T this_value = shared_values[threadIdx.x];
const VAL_T other_value = shared_values[index_to_compare];
if (index_to_compare < local_len && ((this_value > other_value) == ascending)) {
shared_indices[threadIdx.x] = other_index;
shared_indices[index_to_compare] = this_index;
shared_values[threadIdx.x] = other_value;
shared_values[index_to_compare] = this_value;
}
}
__syncthreads();
half_segment_length >>= 1;
}
if (thread_index < len) {
indices[thread_index] = shared_indices[threadIdx.x];
}
}
template <typename VAL_T, typename INDEX_T, bool ASCENDING, bool BEGIN>
__global__ void BitonicArgCompareKernel(const VAL_T* values, INDEX_T* indices, const int half_segment_length, const int outer_segment_length, const int len) {
const int thread_index = static_cast<int>(threadIdx.x + blockIdx.x * blockDim.x);
const int segment_index = thread_index / outer_segment_length;
const int half_segment_index = thread_index / half_segment_length;
const bool ascending = ASCENDING ? (segment_index % 2 == 0) : (segment_index % 2 == 1);
if (half_segment_index % 2 == 0) {
const int num_total_segment = (len + outer_segment_length - 1) / outer_segment_length;
if (BEGIN && (half_segment_index >> 1) == num_total_segment - 1 && ascending == ASCENDING) {
const int offset = num_total_segment * outer_segment_length - len;
const int segment_start = segment_index * outer_segment_length;
if (thread_index >= offset + segment_start) {
const int index_to_compare = thread_index + half_segment_length - offset;
if (index_to_compare < len) {
const INDEX_T this_index = indices[thread_index];
const INDEX_T other_index = indices[index_to_compare];
if ((values[this_index] > values[other_index]) == ascending) {
indices[thread_index] = other_index;
indices[index_to_compare] = this_index;
}
}
}
} else {
const int index_to_compare = thread_index + half_segment_length;
if (index_to_compare < len) {
const INDEX_T this_index = indices[thread_index];
const INDEX_T other_index = indices[index_to_compare];
if ((values[this_index] > values[other_index]) == ascending) {
indices[thread_index] = other_index;
indices[index_to_compare] = this_index;
}
}
}
}
}
template <typename VAL_T, typename INDEX_T, bool ASCENDING>
void BitonicArgSortGlobalHelper(const VAL_T* values, INDEX_T* indices, const size_t len) {
int max_depth = 1;
int len_to_shift = static_cast<int>(len) - 1;
while (len_to_shift > 0) {
++max_depth;
len_to_shift >>= 1;
}
const int num_blocks = (static_cast<int>(len) + BITONIC_SORT_NUM_ELEMENTS - 1) / BITONIC_SORT_NUM_ELEMENTS;
BitonicArgSortGlobalKernel<VAL_T, INDEX_T, ASCENDING><<<num_blocks, BITONIC_SORT_NUM_ELEMENTS>>>(values, indices, static_cast<int>(len));
SynchronizeCUDADevice(__FILE__, __LINE__);
for (int depth = max_depth - 11; depth >= 1; --depth) {
const int segment_length = (1 << (max_depth - depth));
int half_segment_length = (segment_length >> 1);
{
BitonicArgCompareKernel<VAL_T, INDEX_T, ASCENDING, true><<<num_blocks, BITONIC_SORT_NUM_ELEMENTS>>>(
values, indices, half_segment_length, segment_length, static_cast<int>(len));
SynchronizeCUDADevice(__FILE__, __LINE__);
half_segment_length >>= 1;
}
for (int inner_depth = depth + 1; inner_depth <= max_depth - 11; ++inner_depth) {
BitonicArgCompareKernel<VAL_T, INDEX_T, ASCENDING, false><<<num_blocks, BITONIC_SORT_NUM_ELEMENTS>>>(
values, indices, half_segment_length, segment_length, static_cast<int>(len));
SynchronizeCUDADevice(__FILE__, __LINE__);
half_segment_length >>= 1;
}
BitonicArgSortMergeKernel<VAL_T, INDEX_T, ASCENDING><<<num_blocks, BITONIC_SORT_NUM_ELEMENTS>>>(
values, indices, segment_length, static_cast<int>(len));
SynchronizeCUDADevice(__FILE__, __LINE__);
}
}
template <>
void BitonicArgSortGlobal<double, data_size_t, false>(const double* values, data_size_t* indices, const size_t len) {
BitonicArgSortGlobalHelper<double, data_size_t, false>(values, indices, len);
}
template <>
void BitonicArgSortGlobal<double, data_size_t, true>(const double* values, data_size_t* indices, const size_t len) {
BitonicArgSortGlobalHelper<double, data_size_t, true>(values, indices, len);
}
template <>
void BitonicArgSortGlobal<label_t, data_size_t, false>(const label_t* values, data_size_t* indices, const size_t len) {
BitonicArgSortGlobalHelper<label_t, data_size_t, false>(values, indices, len);
}
template <>
void BitonicArgSortGlobal<data_size_t, int, true>(const data_size_t* values, int* indices, const size_t len) {
BitonicArgSortGlobalHelper<data_size_t, int, true>(values, indices, len);
}
template <typename VAL_T, typename INDEX_T, typename WEIGHT_T, typename REDUCE_WEIGHT_T, bool ASCENDING, bool USE_WEIGHT>
__device__ VAL_T PercentileDeviceInner(const VAL_T* values,
const WEIGHT_T* weights,
INDEX_T* indices,
REDUCE_WEIGHT_T* weights_prefix_sum,
const double alpha,
const INDEX_T len) {
if (len <= 1) {
return values[0];
}
if (!USE_WEIGHT) {
BitonicArgSortDevice<VAL_T, INDEX_T, ASCENDING, BITONIC_SORT_NUM_ELEMENTS / 2, 10>(values, indices, len);
const double float_pos = (1.0f - alpha) * len;
const INDEX_T pos = static_cast<INDEX_T>(float_pos);
if (pos < 1) {
return values[indices[0]];
} else if (pos >= len) {
return values[indices[len - 1]];
} else {
const double bias = float_pos - pos;
const VAL_T v1 = values[indices[pos - 1]];
const VAL_T v2 = values[indices[pos]];
return static_cast<VAL_T>(v1 - (v1 - v2) * bias);
}
} else {
BitonicArgSortDevice<VAL_T, INDEX_T, ASCENDING, BITONIC_SORT_NUM_ELEMENTS / 4, 9>(values, indices, len);
ShuffleSortedPrefixSumDevice<WEIGHT_T, REDUCE_WEIGHT_T, INDEX_T>(weights, indices, weights_prefix_sum, len);
const REDUCE_WEIGHT_T threshold = weights_prefix_sum[len - 1] * (1.0f - alpha);
__shared__ INDEX_T pos;
if (threadIdx.x == 0) {
pos = len;
}
__syncthreads();
for (INDEX_T index = static_cast<INDEX_T>(threadIdx.x); index < len; index += static_cast<INDEX_T>(blockDim.x)) {
if (weights_prefix_sum[index] > threshold && (index == 0 || weights_prefix_sum[index - 1] <= threshold)) {
pos = index;
}
}
__syncthreads();
pos = min(pos, len - 1);
if (pos == 0 || pos == len - 1) {
return values[pos];
}
const VAL_T v1 = values[indices[pos - 1]];
const VAL_T v2 = values[indices[pos]];
return static_cast<VAL_T>(v1 - (v1 - v2) * (threshold - weights_prefix_sum[pos - 1]) / (weights_prefix_sum[pos] - weights_prefix_sum[pos - 1]));
}
}
template <>
__device__ double PercentileDevice<double, data_size_t, label_t, double, false, true>(
const double* values,
const label_t* weights,
data_size_t* indices,
double* weights_prefix_sum,
const double alpha,
const data_size_t len) {
return PercentileDeviceInner<double, data_size_t, label_t, double, false, true>(values, weights, indices, weights_prefix_sum, alpha, len);
}
template <>
__device__ double PercentileDevice<double, data_size_t, label_t, double, false, false>(
const double* values,
const label_t* weights,
data_size_t* indices,
double* weights_prefix_sum,
const double alpha,
const data_size_t len) {
return PercentileDeviceInner<double, data_size_t, label_t, double, false, false>(values, weights, indices, weights_prefix_sum, alpha, len);
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
......@@ -14,28 +14,23 @@
namespace LightGBM {
CUDARegressionL2loss::CUDARegressionL2loss(const Config& config):
RegressionL2loss(config) {
cuda_block_buffer_ = nullptr;
cuda_trans_label_ = nullptr;
}
RegressionL2loss(config) {}
CUDARegressionL2loss::CUDARegressionL2loss(const std::vector<std::string>& strs):
RegressionL2loss(strs) {}
CUDARegressionL2loss::~CUDARegressionL2loss() {
DeallocateCUDAMemory(&cuda_block_buffer_, __FILE__, __LINE__);
DeallocateCUDAMemory(&cuda_trans_label_, __FILE__, __LINE__);
}
CUDARegressionL2loss::~CUDARegressionL2loss() {}
void CUDARegressionL2loss::Init(const Metadata& metadata, data_size_t num_data) {
RegressionL2loss::Init(metadata, num_data);
cuda_labels_ = metadata.cuda_metadata()->cuda_label();
cuda_weights_ = metadata.cuda_metadata()->cuda_weights();
num_get_gradients_blocks_ = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_REGRESSION - 1) / GET_GRADIENTS_BLOCK_SIZE_REGRESSION;
AllocateCUDAMemory<double>(&cuda_block_buffer_, static_cast<size_t>(num_get_gradients_blocks_), __FILE__, __LINE__);
cuda_block_buffer_.Resize(static_cast<size_t>(num_get_gradients_blocks_));
if (sqrt_) {
InitCUDAMemoryFromHostMemory<label_t>(&cuda_trans_label_, trans_label_.data(), trans_label_.size(), __FILE__, __LINE__);
cuda_labels_ = cuda_trans_label_;
cuda_trans_label_.Resize(trans_label_.size());
CopyFromHostToCUDADevice<label_t>(cuda_trans_label_.RawData(), trans_label_.data(), trans_label_.size(), __FILE__, __LINE__);
cuda_labels_ = cuda_trans_label_.RawData();
}
}
......@@ -51,20 +46,41 @@ void CUDARegressionL2loss::ConvertOutputCUDA(const data_size_t num_data, const d
LaunchConvertOutputCUDAKernel(num_data, input, output);
}
void CUDARegressionL2loss::RenewTreeOutputCUDA(
CUDARegressionL1loss::CUDARegressionL1loss(const Config& config):
CUDARegressionL2loss(config) {}
CUDARegressionL1loss::CUDARegressionL1loss(const std::vector<std::string>& strs):
CUDARegressionL2loss(strs) {}
CUDARegressionL1loss::~CUDARegressionL1loss() {}
void CUDARegressionL1loss::Init(const Metadata& metadata, data_size_t num_data) {
CUDARegressionL2loss::Init(metadata, num_data);
cuda_data_indices_buffer_.Resize(static_cast<size_t>(num_data));
cuda_percentile_result_.Resize(1);
if (cuda_weights_ != nullptr) {
const int num_blocks = (num_data + GET_GRADIENTS_BLOCK_SIZE_REGRESSION - 1) / GET_GRADIENTS_BLOCK_SIZE_REGRESSION + 1;
cuda_weights_prefix_sum_.Resize(static_cast<size_t>(num_data));
cuda_weights_prefix_sum_buffer_.Resize(static_cast<size_t>(num_blocks));
cuda_weight_by_leaf_buffer_.Resize(static_cast<size_t>(num_data));
}
cuda_residual_buffer_.Resize(static_cast<size_t>(num_data));
}
void CUDARegressionL1loss::RenewTreeOutputCUDA(
const double* score,
const data_size_t* data_indices_in_leaf,
const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf,
const int num_leaves,
double* leaf_value) const {
global_timer.Start("CUDARegressionL2loss::LaunchRenewTreeOutputCUDAKernel");
global_timer.Start("CUDARegressionL1loss::LaunchRenewTreeOutputCUDAKernel");
LaunchRenewTreeOutputCUDAKernel(score, data_indices_in_leaf, num_data_in_leaf, data_start_in_leaf, num_leaves, leaf_value);
SynchronizeCUDADevice(__FILE__, __LINE__);
global_timer.Stop("CUDARegressionL2loss::LaunchRenewTreeOutputCUDAKernel");
global_timer.Stop("CUDARegressionL1loss::LaunchRenewTreeOutputCUDAKernel");
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
......@@ -14,14 +14,14 @@ namespace LightGBM {
double CUDARegressionL2loss::LaunchCalcInitScoreKernel() const {
double label_sum = 0.0f, weight_sum = 0.0f;
if (cuda_weights_ == nullptr) {
ShuffleReduceSumGlobal<label_t, double>(cuda_labels_, static_cast<size_t>(num_data_), cuda_block_buffer_);
CopyFromCUDADeviceToHost<double>(&label_sum, cuda_block_buffer_, 1, __FILE__, __LINE__);
ShuffleReduceSumGlobal<label_t, double>(cuda_labels_, static_cast<size_t>(num_data_), cuda_block_buffer_.RawData());
CopyFromCUDADeviceToHost<double>(&label_sum, cuda_block_buffer_.RawData(), 1, __FILE__, __LINE__);
weight_sum = static_cast<double>(num_data_);
} else {
ShuffleReduceDotProdGlobal<label_t, double>(cuda_labels_, cuda_weights_, static_cast<size_t>(num_data_), cuda_block_buffer_);
CopyFromCUDADeviceToHost<double>(&label_sum, cuda_block_buffer_, 1, __FILE__, __LINE__);
ShuffleReduceSumGlobal<label_t, double>(cuda_weights_, static_cast<size_t>(num_data_), cuda_block_buffer_);
CopyFromCUDADeviceToHost<double>(&weight_sum, cuda_block_buffer_, 1, __FILE__, __LINE__);
ShuffleReduceDotProdGlobal<label_t, double>(cuda_labels_, cuda_weights_, static_cast<size_t>(num_data_), cuda_block_buffer_.RawData());
CopyFromCUDADeviceToHost<double>(&label_sum, cuda_block_buffer_.RawData(), 1, __FILE__, __LINE__);
ShuffleReduceSumGlobal<label_t, double>(cuda_weights_, static_cast<size_t>(num_data_), cuda_block_buffer_.RawData());
CopyFromCUDADeviceToHost<double>(&weight_sum, cuda_block_buffer_.RawData(), 1, __FILE__, __LINE__);
}
return label_sum / weight_sum;
}
......@@ -69,6 +69,126 @@ void CUDARegressionL2loss::LaunchGetGradientsKernel(const double* score, score_t
}
double CUDARegressionL1loss::LaunchCalcInitScoreKernel() const {
const double alpha = 0.9f;
if (cuda_weights_ == nullptr) {
PercentileGlobal<label_t, data_size_t, label_t, double, false, false>(
cuda_labels_, nullptr, cuda_data_indices_buffer_.RawData(), nullptr, nullptr, alpha, num_data_, cuda_percentile_result_.RawData());
} else {
PercentileGlobal<label_t, data_size_t, label_t, double, false, true>(
cuda_labels_, cuda_weights_, cuda_data_indices_buffer_.RawData(), cuda_weights_prefix_sum_.RawData(),
cuda_weights_prefix_sum_buffer_.RawData(), alpha, num_data_, cuda_percentile_result_.RawData());
}
label_t percentile_result = 0.0f;
CopyFromCUDADeviceToHost<label_t>(&percentile_result, cuda_percentile_result_.RawData(), 1, __FILE__, __LINE__);
SynchronizeCUDADevice(__FILE__, __LINE__);
return static_cast<label_t>(percentile_result);
}
template <bool USE_WEIGHT>
__global__ void GetGradientsKernel_RegressionL1(const double* cuda_scores, const label_t* cuda_labels, const label_t* cuda_weights, const data_size_t num_data,
score_t* cuda_out_gradients, score_t* cuda_out_hessians) {
const data_size_t data_index = static_cast<data_size_t>(blockDim.x * blockIdx.x + threadIdx.x);
if (data_index < num_data) {
if (!USE_WEIGHT) {
const double diff = cuda_scores[data_index] - static_cast<double>(cuda_labels[data_index]);
cuda_out_gradients[data_index] = static_cast<score_t>((diff > 0.0f) - (diff < 0.0f));
cuda_out_hessians[data_index] = 1.0f;
} else {
const double diff = cuda_scores[data_index] - static_cast<double>(cuda_labels[data_index]);
const score_t weight = static_cast<score_t>(cuda_weights[data_index]);
cuda_out_gradients[data_index] = static_cast<score_t>((diff > 0.0f) - (diff < 0.0f)) * weight;
cuda_out_hessians[data_index] = weight;
}
}
}
void CUDARegressionL1loss::LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const {
const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_REGRESSION - 1) / GET_GRADIENTS_BLOCK_SIZE_REGRESSION;
if (cuda_weights_ == nullptr) {
GetGradientsKernel_RegressionL1<false><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_REGRESSION>>>(score, cuda_labels_, nullptr, num_data_, gradients, hessians);
} else {
GetGradientsKernel_RegressionL1<true><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_REGRESSION>>>(score, cuda_labels_, cuda_weights_, num_data_, gradients, hessians);
}
}
template <bool USE_WEIGHT>
__global__ void RenewTreeOutputCUDAKernel_RegressionL1(
const double* score,
const label_t* label,
const label_t* weight,
double* residual_buffer,
label_t* weight_by_leaf,
double* weight_prefix_sum_buffer,
const data_size_t* data_indices_in_leaf,
const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf,
data_size_t* data_indices_buffer,
double* leaf_value) {
const int leaf_index = static_cast<int>(blockIdx.x);
const data_size_t data_start = data_start_in_leaf[leaf_index];
const data_size_t num_data = num_data_in_leaf[leaf_index];
data_size_t* data_indices_buffer_pointer = data_indices_buffer + data_start;
const label_t* weight_by_leaf_pointer = weight_by_leaf + data_start;
double* weight_prefix_sum_buffer_pointer = weight_prefix_sum_buffer + data_start;
const double* residual_buffer_pointer = residual_buffer + data_start;
const double alpha = 0.5f;
for (data_size_t inner_data_index = data_start + static_cast<data_size_t>(threadIdx.x); inner_data_index < data_start + num_data; inner_data_index += static_cast<data_size_t>(blockDim.x)) {
const data_size_t data_index = data_indices_in_leaf[inner_data_index];
const label_t data_label = label[data_index];
const double data_score = score[data_index];
residual_buffer[inner_data_index] = static_cast<double>(data_label) - data_score;
if (USE_WEIGHT) {
weight_by_leaf[inner_data_index] = weight[data_index];
}
}
__syncthreads();
const double renew_leaf_value = PercentileDevice<double, data_size_t, label_t, double, false, USE_WEIGHT>(
residual_buffer_pointer, weight_by_leaf_pointer, data_indices_buffer_pointer,
weight_prefix_sum_buffer_pointer, alpha, num_data);
if (threadIdx.x == 0) {
leaf_value[leaf_index] = renew_leaf_value;
}
}
void CUDARegressionL1loss::LaunchRenewTreeOutputCUDAKernel(
const double* score,
const data_size_t* data_indices_in_leaf,
const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf,
const int num_leaves,
double* leaf_value) const {
if (cuda_weights_ == nullptr) {
RenewTreeOutputCUDAKernel_RegressionL1<false><<<num_leaves, GET_GRADIENTS_BLOCK_SIZE_REGRESSION / 2>>>(
score,
cuda_labels_,
cuda_weights_,
cuda_residual_buffer_.RawData(),
cuda_weight_by_leaf_buffer_.RawData(),
cuda_weights_prefix_sum_.RawData(),
data_indices_in_leaf,
num_data_in_leaf,
data_start_in_leaf,
cuda_data_indices_buffer_.RawData(),
leaf_value);
} else {
RenewTreeOutputCUDAKernel_RegressionL1<true><<<num_leaves, GET_GRADIENTS_BLOCK_SIZE_REGRESSION / 4>>>(
score,
cuda_labels_,
cuda_weights_,
cuda_residual_buffer_.RawData(),
cuda_weight_by_leaf_buffer_.RawData(),
cuda_weights_prefix_sum_.RawData(),
data_indices_in_leaf,
num_data_in_leaf,
data_start_in_leaf,
cuda_data_indices_buffer_.RawData(),
leaf_value);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
......@@ -36,9 +36,6 @@ class CUDARegressionL2loss : public CUDAObjectiveInterface, public RegressionL2l
double BoostFromScore(int) const override;
void RenewTreeOutputCUDA(const double* score, const data_size_t* data_indices_in_leaf, const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf, const int num_leaves, double* leaf_value) const override;
std::function<void(data_size_t, const double*, double*)> GetCUDAConvertOutputFunc() const override {
return [this] (data_size_t num_data, const double* input, double* output) {
ConvertOutputCUDA(num_data, input, output);
......@@ -62,19 +59,48 @@ class CUDARegressionL2loss : public CUDAObjectiveInterface, public RegressionL2l
virtual void LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const;
virtual void LaunchRenewTreeOutputCUDAKernel(
const double* /*score*/, const data_size_t* /*data_indices_in_leaf*/, const data_size_t* /*num_data_in_leaf*/,
const data_size_t* /*data_start_in_leaf*/, const int /*num_leaves*/, double* /*leaf_value*/) const {}
const label_t* cuda_labels_;
const label_t* cuda_weights_;
label_t* cuda_trans_label_;
double* cuda_block_buffer_;
CUDAVector<label_t> cuda_trans_label_;
CUDAVector<double> cuda_block_buffer_;
data_size_t num_get_gradients_blocks_;
data_size_t num_init_score_blocks_;
};
class CUDARegressionL1loss : public CUDARegressionL2loss {
public:
explicit CUDARegressionL1loss(const Config& config);
explicit CUDARegressionL1loss(const std::vector<std::string>& strs);
~CUDARegressionL1loss();
void Init(const Metadata& metadata, data_size_t num_data) override;
void RenewTreeOutputCUDA(const double* score, const data_size_t* data_indices_in_leaf, const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf, const int num_leaves, double* leaf_value) const override;
bool IsRenewTreeOutput() const override { return true; }
protected:
CUDAVector<data_size_t> cuda_data_indices_buffer_;
CUDAVector<double> cuda_weights_prefix_sum_;
CUDAVector<double> cuda_weights_prefix_sum_buffer_;
CUDAVector<double> cuda_residual_buffer_;
CUDAVector<label_t> cuda_weight_by_leaf_buffer_;
CUDAVector<label_t> cuda_percentile_result_;
double LaunchCalcInitScoreKernel() const override;
void LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const override;
void LaunchRenewTreeOutputCUDAKernel(
const double* score, const data_size_t* data_indices_in_leaf, const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf, const int num_leaves, double* leaf_value) const;
};
} // namespace LightGBM
#endif // USE_CUDA_EXP
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment