Unverified Commit 6b56a90c authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

[CUDA] New CUDA version Part 1 (#4630)



* new cuda framework

* add histogram construction kernel

* before removing multi-gpu

* new cuda framework

* tree learner cuda kernels

* single tree framework ready

* single tree training framework

* remove comments

* boosting with cuda

* optimize for best split find

* data split

* move boosting into cuda

* parallel synchronize best split point

* merge split data kernels

* before code refactor

* use tasks instead of features as units for split finding

* refactor cuda best split finder

* fix configuration error with small leaves in data split

* skip histogram construction of too small leaf

* skip split finding of invalid leaves

stop when no leaf to split

* support row wise with CUDA

* copy data for split by column

* copy data from host to CPU by column for data partition

* add synchronize best splits for one leaf from multiple blocks

* partition dense row data

* fix sync best split from task blocks

* add support for sparse row wise for CUDA

* remove useless code

* add l2 regression objective

* sparse multi value bin enabled for CUDA

* fix cuda ranking objective

* support for number of items <= 2048 per query

* speedup histogram construction by interleaving global memory access

* split optimization

* add cuda tree predictor

* remove comma

* refactor objective and score updater

* before use struct

* use structure for split information

* use structure for leaf splits

* return CUDASplitInfo directly after finding best split

* split with CUDATree directly

* use cuda row data in cuda histogram constructor

* clean src/treelearner/cuda

* gather shared cuda device functions

* put shared CUDA functions into header file

* change smaller leaf from <= back to < for consistent result with CPU

* add tree predictor

* remove useless cuda_tree_predictor

* predict on CUDA with pipeline

* add global sort algorithms

* add global argsort for queries with many items in ranking tasks

* remove limitation of maximum number of items per query in ranking

* add cuda metrics

* fix CUDA AUC

* remove debug code

* add regression metrics

* remove useless file

* don't use mask in shuffle reduce

* add more regression objectives

* fix cuda mape loss

add cuda xentropy loss

* use template for different versions of BitonicArgSortDevice

* add multiclass metrics

* add ndcg metric

* fix cross entropy objectives and metrics

* fix cross entropy and ndcg metrics

* add support for customized objective in CUDA

* complete multiclass ova for CUDA

* separate cuda tree learner

* use shuffle based prefix sum

* clean up cuda_algorithms.hpp

* add copy subset on CUDA

* add bagging for CUDA

* clean up code

* copy gradients from host to device

* support bagging without using subset

* add support of bagging with subset for CUDAColumnData

* add support of bagging with subset for dense CUDARowData

* refactor copy sparse subrow

* use copy subset for column subset

* add reset train data and reset config for CUDA tree learner

add deconstructors for cuda tree learner

* add USE_CUDA ifdef to cuda tree learner files

* check that dataset doesn't contain CUDA tree learner

* remove printf debug information

* use full new cuda tree learner only when using single GPU

* disable all CUDA code when using CPU version

* recover main.cpp

* add cpp files for multi value bins

* update LightGBM.vcxproj

* update LightGBM.vcxproj

fix lint errors

* fix lint errors

* fix lint errors

* update Makevars

fix lint errors

* fix the case with 0 feature and 0 bin

fix split finding for invalid leaves

create cuda column data when loaded from bin file

* fix lint errors

hide GetRowWiseData when cuda is not used

* recover default device type to cpu

* fix na_as_missing case

fix cuda feature meta information

* fix UpdateDataIndexToLeafIndexKernel

* create CUDA trees when needed in CUDADataPartition::UpdateTrainScore

* add refit by tree for cuda tree learner

* fix test_refit in test_engine.py

* create set of large bin partitions in CUDARowData

* add histogram construction for columns with a large number of bins

* add find best split for categorical features on CUDA

* add bitvectors for categorical split

* cuda data partition split for categorical features

* fix split tree with categorical feature

* fix categorical feature splits

* refactor cuda_data_partition.cu with multi-level templates

* refactor CUDABestSplitFinder by grouping task information into struct

* pre-allocate space for vector split_find_tasks_ in CUDABestSplitFinder

* fix misuse of reference

* remove useless changes

* add support for path smoothing

* virtual destructor for LightGBM::Tree

* fix overlapped cat threshold in best split infos

* reset histogram pointers in data partition and spllit finder in ResetConfig

* comment useless parameter

* fix reverse case when na is missing and default bin is zero

* fix mfb_is_na and mfb_is_zero and is_single_feature_column

* remove debug log

* fix cat_l2 when one-hot

fix gradient copy when data subset is used

* switch shared histogram size according to CUDA version

* gpu_use_dp=true when cuda test

* revert modification in config.h

* fix setting of gpu_use_dp=true in .ci/test.sh

* fix linter errors

* fix linter error

remove useless change

* recover main.cpp

* separate cuda_exp and cuda

* fix ci bash scripts

add description for cuda_exp

* add USE_CUDA_EXP flag

* switch off USE_CUDA_EXP

* revert changes in python-packages

* more careful separation for USE_CUDA_EXP

* fix CUDARowData::DivideCUDAFeatureGroups

fix set fields for cuda metadata

* revert config.h

* fix test settings for cuda experimental version

* skip some tests due to unsupported features or differences in implementation details for CUDA Experimental version

* fix lint issue by adding a blank line

* fix lint errors by resorting imports

* fix lint errors by resorting imports

* fix lint errors by resorting imports

* merge cuda.yml and cuda_exp.yml

* update python version in cuda.yml

* remove cuda_exp.yml

* remove unrelated changes

* fix compilation warnings

fix cuda exp ci task name

* recover task

* use multi-level template in histogram construction

check split only in debug mode

* ignore NVCC related lines in parameter_generator.py

* update job name for CUDA tests

* apply review suggestions

* Update .github/workflows/cuda.yml
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update .github/workflows/cuda.yml
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* update header

* remove useless TODOs

* remove [TODO(shiyu1994): constrain the split with min_data_in_group] and record in #5062

* #include <LightGBM/utils/log.h> for USE_CUDA_EXP only

* fix include order

* fix include order

* remove extra space

* address review comments

* add warning when cuda_exp is used together with deterministic

* add comment about gpu_use_dp in .ci/test.sh

* revert changing order of included headers
Co-authored-by: default avatarYu Shi <shiyu1994@qq.com>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent b857ee10
......@@ -272,6 +272,16 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
is_load_from_binary = true;
Log::Info("Load from binary file %s", bin_filename.c_str());
dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
dataset->device_type_ = config_.device_type;
dataset->gpu_device_id_ = config_.gpu_device_id;
#ifdef USE_CUDA_EXP
if (config_.device_type == std::string("cuda_exp")) {
dataset->CreateCUDAColumnData();
dataset->metadata_.CreateCUDAMetadata(dataset->gpu_device_id_);
} else {
dataset->cuda_column_data_ = nullptr;
}
#endif // USE_CUDA_EXP
}
// check meta data
dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
......
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#include "dense_bin.hpp"
namespace LightGBM {
template <>
const void* DenseBin<uint8_t, false>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
std::vector<BinIterator*>* bin_iterator,
const int /*num_threads*/) const {
*is_sparse = false;
*bit_type = 8;
bin_iterator->clear();
return reinterpret_cast<const void*>(data_.data());
}
template <>
const void* DenseBin<uint16_t, false>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
std::vector<BinIterator*>* bin_iterator,
const int /*num_threads*/) const {
*is_sparse = false;
*bit_type = 16;
bin_iterator->clear();
return reinterpret_cast<const void*>(data_.data());
}
template <>
const void* DenseBin<uint32_t, false>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
std::vector<BinIterator*>* bin_iterator,
const int /*num_threads*/) const {
*is_sparse = false;
*bit_type = 32;
bin_iterator->clear();
return reinterpret_cast<const void*>(data_.data());
}
template <>
const void* DenseBin<uint8_t, true>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
std::vector<BinIterator*>* bin_iterator,
const int /*num_threads*/) const {
*is_sparse = false;
*bit_type = 4;
bin_iterator->clear();
return reinterpret_cast<const void*>(data_.data());
}
template <>
const void* DenseBin<uint8_t, false>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
BinIterator** bin_iterator) const {
*is_sparse = false;
*bit_type = 8;
*bin_iterator = nullptr;
return reinterpret_cast<const void*>(data_.data());
}
template <>
const void* DenseBin<uint16_t, false>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
BinIterator** bin_iterator) const {
*is_sparse = false;
*bit_type = 16;
*bin_iterator = nullptr;
return reinterpret_cast<const void*>(data_.data());
}
template <>
const void* DenseBin<uint32_t, false>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
BinIterator** bin_iterator) const {
*is_sparse = false;
*bit_type = 32;
*bin_iterator = nullptr;
return reinterpret_cast<const void*>(data_.data());
}
template <>
const void* DenseBin<uint8_t, true>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
BinIterator** bin_iterator) const {
*is_sparse = false;
*bit_type = 4;
*bin_iterator = nullptr;
return reinterpret_cast<const void*>(data_.data());
}
} // namespace LightGBM
......@@ -461,9 +461,13 @@ class DenseBin : public Bin {
DenseBin<VAL_T, IS_4BIT>* Clone() override;
const void* GetColWiseData(uint8_t* bit_type, bool* is_sparse, std::vector<BinIterator*>* bin_iterator, const int num_threads) const override;
const void* GetColWiseData(uint8_t* bit_type, bool* is_sparse, BinIterator** bin_iterator) const override;
private:
data_size_t num_data_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_CUDA_EXP)
std::vector<VAL_T, CHAllocator<VAL_T>> data_;
#else
std::vector<VAL_T, Common::AlignmentAllocator<VAL_T, kAlignedSize>> data_;
......
......@@ -18,6 +18,9 @@ Metadata::Metadata() {
weight_load_from_file_ = false;
query_load_from_file_ = false;
init_score_load_from_file_ = false;
#ifdef USE_CUDA_EXP
cuda_metadata_ = nullptr;
#endif // USE_CUDA_EXP
}
void Metadata::Init(const char* data_filename) {
......@@ -302,6 +305,11 @@ void Metadata::SetInitScore(const double* init_score, data_size_t len) {
init_score_[i] = Common::AvoidInf(init_score[i]);
}
init_score_load_from_file_ = false;
#ifdef USE_CUDA_EXP
if (cuda_metadata_ != nullptr) {
cuda_metadata_->SetInitScore(init_score_.data(), len);
}
#endif // USE_CUDA_EXP
}
void Metadata::SetLabel(const label_t* label, data_size_t len) {
......@@ -318,6 +326,11 @@ void Metadata::SetLabel(const label_t* label, data_size_t len) {
for (data_size_t i = 0; i < num_data_; ++i) {
label_[i] = Common::AvoidInf(label[i]);
}
#ifdef USE_CUDA_EXP
if (cuda_metadata_ != nullptr) {
cuda_metadata_->SetLabel(label_.data(), len);
}
#endif // USE_CUDA_EXP
}
void Metadata::SetWeights(const label_t* weights, data_size_t len) {
......@@ -340,6 +353,11 @@ void Metadata::SetWeights(const label_t* weights, data_size_t len) {
}
LoadQueryWeights();
weight_load_from_file_ = false;
#ifdef USE_CUDA_EXP
if (cuda_metadata_ != nullptr) {
cuda_metadata_->SetWeights(weights_.data(), len);
}
#endif // USE_CUDA_EXP
}
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
......@@ -366,6 +384,16 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
}
LoadQueryWeights();
query_load_from_file_ = false;
#ifdef USE_CUDA_EXP
if (cuda_metadata_ != nullptr) {
if (query_weights_.size() > 0) {
CHECK_EQ(query_weights_.size(), static_cast<size_t>(num_queries_));
cuda_metadata_->SetQuery(query_boundaries_.data(), query_weights_.data(), num_queries_);
} else {
cuda_metadata_->SetQuery(query_boundaries_.data(), nullptr, num_queries_);
}
}
#endif // USE_CUDA_EXP
}
void Metadata::LoadWeights() {
......@@ -472,6 +500,13 @@ void Metadata::LoadQueryWeights() {
}
}
#ifdef USE_CUDA_EXP
void Metadata::CreateCUDAMetadata(const int gpu_device_id) {
cuda_metadata_.reset(new CUDAMetadata(gpu_device_id));
cuda_metadata_->Init(label_, weights_, query_boundaries_, query_weights_, init_score_);
}
#endif // USE_CUDA_EXP
void Metadata::LoadFromMemory(const void* memory) {
const char* mem_ptr = reinterpret_cast<const char*>(memory);
......
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#include "multi_val_dense_bin.hpp"
namespace LightGBM {
#ifdef USE_CUDA_EXP
template <>
const void* MultiValDenseBin<uint8_t>::GetRowWiseData(uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const {
const uint8_t* to_return = data_.data();
*bit_type = 8;
*total_size = static_cast<size_t>(num_data_) * static_cast<size_t>(num_feature_);
CHECK_EQ(*total_size, data_.size());
*is_sparse = false;
*out_data_ptr = nullptr;
*data_ptr_bit_type = 0;
return to_return;
}
template <>
const void* MultiValDenseBin<uint16_t>::GetRowWiseData(uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const {
const uint16_t* data_ptr = data_.data();
const uint8_t* to_return = reinterpret_cast<const uint8_t*>(data_ptr);
*bit_type = 16;
*total_size = static_cast<size_t>(num_data_) * static_cast<size_t>(num_feature_);
CHECK_EQ(*total_size, data_.size());
*is_sparse = false;
*out_data_ptr = nullptr;
*data_ptr_bit_type = 0;
return to_return;
}
template <>
const void* MultiValDenseBin<uint32_t>::GetRowWiseData(uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const {
const uint32_t* data_ptr = data_.data();
const uint8_t* to_return = reinterpret_cast<const uint8_t*>(data_ptr);
*bit_type = 32;
*total_size = static_cast<size_t>(num_data_) * static_cast<size_t>(num_feature_);
CHECK_EQ(*total_size, data_.size());
*is_sparse = false;
*out_data_ptr = nullptr;
*data_ptr_bit_type = 0;
return to_return;
}
#endif // USE_CUDA_EXP
} // namespace LightGBM
......@@ -7,6 +7,7 @@
#include <LightGBM/bin.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/threading.h>
#include <algorithm>
#include <cstdint>
......@@ -210,6 +211,14 @@ class MultiValDenseBin : public MultiValBin {
MultiValDenseBin<VAL_T>* Clone() override;
#ifdef USE_CUDA_EXP
const void* GetRowWiseData(uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const override;
#endif // USE_CUDA_EXP
private:
data_size_t num_data_;
int num_bin_;
......@@ -229,4 +238,5 @@ MultiValDenseBin<VAL_T>* MultiValDenseBin<VAL_T>::Clone() {
}
} // namespace LightGBM
#endif // LIGHTGBM_IO_MULTI_VAL_DENSE_BIN_HPP_
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#include "multi_val_sparse_bin.hpp"
namespace LightGBM {
#ifdef USE_CUDA_EXP
template <>
const void* MultiValSparseBin<uint16_t, uint8_t>::GetRowWiseData(
uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const {
const uint8_t* to_return = data_.data();
*bit_type = 8;
*total_size = data_.size();
*is_sparse = true;
*out_data_ptr = reinterpret_cast<const uint8_t*>(row_ptr_.data());
*data_ptr_bit_type = 16;
return to_return;
}
template <>
const void* MultiValSparseBin<uint16_t, uint16_t>::GetRowWiseData(
uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const {
const uint8_t* to_return = reinterpret_cast<const uint8_t*>(data_.data());
*bit_type = 16;
*total_size = data_.size();
*is_sparse = true;
*out_data_ptr = reinterpret_cast<const uint8_t*>(row_ptr_.data());
*data_ptr_bit_type = 16;
return to_return;
}
template <>
const void* MultiValSparseBin<uint16_t, uint32_t>::GetRowWiseData(
uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const {
const uint8_t* to_return = reinterpret_cast<const uint8_t*>(data_.data());
*bit_type = 32;
*total_size = data_.size();
*is_sparse = true;
*out_data_ptr = reinterpret_cast<const uint8_t*>(row_ptr_.data());
*data_ptr_bit_type = 16;
return to_return;
}
template <>
const void* MultiValSparseBin<uint32_t, uint8_t>::GetRowWiseData(
uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const {
const uint8_t* to_return = data_.data();
*bit_type = 8;
*total_size = data_.size();
*is_sparse = true;
*out_data_ptr = reinterpret_cast<const uint8_t*>(row_ptr_.data());
*data_ptr_bit_type = 32;
return to_return;
}
template <>
const void* MultiValSparseBin<uint32_t, uint16_t>::GetRowWiseData(
uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const {
const uint8_t* to_return = reinterpret_cast<const uint8_t*>(data_.data());
*bit_type = 16;
*total_size = data_.size();
*is_sparse = true;
*out_data_ptr = reinterpret_cast<const uint8_t*>(row_ptr_.data());
*data_ptr_bit_type = 32;
return to_return;
}
template <>
const void* MultiValSparseBin<uint32_t, uint32_t>::GetRowWiseData(
uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const {
const uint8_t* to_return = reinterpret_cast<const uint8_t*>(data_.data());
*bit_type = 32;
*total_size = data_.size();
*is_sparse = true;
*out_data_ptr = reinterpret_cast<const uint8_t*>(row_ptr_.data());
*data_ptr_bit_type = 32;
return to_return;
}
template <>
const void* MultiValSparseBin<uint64_t, uint8_t>::GetRowWiseData(
uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const {
const uint8_t* to_return = data_.data();
*bit_type = 8;
*total_size = data_.size();
*is_sparse = true;
*out_data_ptr = reinterpret_cast<const uint8_t*>(row_ptr_.data());
*data_ptr_bit_type = 64;
return to_return;
}
template <>
const void* MultiValSparseBin<uint64_t, uint16_t>::GetRowWiseData(
uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const {
const uint8_t* to_return = reinterpret_cast<const uint8_t*>(data_.data());
*bit_type = 16;
*total_size = data_.size();
*is_sparse = true;
*out_data_ptr = reinterpret_cast<const uint8_t*>(row_ptr_.data());
*data_ptr_bit_type = 64;
return to_return;
}
template <>
const void* MultiValSparseBin<uint64_t, uint32_t>::GetRowWiseData(
uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const {
const uint8_t* to_return = reinterpret_cast<const uint8_t*>(data_.data());
*bit_type = 32;
*total_size = data_.size();
*is_sparse = true;
*out_data_ptr = reinterpret_cast<const uint8_t*>(row_ptr_.data());
*data_ptr_bit_type = 64;
return to_return;
}
#endif // USE_CUDA_EXP
} // namespace LightGBM
......@@ -7,6 +7,7 @@
#include <LightGBM/bin.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/threading.h>
#include <algorithm>
#include <cstdint>
......@@ -290,6 +291,15 @@ class MultiValSparseBin : public MultiValBin {
MultiValSparseBin<INDEX_T, VAL_T>* Clone() override;
#ifdef USE_CUDA_EXP
const void* GetRowWiseData(uint8_t* bit_type,
size_t* total_size,
bool* is_sparse,
const void** out_data_ptr,
uint8_t* data_ptr_bit_type) const override;
#endif // USE_CUDA_EXP
private:
data_size_t num_data_;
int num_bin_;
......@@ -317,4 +327,5 @@ MultiValSparseBin<INDEX_T, VAL_T>* MultiValSparseBin<INDEX_T, VAL_T>::Clone() {
}
} // namespace LightGBM
#endif // LIGHTGBM_IO_MULTI_VAL_SPARSE_BIN_HPP_
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#include "sparse_bin.hpp"
namespace LightGBM {
template <>
const void* SparseBin<uint8_t>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
std::vector<BinIterator*>* bin_iterator,
const int num_threads) const {
*is_sparse = true;
*bit_type = 8;
for (int thread_index = 0; thread_index < num_threads; ++thread_index) {
bin_iterator->emplace_back(new SparseBinIterator<uint8_t>(this, 0));
}
return nullptr;
}
template <>
const void* SparseBin<uint16_t>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
std::vector<BinIterator*>* bin_iterator,
const int num_threads) const {
*is_sparse = true;
*bit_type = 16;
for (int thread_index = 0; thread_index < num_threads; ++thread_index) {
bin_iterator->emplace_back(new SparseBinIterator<uint16_t>(this, 0));
}
return nullptr;
}
template <>
const void* SparseBin<uint32_t>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
std::vector<BinIterator*>* bin_iterator,
const int num_threads) const {
*is_sparse = true;
*bit_type = 32;
for (int thread_index = 0; thread_index < num_threads; ++thread_index) {
bin_iterator->emplace_back(new SparseBinIterator<uint32_t>(this, 0));
}
return nullptr;
}
template <>
const void* SparseBin<uint8_t>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
BinIterator** bin_iterator) const {
*is_sparse = true;
*bit_type = 8;
*bin_iterator = new SparseBinIterator<uint8_t>(this, 0);
return nullptr;
}
template <>
const void* SparseBin<uint16_t>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
BinIterator** bin_iterator) const {
*is_sparse = true;
*bit_type = 16;
*bin_iterator = new SparseBinIterator<uint16_t>(this, 0);
return nullptr;
}
template <>
const void* SparseBin<uint32_t>::GetColWiseData(
uint8_t* bit_type,
bool* is_sparse,
BinIterator** bin_iterator) const {
*is_sparse = true;
*bit_type = 32;
*bin_iterator = new SparseBinIterator<uint32_t>(this, 0);
return nullptr;
}
} // namespace LightGBM
......@@ -620,6 +620,10 @@ class SparseBin : public Bin {
}
}
const void* GetColWiseData(uint8_t* bit_type, bool* is_sparse, std::vector<BinIterator*>* bin_iterator, const int num_threads) const override;
const void* GetColWiseData(uint8_t* bit_type, bool* is_sparse, BinIterator** bin_iterator) const override;
private:
data_size_t num_data_;
std::vector<uint8_t, Common::AlignmentAllocator<uint8_t, kAlignedSize>>
......@@ -665,4 +669,5 @@ BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin,
}
} // namespace LightGBM
#endif // LightGBM_IO_SPARSE_BIN_HPP_
......@@ -382,6 +382,9 @@ void TrainingShareStates::CalcBinOffsets(const std::vector<std::unique_ptr<Featu
}
num_hist_total_bin_ = static_cast<int>(feature_hist_offsets_.back());
}
#ifdef USE_CUDA_EXP
column_hist_offsets_ = *offsets;
#endif // USE_CUDA_EXP
}
void TrainingShareStates::SetMultiValBin(MultiValBin* bin, data_size_t num_data,
......
......@@ -53,6 +53,9 @@ Tree::Tree(int max_leaves, bool track_branch_features, bool is_linear)
leaf_features_.resize(max_leaves_);
leaf_features_inner_.resize(max_leaves_);
}
#ifdef USE_CUDA_EXP
is_cuda_tree_ = false;
#endif // USE_CUDA_EXP
}
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
......@@ -734,6 +737,10 @@ Tree::Tree(const char* str, size_t* used_len) {
is_linear_ = false;
}
#ifdef USE_CUDA_EXP
is_cuda_tree_ = false;
#endif // USE_CUDA_EXP
if ((num_leaves_ <= 1) && !is_linear_) {
return;
}
......
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA_EXP
#include <algorithm>
#include "cuda_best_split_finder.hpp"
#include "cuda_leaf_splits.hpp"
namespace LightGBM {
CUDABestSplitFinder::CUDABestSplitFinder(
const hist_t* cuda_hist,
const Dataset* train_data,
const std::vector<uint32_t>& feature_hist_offsets,
const Config* config):
num_features_(train_data->num_features()),
num_leaves_(config->num_leaves),
feature_hist_offsets_(feature_hist_offsets),
lambda_l1_(config->lambda_l1),
lambda_l2_(config->lambda_l2),
min_data_in_leaf_(config->min_data_in_leaf),
min_sum_hessian_in_leaf_(config->min_sum_hessian_in_leaf),
min_gain_to_split_(config->min_gain_to_split),
cat_smooth_(config->cat_smooth),
cat_l2_(config->cat_l2),
max_cat_threshold_(config->max_cat_threshold),
min_data_per_group_(config->min_data_per_group),
max_cat_to_onehot_(config->max_cat_to_onehot),
extra_trees_(config->extra_trees),
extra_seed_(config->extra_seed),
use_smoothing_(config->path_smooth > 0),
path_smooth_(config->path_smooth),
num_total_bin_(feature_hist_offsets.empty() ? 0 : static_cast<int>(feature_hist_offsets.back())),
cuda_hist_(cuda_hist) {
InitFeatureMetaInfo(train_data);
cuda_leaf_best_split_info_ = nullptr;
cuda_best_split_info_ = nullptr;
cuda_best_split_info_buffer_ = nullptr;
cuda_is_feature_used_bytree_ = nullptr;
}
CUDABestSplitFinder::~CUDABestSplitFinder() {
DeallocateCUDAMemory<CUDASplitInfo>(&cuda_leaf_best_split_info_, __FILE__, __LINE__);
DeallocateCUDAMemory<CUDASplitInfo>(&cuda_best_split_info_, __FILE__, __LINE__);
DeallocateCUDAMemory<int>(&cuda_best_split_info_buffer_, __FILE__, __LINE__);
cuda_split_find_tasks_.Clear();
DeallocateCUDAMemory<int8_t>(&cuda_is_feature_used_bytree_, __FILE__, __LINE__);
gpuAssert(cudaStreamDestroy(cuda_streams_[0]), __FILE__, __LINE__);
gpuAssert(cudaStreamDestroy(cuda_streams_[1]), __FILE__, __LINE__);
cuda_streams_.clear();
cuda_streams_.shrink_to_fit();
}
void CUDABestSplitFinder::InitFeatureMetaInfo(const Dataset* train_data) {
feature_missing_type_.resize(num_features_);
feature_mfb_offsets_.resize(num_features_);
feature_default_bins_.resize(num_features_);
feature_num_bins_.resize(num_features_);
max_num_bin_in_feature_ = 0;
has_categorical_feature_ = false;
max_num_categorical_bin_ = 0;
is_categorical_.resize(train_data->num_features(), 0);
for (int inner_feature_index = 0; inner_feature_index < num_features_; ++inner_feature_index) {
const BinMapper* bin_mapper = train_data->FeatureBinMapper(inner_feature_index);
if (bin_mapper->bin_type() == BinType::CategoricalBin) {
has_categorical_feature_ = true;
is_categorical_[inner_feature_index] = 1;
if (bin_mapper->num_bin() > max_num_categorical_bin_) {
max_num_categorical_bin_ = bin_mapper->num_bin();
}
}
const MissingType missing_type = bin_mapper->missing_type();
feature_missing_type_[inner_feature_index] = missing_type;
feature_mfb_offsets_[inner_feature_index] = static_cast<int8_t>(bin_mapper->GetMostFreqBin() == 0);
feature_default_bins_[inner_feature_index] = bin_mapper->GetDefaultBin();
feature_num_bins_[inner_feature_index] = static_cast<uint32_t>(bin_mapper->num_bin());
const int num_bin_hist = bin_mapper->num_bin() - feature_mfb_offsets_[inner_feature_index];
if (num_bin_hist > max_num_bin_in_feature_) {
max_num_bin_in_feature_ = num_bin_hist;
}
}
if (max_num_bin_in_feature_ > NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER) {
use_global_memory_ = true;
} else {
use_global_memory_ = false;
}
}
void CUDABestSplitFinder::Init() {
InitCUDAFeatureMetaInfo();
cuda_streams_.resize(2);
CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_streams_[0]));
CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_streams_[1]));
AllocateCUDAMemory<int>(&cuda_best_split_info_buffer_, 8, __FILE__, __LINE__);
if (use_global_memory_) {
AllocateCUDAMemory<hist_t>(&cuda_feature_hist_grad_buffer_, static_cast<size_t>(num_total_bin_), __FILE__, __LINE__);
AllocateCUDAMemory<hist_t>(&cuda_feature_hist_hess_buffer_, static_cast<size_t>(num_total_bin_), __FILE__, __LINE__);
if (has_categorical_feature_) {
AllocateCUDAMemory<hist_t>(&cuda_feature_hist_stat_buffer_, static_cast<size_t>(num_total_bin_), __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_feature_hist_index_buffer_, static_cast<size_t>(num_total_bin_), __FILE__, __LINE__);
}
}
}
void CUDABestSplitFinder::InitCUDAFeatureMetaInfo() {
AllocateCUDAMemory<int8_t>(&cuda_is_feature_used_bytree_, static_cast<size_t>(num_features_), __FILE__, __LINE__);
// intialize split find task information (a split find task is one pass through the histogram of a feature)
num_tasks_ = 0;
for (int inner_feature_index = 0; inner_feature_index < num_features_; ++inner_feature_index) {
const uint32_t num_bin = feature_num_bins_[inner_feature_index];
const MissingType missing_type = feature_missing_type_[inner_feature_index];
if (num_bin > 2 && missing_type != MissingType::None && !is_categorical_[inner_feature_index]) {
num_tasks_ += 2;
} else {
++num_tasks_;
}
}
split_find_tasks_.resize(num_tasks_);
split_find_tasks_.shrink_to_fit();
int cur_task_index = 0;
for (int inner_feature_index = 0; inner_feature_index < num_features_; ++inner_feature_index) {
const uint32_t num_bin = feature_num_bins_[inner_feature_index];
const MissingType missing_type = feature_missing_type_[inner_feature_index];
if (num_bin > 2 && missing_type != MissingType::None && !is_categorical_[inner_feature_index]) {
if (missing_type == MissingType::Zero) {
SplitFindTask* new_task = &split_find_tasks_[cur_task_index];
new_task->reverse = false;
new_task->skip_default_bin = true;
new_task->na_as_missing = false;
new_task->inner_feature_index = inner_feature_index;
new_task->assume_out_default_left = false;
new_task->is_categorical = false;
uint32_t num_bin = feature_num_bins_[inner_feature_index];
new_task->is_one_hot = false;
new_task->hist_offset = feature_hist_offsets_[inner_feature_index];
new_task->mfb_offset = feature_mfb_offsets_[inner_feature_index];
new_task->default_bin = feature_default_bins_[inner_feature_index];
new_task->num_bin = num_bin;
++cur_task_index;
new_task = &split_find_tasks_[cur_task_index];
new_task->reverse = true;
new_task->skip_default_bin = true;
new_task->na_as_missing = false;
new_task->inner_feature_index = inner_feature_index;
new_task->assume_out_default_left = true;
new_task->is_categorical = false;
num_bin = feature_num_bins_[inner_feature_index];
new_task->is_one_hot = false;
new_task->hist_offset = feature_hist_offsets_[inner_feature_index];
new_task->default_bin = feature_default_bins_[inner_feature_index];
new_task->mfb_offset = feature_mfb_offsets_[inner_feature_index];
new_task->num_bin = num_bin;
++cur_task_index;
} else {
SplitFindTask* new_task = &split_find_tasks_[cur_task_index];
new_task->reverse = false;
new_task->skip_default_bin = false;
new_task->na_as_missing = true;
new_task->inner_feature_index = inner_feature_index;
new_task->assume_out_default_left = false;
new_task->is_categorical = false;
uint32_t num_bin = feature_num_bins_[inner_feature_index];
new_task->is_one_hot = false;
new_task->hist_offset = feature_hist_offsets_[inner_feature_index];
new_task->mfb_offset = feature_mfb_offsets_[inner_feature_index];
new_task->default_bin = feature_default_bins_[inner_feature_index];
new_task->num_bin = num_bin;
++cur_task_index;
new_task = &split_find_tasks_[cur_task_index];
new_task->reverse = true;
new_task->skip_default_bin = false;
new_task->na_as_missing = true;
new_task->inner_feature_index = inner_feature_index;
new_task->assume_out_default_left = true;
new_task->is_categorical = false;
num_bin = feature_num_bins_[inner_feature_index];
new_task->is_one_hot = false;
new_task->hist_offset = feature_hist_offsets_[inner_feature_index];
new_task->mfb_offset = feature_mfb_offsets_[inner_feature_index];
new_task->default_bin = feature_default_bins_[inner_feature_index];
new_task->num_bin = num_bin;
++cur_task_index;
}
} else {
SplitFindTask& new_task = split_find_tasks_[cur_task_index];
const uint32_t num_bin = feature_num_bins_[inner_feature_index];
if (is_categorical_[inner_feature_index]) {
new_task.reverse = false;
new_task.is_categorical = true;
new_task.is_one_hot = (static_cast<int>(num_bin) <= max_cat_to_onehot_);
} else {
new_task.reverse = true;
new_task.is_categorical = false;
new_task.is_one_hot = false;
}
new_task.skip_default_bin = false;
new_task.na_as_missing = false;
new_task.inner_feature_index = inner_feature_index;
if (missing_type != MissingType::NaN && !is_categorical_[inner_feature_index]) {
new_task.assume_out_default_left = true;
} else {
new_task.assume_out_default_left = false;
}
new_task.hist_offset = feature_hist_offsets_[inner_feature_index];
new_task.mfb_offset = feature_mfb_offsets_[inner_feature_index];
new_task.default_bin = feature_default_bins_[inner_feature_index];
new_task.num_bin = num_bin;
++cur_task_index;
}
}
CHECK_EQ(cur_task_index, static_cast<int>(split_find_tasks_.size()));
if (extra_trees_) {
cuda_randoms_.Resize(num_tasks_ * 2);
LaunchInitCUDARandomKernel();
}
const int num_task_blocks = (num_tasks_ + NUM_TASKS_PER_SYNC_BLOCK - 1) / NUM_TASKS_PER_SYNC_BLOCK;
const size_t cuda_best_leaf_split_info_buffer_size = static_cast<size_t>(num_task_blocks) * static_cast<size_t>(num_leaves_);
AllocateCUDAMemory<CUDASplitInfo>(&cuda_leaf_best_split_info_,
cuda_best_leaf_split_info_buffer_size,
__FILE__,
__LINE__);
cuda_split_find_tasks_.Resize(num_tasks_);
CopyFromHostToCUDADevice<SplitFindTask>(cuda_split_find_tasks_.RawData(),
split_find_tasks_.data(),
split_find_tasks_.size(),
__FILE__,
__LINE__);
const size_t output_buffer_size = 2 * static_cast<size_t>(num_tasks_);
AllocateCUDAMemory<CUDASplitInfo>(&cuda_best_split_info_, output_buffer_size, __FILE__, __LINE__);
max_num_categories_in_split_ = std::min(max_cat_threshold_, max_num_categorical_bin_ / 2);
AllocateCUDAMemory<uint32_t>(&cuda_cat_threshold_feature_, max_num_categories_in_split_ * output_buffer_size, __FILE__, __LINE__);
AllocateCUDAMemory<int>(&cuda_cat_threshold_real_feature_, max_num_categories_in_split_ * output_buffer_size, __FILE__, __LINE__);
AllocateCUDAMemory<uint32_t>(&cuda_cat_threshold_leaf_, max_num_categories_in_split_ * cuda_best_leaf_split_info_buffer_size, __FILE__, __LINE__);
AllocateCUDAMemory<int>(&cuda_cat_threshold_real_leaf_, max_num_categories_in_split_ * cuda_best_leaf_split_info_buffer_size, __FILE__, __LINE__);
AllocateCatVectors(cuda_leaf_best_split_info_, cuda_cat_threshold_leaf_, cuda_cat_threshold_real_leaf_, cuda_best_leaf_split_info_buffer_size);
AllocateCatVectors(cuda_best_split_info_, cuda_cat_threshold_feature_, cuda_cat_threshold_real_feature_, output_buffer_size);
}
void CUDABestSplitFinder::ResetTrainingData(
const hist_t* cuda_hist,
const Dataset* train_data,
const std::vector<uint32_t>& feature_hist_offsets) {
cuda_hist_ = cuda_hist;
num_features_ = train_data->num_features();
feature_hist_offsets_ = feature_hist_offsets;
InitFeatureMetaInfo(train_data);
DeallocateCUDAMemory<int8_t>(&cuda_is_feature_used_bytree_, __FILE__, __LINE__);
DeallocateCUDAMemory<CUDASplitInfo>(&cuda_best_split_info_, __FILE__, __LINE__);
InitCUDAFeatureMetaInfo();
}
void CUDABestSplitFinder::ResetConfig(const Config* config, const hist_t* cuda_hist) {
num_leaves_ = config->num_leaves;
lambda_l1_ = config->lambda_l1;
lambda_l2_ = config->lambda_l2;
min_data_in_leaf_ = config->min_data_in_leaf;
min_sum_hessian_in_leaf_ = config->min_sum_hessian_in_leaf;
min_gain_to_split_ = config->min_gain_to_split;
cat_smooth_ = config->cat_smooth;
cat_l2_ = config->cat_l2;
max_cat_threshold_ = config->max_cat_threshold;
min_data_per_group_ = config->min_data_per_group;
max_cat_to_onehot_ = config->max_cat_to_onehot;
extra_trees_ = config->extra_trees;
extra_seed_ = config->extra_seed;
use_smoothing_ = (config->path_smooth > 0.0f);
path_smooth_ = config->path_smooth;
cuda_hist_ = cuda_hist;
const int num_task_blocks = (num_tasks_ + NUM_TASKS_PER_SYNC_BLOCK - 1) / NUM_TASKS_PER_SYNC_BLOCK;
size_t cuda_best_leaf_split_info_buffer_size = static_cast<size_t>(num_task_blocks) * static_cast<size_t>(num_leaves_);
DeallocateCUDAMemory<CUDASplitInfo>(&cuda_leaf_best_split_info_, __FILE__, __LINE__);
AllocateCUDAMemory<CUDASplitInfo>(&cuda_leaf_best_split_info_,
cuda_best_leaf_split_info_buffer_size,
__FILE__,
__LINE__);
max_num_categories_in_split_ = std::min(max_cat_threshold_, max_num_categorical_bin_ / 2);
size_t total_cat_threshold_size = max_num_categories_in_split_ * cuda_best_leaf_split_info_buffer_size;
DeallocateCUDAMemory<uint32_t>(&cuda_cat_threshold_leaf_, __FILE__, __LINE__);
DeallocateCUDAMemory<int>(&cuda_cat_threshold_real_leaf_, __FILE__, __LINE__);
AllocateCUDAMemory<uint32_t>(&cuda_cat_threshold_leaf_, total_cat_threshold_size, __FILE__, __LINE__);
AllocateCUDAMemory<int>(&cuda_cat_threshold_real_leaf_, total_cat_threshold_size, __FILE__, __LINE__);
AllocateCatVectors(cuda_leaf_best_split_info_, cuda_cat_threshold_leaf_, cuda_cat_threshold_real_leaf_, cuda_best_leaf_split_info_buffer_size);
cuda_best_leaf_split_info_buffer_size = 2 * static_cast<size_t>(num_tasks_);
total_cat_threshold_size = max_num_categories_in_split_ * cuda_best_leaf_split_info_buffer_size;
DeallocateCUDAMemory<uint32_t>(&cuda_cat_threshold_feature_, __FILE__, __LINE__);
DeallocateCUDAMemory<int>(&cuda_cat_threshold_real_feature_, __FILE__, __LINE__);
AllocateCUDAMemory<uint32_t>(&cuda_cat_threshold_feature_, total_cat_threshold_size, __FILE__, __LINE__);
AllocateCUDAMemory<int>(&cuda_cat_threshold_real_feature_, total_cat_threshold_size, __FILE__, __LINE__);
AllocateCatVectors(cuda_best_split_info_, cuda_cat_threshold_feature_, cuda_cat_threshold_real_feature_, cuda_best_leaf_split_info_buffer_size);
}
void CUDABestSplitFinder::BeforeTrain(const std::vector<int8_t>& is_feature_used_bytree) {
CopyFromHostToCUDADevice<int8_t>(cuda_is_feature_used_bytree_,
is_feature_used_bytree.data(),
is_feature_used_bytree.size(), __FILE__, __LINE__);
}
void CUDABestSplitFinder::FindBestSplitsForLeaf(
const CUDALeafSplitsStruct* smaller_leaf_splits,
const CUDALeafSplitsStruct* larger_leaf_splits,
const int smaller_leaf_index,
const int larger_leaf_index,
const data_size_t num_data_in_smaller_leaf,
const data_size_t num_data_in_larger_leaf,
const double sum_hessians_in_smaller_leaf,
const double sum_hessians_in_larger_leaf) {
const bool is_smaller_leaf_valid = (num_data_in_smaller_leaf > min_data_in_leaf_ &&
sum_hessians_in_smaller_leaf > min_sum_hessian_in_leaf_);
const bool is_larger_leaf_valid = (num_data_in_larger_leaf > min_data_in_leaf_ &&
sum_hessians_in_larger_leaf > min_sum_hessian_in_leaf_ && larger_leaf_index >= 0);
LaunchFindBestSplitsForLeafKernel(smaller_leaf_splits, larger_leaf_splits,
smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid);
global_timer.Start("CUDABestSplitFinder::LaunchSyncBestSplitForLeafKernel");
LaunchSyncBestSplitForLeafKernel(smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid);
SynchronizeCUDADevice(__FILE__, __LINE__);
global_timer.Stop("CUDABestSplitFinder::LaunchSyncBestSplitForLeafKernel");
}
const CUDASplitInfo* CUDABestSplitFinder::FindBestFromAllSplits(
const int cur_num_leaves,
const int smaller_leaf_index,
const int larger_leaf_index,
int* smaller_leaf_best_split_feature,
uint32_t* smaller_leaf_best_split_threshold,
uint8_t* smaller_leaf_best_split_default_left,
int* larger_leaf_best_split_feature,
uint32_t* larger_leaf_best_split_threshold,
uint8_t* larger_leaf_best_split_default_left,
int* best_leaf_index,
int* num_cat_threshold) {
LaunchFindBestFromAllSplitsKernel(
cur_num_leaves,
smaller_leaf_index,
larger_leaf_index,
smaller_leaf_best_split_feature,
smaller_leaf_best_split_threshold,
smaller_leaf_best_split_default_left,
larger_leaf_best_split_feature,
larger_leaf_best_split_threshold,
larger_leaf_best_split_default_left,
best_leaf_index,
num_cat_threshold);
SynchronizeCUDADevice(__FILE__, __LINE__);
return cuda_leaf_best_split_info_ + (*best_leaf_index);
}
void CUDABestSplitFinder::AllocateCatVectors(CUDASplitInfo* cuda_split_infos, uint32_t* cat_threshold_vec, int* cat_threshold_real_vec, size_t len) {
LaunchAllocateCatVectorsKernel(cuda_split_infos, cat_threshold_vec, cat_threshold_real_vec, len);
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA_EXP
#include <algorithm>
#include <LightGBM/cuda/cuda_algorithms.hpp>
#include "cuda_best_split_finder.hpp"
namespace LightGBM {
__device__ void ReduceBestGainWarp(double gain, bool found, uint32_t thread_index, double* out_gain, bool* out_found, uint32_t* out_thread_index) {
const uint32_t mask = 0xffffffff;
const uint32_t warpLane = threadIdx.x % warpSize;
for (uint32_t offset = warpSize / 2; offset > 0; offset >>= 1) {
const bool other_found = __shfl_down_sync(mask, found, offset);
const double other_gain = __shfl_down_sync(mask, gain, offset);
const uint32_t other_thread_index = __shfl_down_sync(mask, thread_index, offset);
if ((other_found && found && other_gain > gain) || (!found && other_found)) {
found = other_found;
gain = other_gain;
thread_index = other_thread_index;
}
}
if (warpLane == 0) {
*out_gain = gain;
*out_found = found;
*out_thread_index = thread_index;
}
}
__device__ uint32_t ReduceBestGainBlock(double gain, bool found, uint32_t thread_index) {
const uint32_t mask = 0xffffffff;
for (uint32_t offset = warpSize / 2; offset > 0; offset >>= 1) {
const bool other_found = __shfl_down_sync(mask, found, offset);
const double other_gain = __shfl_down_sync(mask, gain, offset);
const uint32_t other_thread_index = __shfl_down_sync(mask, thread_index, offset);
if ((other_found && found && other_gain > gain) || (!found && other_found)) {
found = other_found;
gain = other_gain;
thread_index = other_thread_index;
}
}
return thread_index;
}
__device__ uint32_t ReduceBestGain(double gain, bool found, uint32_t thread_index,
double* shared_gain_buffer, bool* shared_found_buffer, uint32_t* shared_thread_index_buffer) {
const uint32_t warpID = threadIdx.x / warpSize;
const uint32_t warpLane = threadIdx.x % warpSize;
const uint32_t num_warp = blockDim.x / warpSize;
ReduceBestGainWarp(gain, found, thread_index, shared_gain_buffer + warpID, shared_found_buffer + warpID, shared_thread_index_buffer + warpID);
__syncthreads();
if (warpID == 0) {
gain = warpLane < num_warp ? shared_gain_buffer[warpLane] : kMinScore;
found = warpLane < num_warp ? shared_found_buffer[warpLane] : false;
thread_index = warpLane < num_warp ? shared_thread_index_buffer[warpLane] : 0;
thread_index = ReduceBestGainBlock(gain, found, thread_index);
}
return thread_index;
}
__device__ void ReduceBestGainForLeaves(double* gain, int* leaves, int cuda_cur_num_leaves) {
const unsigned int tid = threadIdx.x;
for (unsigned int s = 1; s < cuda_cur_num_leaves; s *= 2) {
if (tid % (2 * s) == 0 && (tid + s) < cuda_cur_num_leaves) {
const uint32_t tid_s = tid + s;
if ((leaves[tid] == -1 && leaves[tid_s] != -1) || (leaves[tid] != -1 && leaves[tid_s] != -1 && gain[tid_s] > gain[tid])) {
gain[tid] = gain[tid_s];
leaves[tid] = leaves[tid_s];
}
}
__syncthreads();
}
}
__device__ void ReduceBestGainForLeavesWarp(double gain, int leaf_index, double* out_gain, int* out_leaf_index) {
const uint32_t mask = 0xffffffff;
const uint32_t warpLane = threadIdx.x % warpSize;
for (uint32_t offset = warpSize / 2; offset > 0; offset >>= 1) {
const int other_leaf_index = __shfl_down_sync(mask, leaf_index, offset);
const double other_gain = __shfl_down_sync(mask, gain, offset);
if ((leaf_index != -1 && other_leaf_index != -1 && other_gain > gain) || (leaf_index == -1 && other_leaf_index != -1)) {
gain = other_gain;
leaf_index = other_leaf_index;
}
}
if (warpLane == 0) {
*out_gain = gain;
*out_leaf_index = leaf_index;
}
}
__device__ int ReduceBestGainForLeavesBlock(double gain, int leaf_index) {
const uint32_t mask = 0xffffffff;
for (uint32_t offset = warpSize / 2; offset > 0; offset >>= 1) {
const int other_leaf_index = __shfl_down_sync(mask, leaf_index, offset);
const double other_gain = __shfl_down_sync(mask, gain, offset);
if ((leaf_index != -1 && other_leaf_index != -1 && other_gain > gain) || (leaf_index == -1 && other_leaf_index != -1)) {
gain = other_gain;
leaf_index = other_leaf_index;
}
}
return leaf_index;
}
__device__ int ReduceBestGainForLeaves(double gain, int leaf_index, double* shared_gain_buffer, int* shared_leaf_index_buffer) {
const uint32_t warpID = threadIdx.x / warpSize;
const uint32_t warpLane = threadIdx.x % warpSize;
const uint32_t num_warp = blockDim.x / warpSize;
ReduceBestGainForLeavesWarp(gain, leaf_index, shared_gain_buffer + warpID, shared_leaf_index_buffer + warpID);
__syncthreads();
if (warpID == 0) {
gain = warpLane < num_warp ? shared_gain_buffer[warpLane] : kMinScore;
leaf_index = warpLane < num_warp ? shared_leaf_index_buffer[warpLane] : -1;
leaf_index = ReduceBestGainForLeavesBlock(gain, leaf_index);
}
return leaf_index;
}
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING, bool REVERSE>
__device__ void FindBestSplitsForLeafKernelInner(
// input feature information
const hist_t* feature_hist_ptr,
// input task information
const SplitFindTask* task,
CUDARandom* cuda_random,
// input config parameter values
const double lambda_l1,
const double lambda_l2,
const double path_smooth,
const data_size_t min_data_in_leaf,
const double min_sum_hessian_in_leaf,
const double min_gain_to_split,
// input parent node information
const double parent_gain,
const double sum_gradients,
const double sum_hessians,
const data_size_t num_data,
const double parent_output,
// output parameters
CUDASplitInfo* cuda_best_split_info) {
const double cnt_factor = num_data / sum_hessians;
const double min_gain_shift = parent_gain + min_gain_to_split;
cuda_best_split_info->is_valid = false;
hist_t local_grad_hist = 0.0f;
hist_t local_hess_hist = 0.0f;
double local_gain = 0.0f;
bool threshold_found = false;
uint32_t threshold_value = 0;
__shared__ int rand_threshold;
if (USE_RAND && threadIdx.x == 0) {
if (task->num_bin - 2 > 0) {
rand_threshold = cuda_random->NextInt(0, task->num_bin - 2);
}
}
__shared__ uint32_t best_thread_index;
__shared__ double shared_double_buffer[32];
__shared__ bool shared_bool_buffer[32];
__shared__ uint32_t shared_int_buffer[32];
const unsigned int threadIdx_x = threadIdx.x;
const bool skip_sum = REVERSE ?
(task->skip_default_bin && (task->num_bin - 1 - threadIdx_x) == static_cast<int>(task->default_bin)) :
(task->skip_default_bin && (threadIdx_x + task->mfb_offset) == static_cast<int>(task->default_bin));
const uint32_t feature_num_bin_minus_offset = task->num_bin - task->mfb_offset;
if (!REVERSE) {
if (task->na_as_missing && task->mfb_offset == 1) {
if (threadIdx_x < static_cast<uint32_t>(task->num_bin) && threadIdx_x > 0) {
const unsigned int bin_offset = (threadIdx_x - 1) << 1;
local_grad_hist = feature_hist_ptr[bin_offset];
local_hess_hist = feature_hist_ptr[bin_offset + 1];
}
} else {
if (threadIdx_x < feature_num_bin_minus_offset && !skip_sum) {
const unsigned int bin_offset = threadIdx_x << 1;
local_grad_hist = feature_hist_ptr[bin_offset];
local_hess_hist = feature_hist_ptr[bin_offset + 1];
}
}
} else {
if (threadIdx_x >= static_cast<unsigned int>(task->na_as_missing) &&
threadIdx_x < feature_num_bin_minus_offset && !skip_sum) {
const unsigned int read_index = feature_num_bin_minus_offset - 1 - threadIdx_x;
const unsigned int bin_offset = read_index << 1;
local_grad_hist = feature_hist_ptr[bin_offset];
local_hess_hist = feature_hist_ptr[bin_offset + 1];
}
}
__syncthreads();
if (!REVERSE && task->na_as_missing && task->mfb_offset == 1) {
const hist_t sum_gradients_non_default = ShuffleReduceSum<hist_t>(local_grad_hist, shared_double_buffer, blockDim.x);
__syncthreads();
const hist_t sum_hessians_non_default = ShuffleReduceSum<hist_t>(local_hess_hist, shared_double_buffer, blockDim.x);
if (threadIdx_x == 0) {
local_grad_hist += (sum_gradients - sum_gradients_non_default);
local_hess_hist += (sum_hessians - sum_hessians_non_default);
}
}
if (threadIdx_x == 0) {
local_hess_hist += kEpsilon;
}
local_gain = kMinScore;
local_grad_hist = ShufflePrefixSum(local_grad_hist, shared_double_buffer);
__syncthreads();
local_hess_hist = ShufflePrefixSum(local_hess_hist, shared_double_buffer);
if (REVERSE) {
if (threadIdx_x >= static_cast<unsigned int>(task->na_as_missing) && threadIdx_x <= task->num_bin - 2 && !skip_sum) {
const double sum_right_gradient = local_grad_hist;
const double sum_right_hessian = local_hess_hist;
const data_size_t right_count = static_cast<data_size_t>(__double2int_rn(sum_right_hessian * cnt_factor));
const double sum_left_gradient = sum_gradients - sum_right_gradient;
const double sum_left_hessian = sum_hessians - sum_right_hessian;
const data_size_t left_count = num_data - right_count;
if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf &&
sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf &&
(!USE_RAND || static_cast<int>(task->num_bin - 2 - threadIdx_x) == rand_threshold)) {
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
sum_left_gradient, sum_left_hessian, sum_right_gradient,
sum_right_hessian, lambda_l1,
lambda_l2, path_smooth, left_count, right_count, parent_output);
// gain with split is worse than without split
if (current_gain > min_gain_shift) {
local_gain = current_gain - min_gain_shift;
threshold_value = static_cast<uint32_t>(task->num_bin - 2 - threadIdx_x);
threshold_found = true;
}
}
}
} else {
const uint32_t end = (task->na_as_missing && task->mfb_offset == 1) ? static_cast<uint32_t>(task->num_bin - 2) : feature_num_bin_minus_offset - 2;
if (threadIdx_x <= end && !skip_sum) {
const double sum_left_gradient = local_grad_hist;
const double sum_left_hessian = local_hess_hist;
const data_size_t left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
const double sum_right_gradient = sum_gradients - sum_left_gradient;
const double sum_right_hessian = sum_hessians - sum_left_hessian;
const data_size_t right_count = num_data - left_count;
if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf &&
sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf &&
(!USE_RAND || static_cast<int>(threadIdx_x + task->mfb_offset) == rand_threshold)) {
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
sum_left_gradient, sum_left_hessian, sum_right_gradient,
sum_right_hessian, lambda_l1,
lambda_l2, path_smooth, left_count, right_count, parent_output);
// gain with split is worse than without split
if (current_gain > min_gain_shift) {
local_gain = current_gain - min_gain_shift;
threshold_value = (task->na_as_missing && task->mfb_offset == 1) ?
static_cast<uint32_t>(threadIdx_x) :
static_cast<uint32_t>(threadIdx_x + task->mfb_offset);
threshold_found = true;
}
}
}
}
__syncthreads();
const uint32_t result = ReduceBestGain(local_gain, threshold_found, threadIdx_x, shared_double_buffer, shared_bool_buffer, shared_int_buffer);
if (threadIdx_x == 0) {
best_thread_index = result;
}
__syncthreads();
if (threshold_found && threadIdx_x == best_thread_index) {
cuda_best_split_info->is_valid = true;
cuda_best_split_info->threshold = threshold_value;
cuda_best_split_info->gain = local_gain;
cuda_best_split_info->default_left = task->assume_out_default_left;
if (REVERSE) {
const double sum_right_gradient = local_grad_hist;
const double sum_right_hessian = local_hess_hist - kEpsilon;
const data_size_t right_count = static_cast<data_size_t>(__double2int_rn(sum_right_hessian * cnt_factor));
const double sum_left_gradient = sum_gradients - sum_right_gradient;
const double sum_left_hessian = sum_hessians - sum_right_hessian - kEpsilon;
const data_size_t left_count = num_data - right_count;
const double left_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_left_gradient,
sum_left_hessian, lambda_l1, lambda_l2, path_smooth, left_count, parent_output);
const double right_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_right_gradient,
sum_right_hessian, lambda_l1, lambda_l2, path_smooth, right_count, parent_output);
cuda_best_split_info->left_sum_gradients = sum_left_gradient;
cuda_best_split_info->left_sum_hessians = sum_left_hessian;
cuda_best_split_info->left_count = left_count;
cuda_best_split_info->right_sum_gradients = sum_right_gradient;
cuda_best_split_info->right_sum_hessians = sum_right_hessian;
cuda_best_split_info->right_count = right_count;
cuda_best_split_info->left_value = left_output;
cuda_best_split_info->left_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_left_gradient,
sum_left_hessian, lambda_l1, lambda_l2, left_output);
cuda_best_split_info->right_value = right_output;
cuda_best_split_info->right_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_right_gradient,
sum_right_hessian, lambda_l1, lambda_l2, right_output);
} else {
const double sum_left_gradient = local_grad_hist;
const double sum_left_hessian = local_hess_hist - kEpsilon;
const data_size_t left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
const double sum_right_gradient = sum_gradients - sum_left_gradient;
const double sum_right_hessian = sum_hessians - sum_left_hessian - kEpsilon;
const data_size_t right_count = num_data - left_count;
const double left_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_left_gradient,
sum_left_hessian, lambda_l1, lambda_l2, path_smooth, left_count, parent_output);
const double right_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_right_gradient,
sum_right_hessian, lambda_l1, lambda_l2, path_smooth, right_count, parent_output);
cuda_best_split_info->left_sum_gradients = sum_left_gradient;
cuda_best_split_info->left_sum_hessians = sum_left_hessian;
cuda_best_split_info->left_count = left_count;
cuda_best_split_info->right_sum_gradients = sum_right_gradient;
cuda_best_split_info->right_sum_hessians = sum_right_hessian;
cuda_best_split_info->right_count = right_count;
cuda_best_split_info->left_value = left_output;
cuda_best_split_info->left_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_left_gradient,
sum_left_hessian, lambda_l1, lambda_l2, left_output);
cuda_best_split_info->right_value = right_output;
cuda_best_split_info->right_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_right_gradient,
sum_right_hessian, lambda_l1, lambda_l2, right_output);
}
}
}
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING>
__device__ void FindBestSplitsForLeafKernelCategoricalInner(
// input feature information
const hist_t* feature_hist_ptr,
// input task information
const SplitFindTask* task,
CUDARandom* cuda_random,
// input config parameter values
const double lambda_l1,
const double lambda_l2,
const double path_smooth,
const data_size_t min_data_in_leaf,
const double min_sum_hessian_in_leaf,
const double min_gain_to_split,
const double cat_smooth,
const double cat_l2,
const int max_cat_threshold,
const int min_data_per_group,
// input parent node information
const double parent_gain,
const double sum_gradients,
const double sum_hessians,
const data_size_t num_data,
const double parent_output,
// output parameters
CUDASplitInfo* cuda_best_split_info) {
__shared__ double shared_gain_buffer[32];
__shared__ bool shared_found_buffer[32];
__shared__ uint32_t shared_thread_index_buffer[32];
__shared__ uint32_t best_thread_index;
const double cnt_factor = num_data / sum_hessians;
const double min_gain_shift = parent_gain + min_gain_to_split;
double l2 = lambda_l2;
double local_gain = min_gain_shift;
bool threshold_found = false;
cuda_best_split_info->is_valid = false;
const int bin_start = 1 - task->mfb_offset;
const int bin_end = task->num_bin - task->mfb_offset;
const int threadIdx_x = static_cast<int>(threadIdx.x);
__shared__ int rand_threshold;
if (task->is_one_hot) {
if (USE_RAND && threadIdx.x == 0) {
rand_threshold = 0;
if (bin_end > bin_start) {
rand_threshold = cuda_random->NextInt(bin_start, bin_end);
}
}
__syncthreads();
if (threadIdx_x >= bin_start && threadIdx_x < bin_end) {
const int bin_offset = (threadIdx_x << 1);
const hist_t grad = feature_hist_ptr[bin_offset];
const hist_t hess = feature_hist_ptr[bin_offset + 1];
data_size_t cnt =
static_cast<data_size_t>(__double2int_rn(hess * cnt_factor));
if (cnt >= min_data_in_leaf && hess >= min_sum_hessian_in_leaf) {
const data_size_t other_count = num_data - cnt;
if (other_count >= min_data_in_leaf) {
const double sum_other_hessian = sum_hessians - hess - kEpsilon;
if (sum_other_hessian >= min_sum_hessian_in_leaf && (!USE_RAND || static_cast<int>(threadIdx_x) == rand_threshold)) {
const double sum_other_gradient = sum_gradients - grad;
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
sum_other_gradient, sum_other_hessian, grad,
hess + kEpsilon, lambda_l1,
l2, path_smooth, other_count, cnt, parent_output);
if (current_gain > min_gain_shift) {
local_gain = current_gain;
threshold_found = true;
}
}
}
}
}
__syncthreads();
const uint32_t result = ReduceBestGain(local_gain, threshold_found, threadIdx_x, shared_gain_buffer, shared_found_buffer, shared_thread_index_buffer);
if (threadIdx_x == 0) {
best_thread_index = result;
}
__syncthreads();
if (threshold_found && threadIdx_x == best_thread_index) {
cuda_best_split_info->is_valid = true;
cuda_best_split_info->num_cat_threshold = 1;
cuda_best_split_info->gain = local_gain - min_gain_shift;
*(cuda_best_split_info->cat_threshold) = static_cast<uint32_t>(threadIdx_x + task->mfb_offset);
cuda_best_split_info->default_left = false;
const int bin_offset = (threadIdx_x << 1);
const hist_t sum_left_gradient = feature_hist_ptr[bin_offset];
const hist_t sum_left_hessian = feature_hist_ptr[bin_offset + 1];
const data_size_t left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
const double sum_right_gradient = sum_gradients - sum_left_gradient;
const double sum_right_hessian = sum_hessians - sum_left_hessian;
const data_size_t right_count = static_cast<data_size_t>(__double2int_rn(sum_right_hessian * cnt_factor));
const double left_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_left_gradient,
sum_left_hessian, lambda_l1, l2, path_smooth, left_count, parent_output);
const double right_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_right_gradient,
sum_right_hessian, lambda_l1, l2, path_smooth, right_count, parent_output);
cuda_best_split_info->left_sum_gradients = sum_left_gradient;
cuda_best_split_info->left_sum_hessians = sum_left_hessian;
cuda_best_split_info->left_count = left_count;
cuda_best_split_info->right_sum_gradients = sum_right_gradient;
cuda_best_split_info->right_sum_hessians = sum_right_hessian;
cuda_best_split_info->right_count = right_count;
cuda_best_split_info->left_value = left_output;
cuda_best_split_info->left_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_left_gradient,
sum_left_hessian, lambda_l1, l2, left_output);
cuda_best_split_info->right_value = right_output;
cuda_best_split_info->right_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_right_gradient,
sum_right_hessian, lambda_l1, l2, right_output);
}
} else {
__shared__ double shared_value_buffer[NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER];
__shared__ int16_t shared_index_buffer[NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER];
__shared__ uint16_t shared_mem_buffer_uint16[32];
__shared__ double shared_mem_buffer_double[32];
__shared__ int used_bin;
l2 += cat_l2;
uint16_t is_valid_bin = 0;
int best_dir = 0;
double best_sum_left_gradient = 0.0f;
double best_sum_left_hessian = 0.0f;
if (threadIdx_x >= bin_start && threadIdx_x < bin_end) {
const int bin_offset = (threadIdx_x << 1);
const double hess = feature_hist_ptr[bin_offset + 1];
if (__double2int_rn(hess * cnt_factor) >= cat_smooth) {
const double grad = feature_hist_ptr[bin_offset];
shared_value_buffer[threadIdx_x] = grad / (hess + cat_smooth);
is_valid_bin = 1;
} else {
shared_value_buffer[threadIdx_x] = kMaxScore;
}
} else {
shared_value_buffer[threadIdx_x] = kMaxScore;
}
shared_index_buffer[threadIdx_x] = threadIdx_x;
__syncthreads();
const int local_used_bin = ShuffleReduceSum<uint16_t>(is_valid_bin, shared_mem_buffer_uint16, blockDim.x);
if (threadIdx_x == 0) {
used_bin = local_used_bin;
}
__syncthreads();
BitonicArgSort_1024<double, int16_t, true>(shared_value_buffer, shared_index_buffer, bin_end);
__syncthreads();
const int max_num_cat = min(max_cat_threshold, (used_bin + 1) / 2);
if (USE_RAND) {
rand_threshold = 0;
const int max_threshold = max(min(max_num_cat, used_bin) - 1, 0);
if (max_threshold > 0) {
rand_threshold = cuda_random->NextInt(0, max_threshold);
}
}
// left to right
double grad = 0.0f;
double hess = 0.0f;
if (threadIdx_x < used_bin && threadIdx_x < max_num_cat) {
const int bin_offset = (shared_index_buffer[threadIdx_x] << 1);
grad = feature_hist_ptr[bin_offset];
hess = feature_hist_ptr[bin_offset + 1];
}
if (threadIdx_x == 0) {
hess += kEpsilon;
}
__syncthreads();
double sum_left_gradient = ShufflePrefixSum<double>(grad, shared_mem_buffer_double);
__syncthreads();
double sum_left_hessian = ShufflePrefixSum<double>(hess, shared_mem_buffer_double);
if (threadIdx_x < used_bin && threadIdx_x < max_num_cat) {
const data_size_t left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
const double sum_right_gradient = sum_gradients - sum_left_gradient;
const double sum_right_hessian = sum_hessians - sum_left_hessian;
const data_size_t right_count = num_data - left_count;
if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf &&
sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf &&
(!USE_RAND || threadIdx_x == static_cast<int>(rand_threshold))) {
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
sum_left_gradient, sum_left_hessian, sum_right_gradient,
sum_right_hessian, lambda_l1,
l2, path_smooth, left_count, right_count, parent_output);
// gain with split is worse than without split
if (current_gain > local_gain) {
local_gain = current_gain;
threshold_found = true;
best_dir = 1;
best_sum_left_gradient = sum_left_gradient;
best_sum_left_hessian = sum_left_hessian;
}
}
}
__syncthreads();
// right to left
grad = 0.0f;
hess = 0.0f;
if (threadIdx_x < used_bin && threadIdx_x < max_num_cat) {
const int bin_offset = (shared_index_buffer[used_bin - 1 - threadIdx_x] << 1);
grad = feature_hist_ptr[bin_offset];
hess = feature_hist_ptr[bin_offset + 1];
}
if (threadIdx_x == 0) {
hess += kEpsilon;
}
__syncthreads();
sum_left_gradient = ShufflePrefixSum<double>(grad, shared_mem_buffer_double);
__syncthreads();
sum_left_hessian = ShufflePrefixSum<double>(hess, shared_mem_buffer_double);
if (threadIdx_x < used_bin && threadIdx_x < max_num_cat) {
const data_size_t left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
const double sum_right_gradient = sum_gradients - sum_left_gradient;
const double sum_right_hessian = sum_hessians - sum_left_hessian;
const data_size_t right_count = num_data - left_count;
if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf &&
sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf &&
(!USE_RAND || threadIdx_x == static_cast<int>(rand_threshold))) {
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
sum_left_gradient, sum_left_hessian, sum_right_gradient,
sum_right_hessian, lambda_l1,
l2, path_smooth, left_count, right_count, parent_output);
// gain with split is worse than without split
if (current_gain > local_gain) {
local_gain = current_gain;
threshold_found = true;
best_dir = -1;
best_sum_left_gradient = sum_left_gradient;
best_sum_left_hessian = sum_left_hessian;
}
}
}
__syncthreads();
const uint32_t result = ReduceBestGain(local_gain, threshold_found, threadIdx_x, shared_gain_buffer, shared_found_buffer, shared_thread_index_buffer);
if (threadIdx_x == 0) {
best_thread_index = result;
}
__syncthreads();
if (threshold_found && threadIdx_x == best_thread_index) {
cuda_best_split_info->is_valid = true;
cuda_best_split_info->num_cat_threshold = threadIdx_x + 1;
cuda_best_split_info->gain = local_gain - min_gain_shift;
if (best_dir == 1) {
for (int i = 0; i < threadIdx_x + 1; ++i) {
(cuda_best_split_info->cat_threshold)[i] = shared_index_buffer[i] + task->mfb_offset;
}
} else {
for (int i = 0; i < threadIdx_x + 1; ++i) {
(cuda_best_split_info->cat_threshold)[i] = shared_index_buffer[used_bin - 1 - i] + task->mfb_offset;
}
}
cuda_best_split_info->default_left = false;
const hist_t sum_left_gradient = best_sum_left_gradient;
const hist_t sum_left_hessian = best_sum_left_hessian;
const data_size_t left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
const double sum_right_gradient = sum_gradients - sum_left_gradient;
const double sum_right_hessian = sum_hessians - sum_left_hessian;
const data_size_t right_count = static_cast<data_size_t>(__double2int_rn(sum_right_hessian * cnt_factor));
const double left_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_left_gradient,
sum_left_hessian, lambda_l1, l2, path_smooth, left_count, parent_output);
const double right_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_right_gradient,
sum_right_hessian, lambda_l1, l2, path_smooth, right_count, parent_output);
cuda_best_split_info->left_sum_gradients = sum_left_gradient;
cuda_best_split_info->left_sum_hessians = sum_left_hessian;
cuda_best_split_info->left_count = left_count;
cuda_best_split_info->right_sum_gradients = sum_right_gradient;
cuda_best_split_info->right_sum_hessians = sum_right_hessian;
cuda_best_split_info->right_count = right_count;
cuda_best_split_info->left_value = left_output;
cuda_best_split_info->left_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_left_gradient,
sum_left_hessian, lambda_l1, l2, left_output);
cuda_best_split_info->right_value = right_output;
cuda_best_split_info->right_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_right_gradient,
sum_right_hessian, lambda_l1, l2, right_output);
}
}
}
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING, bool IS_LARGER>
__global__ void FindBestSplitsForLeafKernel(
// input feature information
const int8_t* is_feature_used_bytree,
// input task information
const int num_tasks,
const SplitFindTask* tasks,
CUDARandom* cuda_randoms,
// input leaf information
const CUDALeafSplitsStruct* smaller_leaf_splits,
const CUDALeafSplitsStruct* larger_leaf_splits,
// input config parameter values
const data_size_t min_data_in_leaf,
const double min_sum_hessian_in_leaf,
const double min_gain_to_split,
const double lambda_l1,
const double lambda_l2,
const double path_smooth,
const double cat_smooth,
const double cat_l2,
const int max_cat_threshold,
const int min_data_per_group,
// output
CUDASplitInfo* cuda_best_split_info) {
const unsigned int task_index = blockIdx.x;
const SplitFindTask* task = tasks + task_index;
const int inner_feature_index = task->inner_feature_index;
const double parent_gain = IS_LARGER ? larger_leaf_splits->gain : smaller_leaf_splits->gain;
const double sum_gradients = IS_LARGER ? larger_leaf_splits->sum_of_gradients : smaller_leaf_splits->sum_of_gradients;
const double sum_hessians = (IS_LARGER ? larger_leaf_splits->sum_of_hessians : smaller_leaf_splits->sum_of_hessians) + 2 * kEpsilon;
const data_size_t num_data = IS_LARGER ? larger_leaf_splits->num_data_in_leaf : smaller_leaf_splits->num_data_in_leaf;
const double parent_output = IS_LARGER ? larger_leaf_splits->leaf_value : smaller_leaf_splits->leaf_value;
const unsigned int output_offset = IS_LARGER ? (task_index + num_tasks) : task_index;
CUDASplitInfo* out = cuda_best_split_info + output_offset;
CUDARandom* cuda_random = USE_RAND ?
(IS_LARGER ? cuda_randoms + task_index * 2 + 1 : cuda_randoms + task_index * 2) : nullptr;
if (is_feature_used_bytree[inner_feature_index]) {
const hist_t* hist_ptr = (IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + task->hist_offset * 2;
if (task->is_categorical) {
FindBestSplitsForLeafKernelCategoricalInner<USE_RAND, USE_L1, USE_SMOOTHING>(
// input feature information
hist_ptr,
// input task information
task,
cuda_random,
// input config parameter values
lambda_l1,
lambda_l2,
path_smooth,
min_data_in_leaf,
min_sum_hessian_in_leaf,
min_gain_to_split,
cat_smooth,
cat_l2,
max_cat_threshold,
min_data_per_group,
// input parent node information
parent_gain,
sum_gradients,
sum_hessians,
num_data,
parent_output,
// output parameters
out);
} else {
if (!task->reverse) {
FindBestSplitsForLeafKernelInner<USE_RAND, USE_L1, USE_SMOOTHING, false>(
// input feature information
hist_ptr,
// input task information
task,
cuda_random,
// input config parameter values
lambda_l1,
lambda_l2,
path_smooth,
min_data_in_leaf,
min_sum_hessian_in_leaf,
min_gain_to_split,
// input parent node information
parent_gain,
sum_gradients,
sum_hessians,
num_data,
parent_output,
// output parameters
out);
} else {
FindBestSplitsForLeafKernelInner<USE_RAND, USE_L1, USE_SMOOTHING, true>(
// input feature information
hist_ptr,
// input task information
task,
cuda_random,
// input config parameter values
lambda_l1,
lambda_l2,
path_smooth,
min_data_in_leaf,
min_sum_hessian_in_leaf,
min_gain_to_split,
// input parent node information
parent_gain,
sum_gradients,
sum_hessians,
num_data,
parent_output,
// output parameters
out);
}
}
} else {
out->is_valid = false;
}
}
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING, bool REVERSE>
__device__ void FindBestSplitsForLeafKernelInner_GlobalMemory(
// input feature information
const hist_t* feature_hist_ptr,
// input task information
const SplitFindTask* task,
CUDARandom* cuda_random,
// input config parameter values
const double lambda_l1,
const double lambda_l2,
const double path_smooth,
const data_size_t min_data_in_leaf,
const double min_sum_hessian_in_leaf,
const double min_gain_to_split,
// input parent node information
const double parent_gain,
const double sum_gradients,
const double sum_hessians,
const data_size_t num_data,
const double parent_output,
// output parameters
CUDASplitInfo* cuda_best_split_info,
// buffer
hist_t* hist_grad_buffer_ptr,
hist_t* hist_hess_buffer_ptr) {
const double cnt_factor = num_data / sum_hessians;
const double min_gain_shift = parent_gain + min_gain_to_split;
cuda_best_split_info->is_valid = false;
double local_gain = 0.0f;
bool threshold_found = false;
uint32_t threshold_value = 0;
__shared__ int rand_threshold;
if (USE_RAND && threadIdx.x == 0) {
if (task->num_bin - 2 > 0) {
rand_threshold = cuda_random->NextInt(0, task->num_bin - 2);
}
}
__shared__ uint32_t best_thread_index;
__shared__ double shared_double_buffer[32];
__shared__ bool shared_found_buffer[32];
__shared__ uint32_t shared_thread_index_buffer[32];
const unsigned int threadIdx_x = threadIdx.x;
const uint32_t feature_num_bin_minus_offset = task->num_bin - task->mfb_offset;
if (!REVERSE) {
if (task->na_as_missing && task->mfb_offset == 1) {
uint32_t bin_start = threadIdx_x > 0 ? threadIdx_x : blockDim.x;
hist_t thread_sum_gradients = 0.0f;
hist_t thread_sum_hessians = 0.0f;
for (unsigned int bin = bin_start; bin < static_cast<uint32_t>(task->num_bin); bin += blockDim.x) {
const unsigned int bin_offset = (bin - 1) << 1;
const hist_t grad = feature_hist_ptr[bin_offset];
const hist_t hess = feature_hist_ptr[bin_offset + 1];
hist_grad_buffer_ptr[bin] = grad;
hist_hess_buffer_ptr[bin] = hess;
thread_sum_gradients += grad;
thread_sum_hessians += hess;
}
const hist_t sum_gradients_non_default = ShuffleReduceSum<double>(thread_sum_gradients, shared_double_buffer, blockDim.x);
__syncthreads();
const hist_t sum_hessians_non_default = ShuffleReduceSum<double>(thread_sum_hessians, shared_double_buffer, blockDim.x);
if (threadIdx_x == 0) {
hist_grad_buffer_ptr[0] = sum_gradients - sum_gradients_non_default;
hist_hess_buffer_ptr[0] = sum_hessians - sum_hessians_non_default;
}
} else {
for (unsigned int bin = threadIdx_x; bin < feature_num_bin_minus_offset; bin += blockDim.x) {
const bool skip_sum =
(task->skip_default_bin && (bin + task->mfb_offset) == static_cast<int>(task->default_bin));
if (!skip_sum) {
const unsigned int bin_offset = bin << 1;
hist_grad_buffer_ptr[bin] = feature_hist_ptr[bin_offset];
hist_hess_buffer_ptr[bin] = feature_hist_ptr[bin_offset + 1];
} else {
hist_grad_buffer_ptr[bin] = 0.0f;
hist_hess_buffer_ptr[bin] = 0.0f;
}
}
}
} else {
for (unsigned int bin = threadIdx_x; bin < feature_num_bin_minus_offset; bin += blockDim.x) {
const bool skip_sum = bin >= static_cast<unsigned int>(task->na_as_missing) &&
(task->skip_default_bin && (task->num_bin - 1 - bin) == static_cast<int>(task->default_bin));
if (!skip_sum) {
const unsigned int read_index = feature_num_bin_minus_offset - 1 - bin;
const unsigned int bin_offset = read_index << 1;
hist_grad_buffer_ptr[bin] = feature_hist_ptr[bin_offset];
hist_hess_buffer_ptr[bin] = feature_hist_ptr[bin_offset + 1];
} else {
hist_grad_buffer_ptr[bin] = 0.0f;
hist_hess_buffer_ptr[bin] = 0.0f;
}
}
}
__syncthreads();
if (threadIdx_x == 0) {
hist_hess_buffer_ptr[0] += kEpsilon;
}
local_gain = kMinScore;
GlobalMemoryPrefixSum(hist_grad_buffer_ptr, static_cast<size_t>(feature_num_bin_minus_offset));
__syncthreads();
GlobalMemoryPrefixSum(hist_hess_buffer_ptr, static_cast<size_t>(feature_num_bin_minus_offset));
if (REVERSE) {
for (unsigned int bin = threadIdx_x; bin < feature_num_bin_minus_offset; bin += blockDim.x) {
const bool skip_sum = (bin >= static_cast<unsigned int>(task->na_as_missing) &&
(task->skip_default_bin && (task->num_bin - 1 - bin) == static_cast<int>(task->default_bin)));
if (!skip_sum) {
const double sum_right_gradient = hist_grad_buffer_ptr[bin];
const double sum_right_hessian = hist_hess_buffer_ptr[bin];
const data_size_t right_count = static_cast<data_size_t>(__double2int_rn(sum_right_hessian * cnt_factor));
const double sum_left_gradient = sum_gradients - sum_right_gradient;
const double sum_left_hessian = sum_hessians - sum_right_hessian;
const data_size_t left_count = num_data - right_count;
if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf &&
sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf &&
(!USE_RAND || static_cast<int>(task->num_bin - 2 - bin) == rand_threshold)) {
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
sum_left_gradient, sum_left_hessian, sum_right_gradient,
sum_right_hessian, lambda_l1,
lambda_l2, path_smooth, left_count, right_count, parent_output);
// gain with split is worse than without split
if (current_gain > min_gain_shift) {
local_gain = current_gain - min_gain_shift;
threshold_value = static_cast<uint32_t>(task->num_bin - 2 - bin);
threshold_found = true;
}
}
}
}
} else {
const uint32_t end = (task->na_as_missing && task->mfb_offset == 1) ? static_cast<uint32_t>(task->num_bin - 2) : feature_num_bin_minus_offset - 2;
for (unsigned int bin = threadIdx_x; bin <= end; bin += blockDim.x) {
const bool skip_sum =
(task->skip_default_bin && (bin + task->mfb_offset) == static_cast<int>(task->default_bin));
if (!skip_sum) {
const double sum_left_gradient = hist_grad_buffer_ptr[bin];
const double sum_left_hessian = hist_hess_buffer_ptr[bin];
const data_size_t left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
const double sum_right_gradient = sum_gradients - sum_left_gradient;
const double sum_right_hessian = sum_hessians - sum_left_hessian;
const data_size_t right_count = num_data - left_count;
if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf &&
sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf &&
(!USE_RAND || static_cast<int>(bin + task->mfb_offset) == rand_threshold)) {
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
sum_left_gradient, sum_left_hessian, sum_right_gradient,
sum_right_hessian, lambda_l1,
lambda_l2, path_smooth, left_count, right_count, parent_output);
// gain with split is worse than without split
if (current_gain > min_gain_shift) {
local_gain = current_gain - min_gain_shift;
threshold_value = (task->na_as_missing && task->mfb_offset == 1) ?
bin : static_cast<uint32_t>(bin + task->mfb_offset);
threshold_found = true;
}
}
}
}
}
__syncthreads();
const uint32_t result = ReduceBestGain(local_gain, threshold_found, threadIdx_x, shared_double_buffer, shared_found_buffer, shared_thread_index_buffer);
if (threadIdx_x == 0) {
best_thread_index = result;
}
__syncthreads();
if (threshold_found && threadIdx_x == best_thread_index) {
cuda_best_split_info->is_valid = true;
cuda_best_split_info->threshold = threshold_value;
cuda_best_split_info->gain = local_gain;
cuda_best_split_info->default_left = task->assume_out_default_left;
if (REVERSE) {
const unsigned int best_bin = static_cast<uint32_t>(task->num_bin - 2 - threshold_value);
const double sum_right_gradient = hist_grad_buffer_ptr[best_bin];
const double sum_right_hessian = hist_hess_buffer_ptr[best_bin] - kEpsilon;
const data_size_t right_count = static_cast<data_size_t>(__double2int_rn(sum_right_hessian * cnt_factor));
const double sum_left_gradient = sum_gradients - sum_right_gradient;
const double sum_left_hessian = sum_hessians - sum_right_hessian - kEpsilon;
const data_size_t left_count = num_data - right_count;
const double left_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_left_gradient,
sum_left_hessian, lambda_l1, lambda_l2, path_smooth, left_count, parent_output);
const double right_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_right_gradient,
sum_right_hessian, lambda_l1, lambda_l2, path_smooth, right_count, parent_output);
cuda_best_split_info->left_sum_gradients = sum_left_gradient;
cuda_best_split_info->left_sum_hessians = sum_left_hessian;
cuda_best_split_info->left_count = left_count;
cuda_best_split_info->right_sum_gradients = sum_right_gradient;
cuda_best_split_info->right_sum_hessians = sum_right_hessian;
cuda_best_split_info->right_count = right_count;
cuda_best_split_info->left_value = left_output;
cuda_best_split_info->left_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_left_gradient,
sum_left_hessian, lambda_l1, lambda_l2, left_output);
cuda_best_split_info->right_value = right_output;
cuda_best_split_info->right_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_right_gradient,
sum_right_hessian, lambda_l1, lambda_l2, right_output);
} else {
const unsigned int best_bin = (task->na_as_missing && task->mfb_offset == 1) ?
threshold_value : static_cast<uint32_t>(threshold_value - task->mfb_offset);
const double sum_left_gradient = hist_grad_buffer_ptr[best_bin];
const double sum_left_hessian = hist_hess_buffer_ptr[best_bin] - kEpsilon;
const data_size_t left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
const double sum_right_gradient = sum_gradients - sum_left_gradient;
const double sum_right_hessian = sum_hessians - sum_left_hessian - kEpsilon;
const data_size_t right_count = num_data - left_count;
const double left_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_left_gradient,
sum_left_hessian, lambda_l1, lambda_l2, path_smooth, left_count, parent_output);
const double right_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_right_gradient,
sum_right_hessian, lambda_l1, lambda_l2, path_smooth, right_count, parent_output);
cuda_best_split_info->left_sum_gradients = sum_left_gradient;
cuda_best_split_info->left_sum_hessians = sum_left_hessian;
cuda_best_split_info->left_count = left_count;
cuda_best_split_info->right_sum_gradients = sum_right_gradient;
cuda_best_split_info->right_sum_hessians = sum_right_hessian;
cuda_best_split_info->right_count = right_count;
cuda_best_split_info->left_value = left_output;
cuda_best_split_info->left_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_left_gradient,
sum_left_hessian, lambda_l1, lambda_l2, left_output);
cuda_best_split_info->right_value = right_output;
cuda_best_split_info->right_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_right_gradient,
sum_right_hessian, lambda_l1, lambda_l2, right_output);
}
}
}
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING>
__device__ void FindBestSplitsForLeafKernelCategoricalInner_GlobalMemory(
// input feature information
const hist_t* feature_hist_ptr,
// input task information
const SplitFindTask* task,
CUDARandom* cuda_random,
// input config parameter values
const double lambda_l1,
const double lambda_l2,
const double path_smooth,
const data_size_t min_data_in_leaf,
const double min_sum_hessian_in_leaf,
const double min_gain_to_split,
const double cat_smooth,
const double cat_l2,
const int max_cat_threshold,
const int min_data_per_group,
// input parent node information
const double parent_gain,
const double sum_gradients,
const double sum_hessians,
const data_size_t num_data,
const double parent_output,
// buffer
hist_t* hist_grad_buffer_ptr,
hist_t* hist_hess_buffer_ptr,
hist_t* hist_stat_buffer_ptr,
data_size_t* hist_index_buffer_ptr,
// output parameters
CUDASplitInfo* cuda_best_split_info) {
__shared__ double shared_gain_buffer[32];
__shared__ bool shared_found_buffer[32];
__shared__ uint32_t shared_thread_index_buffer[32];
__shared__ uint32_t best_thread_index;
const double cnt_factor = num_data / sum_hessians;
const double min_gain_shift = parent_gain + min_gain_to_split;
double l2 = lambda_l2;
double local_gain = kMinScore;
bool threshold_found = false;
cuda_best_split_info->is_valid = false;
__shared__ int rand_threshold;
const int bin_start = 1 - task->mfb_offset;
const int bin_end = task->num_bin - task->mfb_offset;
int best_threshold = -1;
const int threadIdx_x = static_cast<int>(threadIdx.x);
if (task->is_one_hot) {
if (USE_RAND && threadIdx.x == 0) {
rand_threshold = 0;
if (bin_end > bin_start) {
rand_threshold = cuda_random->NextInt(bin_start, bin_end);
}
}
__syncthreads();
for (int bin = bin_start + threadIdx_x; bin < bin_end; bin += static_cast<int>(blockDim.x)) {
const int bin_offset = (bin << 1);
const hist_t grad = feature_hist_ptr[bin_offset];
const hist_t hess = feature_hist_ptr[bin_offset + 1];
data_size_t cnt =
static_cast<data_size_t>(__double2int_rn(hess * cnt_factor));
if (cnt >= min_data_in_leaf && hess >= min_sum_hessian_in_leaf) {
const data_size_t other_count = num_data - cnt;
if (other_count >= min_data_in_leaf) {
const double sum_other_hessian = sum_hessians - hess - kEpsilon;
if (sum_other_hessian >= min_sum_hessian_in_leaf && (!USE_RAND || bin == rand_threshold)) {
const double sum_other_gradient = sum_gradients - grad;
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
sum_other_gradient, sum_other_hessian, grad,
hess + kEpsilon, lambda_l1,
l2, path_smooth, other_count, cnt, parent_output);
if (current_gain > min_gain_shift) {
best_threshold = bin;
local_gain = current_gain - min_gain_shift;
threshold_found = true;
}
}
}
}
}
__syncthreads();
const uint32_t result = ReduceBestGain(local_gain, threshold_found, threadIdx_x, shared_gain_buffer, shared_found_buffer, shared_thread_index_buffer);
if (threadIdx_x == 0) {
best_thread_index = result;
}
__syncthreads();
if (threshold_found && threadIdx_x == best_thread_index) {
cuda_best_split_info->is_valid = true;
cuda_best_split_info->num_cat_threshold = 1;
cuda_best_split_info->cat_threshold = new uint32_t[1];
*(cuda_best_split_info->cat_threshold) = static_cast<uint32_t>(best_threshold);
cuda_best_split_info->default_left = false;
const int bin_offset = (best_threshold << 1);
const hist_t sum_left_gradient = feature_hist_ptr[bin_offset];
const hist_t sum_left_hessian = feature_hist_ptr[bin_offset + 1];
const data_size_t left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
const double sum_right_gradient = sum_gradients - sum_left_gradient;
const double sum_right_hessian = sum_hessians - sum_left_hessian;
const data_size_t right_count = static_cast<data_size_t>(__double2int_rn(sum_right_hessian * cnt_factor));
const double left_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_left_gradient,
sum_left_hessian, lambda_l1, l2, path_smooth, left_count, parent_output);
const double right_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_right_gradient,
sum_right_hessian, lambda_l1, l2, path_smooth, right_count, parent_output);
cuda_best_split_info->left_sum_gradients = sum_left_gradient;
cuda_best_split_info->left_sum_hessians = sum_left_hessian;
cuda_best_split_info->left_count = left_count;
cuda_best_split_info->right_sum_gradients = sum_right_gradient;
cuda_best_split_info->right_sum_hessians = sum_right_hessian;
cuda_best_split_info->right_count = right_count;
cuda_best_split_info->left_value = left_output;
cuda_best_split_info->left_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_left_gradient,
sum_left_hessian, lambda_l1, l2, left_output);
cuda_best_split_info->right_value = right_output;
cuda_best_split_info->right_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_right_gradient,
sum_right_hessian, lambda_l1, l2, right_output);
}
} else {
__shared__ uint16_t shared_mem_buffer_uint16[32];
__shared__ int used_bin;
l2 += cat_l2;
uint16_t is_valid_bin = 0;
int best_dir = 0;
double best_sum_left_gradient = 0.0f;
double best_sum_left_hessian = 0.0f;
for (int bin = 0; bin < bin_end; bin += static_cast<int>(blockDim.x)) {
if (bin >= bin_start) {
const int bin_offset = (bin << 1);
const double hess = feature_hist_ptr[bin_offset + 1];
if (__double2int_rn(hess * cnt_factor) >= cat_smooth) {
const double grad = feature_hist_ptr[bin_offset];
hist_stat_buffer_ptr[bin] = grad / (hess + cat_smooth);
hist_index_buffer_ptr[bin] = threadIdx_x;
is_valid_bin = 1;
} else {
hist_stat_buffer_ptr[bin] = kMaxScore;
hist_index_buffer_ptr[bin] = -1;
}
}
}
__syncthreads();
const int local_used_bin = ShuffleReduceSum<uint16_t>(is_valid_bin, shared_mem_buffer_uint16, blockDim.x);
if (threadIdx_x == 0) {
used_bin = local_used_bin;
}
__syncthreads();
BitonicArgSortDevice<double, data_size_t, true, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 11>(
hist_stat_buffer_ptr, hist_index_buffer_ptr, task->num_bin - task->mfb_offset);
const int max_num_cat = min(max_cat_threshold, (used_bin + 1) / 2);
if (USE_RAND) {
rand_threshold = 0;
const int max_threshold = max(min(max_num_cat, used_bin) - 1, 0);
if (max_threshold > 0) {
rand_threshold = cuda_random->NextInt(0, max_threshold);
}
}
__syncthreads();
// left to right
for (int bin = static_cast<int>(threadIdx_x); bin < used_bin && bin < max_num_cat; bin += static_cast<int>(blockDim.x)) {
const int bin_offset = (hist_index_buffer_ptr[bin] << 1);
hist_grad_buffer_ptr[bin] = feature_hist_ptr[bin_offset];
hist_hess_buffer_ptr[bin] = feature_hist_ptr[bin_offset + 1];
}
if (threadIdx_x == 0) {
hist_hess_buffer_ptr[0] += kEpsilon;
}
__syncthreads();
GlobalMemoryPrefixSum<double>(hist_grad_buffer_ptr, static_cast<size_t>(bin_end));
__syncthreads();
GlobalMemoryPrefixSum<double>(hist_hess_buffer_ptr, static_cast<size_t>(bin_end));
for (int bin = static_cast<int>(threadIdx_x); bin < used_bin && bin < max_num_cat; bin += static_cast<int>(blockDim.x)) {
const double sum_left_gradient = hist_grad_buffer_ptr[bin];
const double sum_left_hessian = hist_hess_buffer_ptr[bin];
const data_size_t left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
const double sum_right_gradient = sum_gradients - sum_left_gradient;
const double sum_right_hessian = sum_hessians - sum_left_hessian;
const data_size_t right_count = num_data - left_count;
if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf &&
sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf) {
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
sum_left_gradient, sum_left_hessian, sum_right_gradient,
sum_right_hessian, lambda_l1,
l2, path_smooth, left_count, right_count, parent_output);
// gain with split is worse than without split
if (current_gain > min_gain_shift) {
local_gain = current_gain - min_gain_shift;
threshold_found = true;
best_dir = 1;
best_sum_left_gradient = sum_left_gradient;
best_sum_left_hessian = sum_left_hessian;
best_threshold = bin;
}
}
}
__syncthreads();
// right to left
for (int bin = static_cast<int>(threadIdx_x); bin < used_bin && bin < max_num_cat; bin += static_cast<int>(blockDim.x)) {
const int bin_offset = (hist_index_buffer_ptr[used_bin - 1 - bin] << 1);
hist_grad_buffer_ptr[bin] = feature_hist_ptr[bin_offset];
hist_hess_buffer_ptr[bin] = feature_hist_ptr[bin_offset + 1];
}
if (threadIdx_x == 0) {
hist_hess_buffer_ptr[0] += kEpsilon;
}
__syncthreads();
GlobalMemoryPrefixSum<double>(hist_grad_buffer_ptr, static_cast<size_t>(bin_end));
__syncthreads();
GlobalMemoryPrefixSum<double>(hist_hess_buffer_ptr, static_cast<size_t>(bin_end));
for (int bin = static_cast<int>(threadIdx_x); bin < used_bin && bin < max_num_cat; bin += static_cast<int>(blockDim.x)) {
const double sum_left_gradient = hist_grad_buffer_ptr[bin];
const double sum_left_hessian = hist_hess_buffer_ptr[bin];
const data_size_t left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
const double sum_right_gradient = sum_gradients - sum_left_gradient;
const double sum_right_hessian = sum_hessians - sum_left_hessian;
const data_size_t right_count = num_data - left_count;
if (sum_left_hessian >= min_sum_hessian_in_leaf && left_count >= min_data_in_leaf &&
sum_right_hessian >= min_sum_hessian_in_leaf && right_count >= min_data_in_leaf) {
double current_gain = CUDALeafSplits::GetSplitGains<USE_L1, USE_SMOOTHING>(
sum_left_gradient, sum_left_hessian, sum_right_gradient,
sum_right_hessian, lambda_l1,
l2, path_smooth, left_count, right_count, parent_output);
// gain with split is worse than without split
if (current_gain > min_gain_shift) {
local_gain = current_gain - min_gain_shift;
threshold_found = true;
best_dir = -1;
best_sum_left_gradient = sum_left_gradient;
best_sum_left_hessian = sum_left_hessian;
best_threshold = bin;
}
}
}
__syncthreads();
const uint32_t result = ReduceBestGain(local_gain, threshold_found, threadIdx_x, shared_gain_buffer, shared_found_buffer, shared_thread_index_buffer);
if (threadIdx_x == 0) {
best_thread_index = result;
}
__syncthreads();
if (threshold_found && threadIdx_x == best_thread_index) {
cuda_best_split_info->is_valid = true;
cuda_best_split_info->num_cat_threshold = best_threshold + 1;
cuda_best_split_info->cat_threshold = new uint32_t[best_threshold + 1];
cuda_best_split_info->gain = local_gain;
if (best_dir == 1) {
for (int i = 0; i < best_threshold + 1; ++i) {
(cuda_best_split_info->cat_threshold)[i] = hist_index_buffer_ptr[i] + task->mfb_offset;
}
} else {
for (int i = 0; i < best_threshold + 1; ++i) {
(cuda_best_split_info->cat_threshold)[i] = hist_index_buffer_ptr[used_bin - 1 - i] + task->mfb_offset;
}
}
cuda_best_split_info->default_left = false;
const hist_t sum_left_gradient = best_sum_left_gradient;
const hist_t sum_left_hessian = best_sum_left_hessian;
const data_size_t left_count = static_cast<data_size_t>(__double2int_rn(sum_left_hessian * cnt_factor));
const double sum_right_gradient = sum_gradients - sum_left_gradient;
const double sum_right_hessian = sum_hessians - sum_left_hessian;
const data_size_t right_count = static_cast<data_size_t>(__double2int_rn(sum_right_hessian * cnt_factor));
const double left_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_left_gradient,
sum_left_hessian, lambda_l1, l2, path_smooth, left_count, parent_output);
const double right_output = CUDALeafSplits::CalculateSplittedLeafOutput<USE_L1, USE_SMOOTHING>(sum_right_gradient,
sum_right_hessian, lambda_l1, l2, path_smooth, right_count, parent_output);
cuda_best_split_info->left_sum_gradients = sum_left_gradient;
cuda_best_split_info->left_sum_hessians = sum_left_hessian;
cuda_best_split_info->left_count = left_count;
cuda_best_split_info->right_sum_gradients = sum_right_gradient;
cuda_best_split_info->right_sum_hessians = sum_right_hessian;
cuda_best_split_info->right_count = right_count;
cuda_best_split_info->left_value = left_output;
cuda_best_split_info->left_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_left_gradient,
sum_left_hessian, lambda_l1, l2, left_output);
cuda_best_split_info->right_value = right_output;
cuda_best_split_info->right_gain = CUDALeafSplits::GetLeafGainGivenOutput<USE_L1>(sum_right_gradient,
sum_right_hessian, lambda_l1, l2, right_output);
}
}
}
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING, bool IS_LARGER>
__global__ void FindBestSplitsForLeafKernel_GlobalMemory(
// input feature information
const int8_t* is_feature_used_bytree,
// input task information
const int num_tasks,
const SplitFindTask* tasks,
CUDARandom* cuda_randoms,
// input leaf information
const CUDALeafSplitsStruct* smaller_leaf_splits,
const CUDALeafSplitsStruct* larger_leaf_splits,
// input config parameter values
const data_size_t min_data_in_leaf,
const double min_sum_hessian_in_leaf,
const double min_gain_to_split,
const double lambda_l1,
const double lambda_l2,
const double path_smooth,
const double cat_smooth,
const double cat_l2,
const int max_cat_threshold,
const int min_data_per_group,
// output
CUDASplitInfo* cuda_best_split_info,
// buffer
hist_t* feature_hist_grad_buffer,
hist_t* feature_hist_hess_buffer,
hist_t* feature_hist_stat_buffer,
data_size_t* feature_hist_index_buffer) {
const unsigned int task_index = blockIdx.x;
const SplitFindTask* task = tasks + task_index;
const double parent_gain = IS_LARGER ? larger_leaf_splits->gain : smaller_leaf_splits->gain;
const double sum_gradients = IS_LARGER ? larger_leaf_splits->sum_of_gradients : smaller_leaf_splits->sum_of_gradients;
const double sum_hessians = (IS_LARGER ? larger_leaf_splits->sum_of_hessians : smaller_leaf_splits->sum_of_hessians) + 2 * kEpsilon;
const data_size_t num_data = IS_LARGER ? larger_leaf_splits->num_data_in_leaf : smaller_leaf_splits->num_data_in_leaf;
const double parent_output = IS_LARGER ? larger_leaf_splits->leaf_value : smaller_leaf_splits->leaf_value;
const unsigned int output_offset = IS_LARGER ? (task_index + num_tasks) : task_index;
CUDASplitInfo* out = cuda_best_split_info + output_offset;
CUDARandom* cuda_random = USE_RAND ?
(IS_LARGER ? cuda_randoms + task_index * 2 + 1: cuda_randoms + task_index * 2) : nullptr;
if (is_feature_used_bytree[task->inner_feature_index]) {
const uint32_t hist_offset = task->hist_offset;
const hist_t* hist_ptr = (IS_LARGER ? larger_leaf_splits->hist_in_leaf : smaller_leaf_splits->hist_in_leaf) + hist_offset * 2;
hist_t* hist_grad_buffer_ptr = feature_hist_grad_buffer + hist_offset * 2;
hist_t* hist_hess_buffer_ptr = feature_hist_hess_buffer + hist_offset * 2;
hist_t* hist_stat_buffer_ptr = feature_hist_stat_buffer + hist_offset * 2;
data_size_t* hist_index_buffer_ptr = feature_hist_index_buffer + hist_offset * 2;
if (task->is_categorical) {
FindBestSplitsForLeafKernelCategoricalInner_GlobalMemory<USE_RAND, USE_L1, USE_SMOOTHING>(
// input feature information
hist_ptr,
// input task information
task,
cuda_random,
// input config parameter values
lambda_l1,
lambda_l2,
path_smooth,
min_data_in_leaf,
min_sum_hessian_in_leaf,
min_gain_to_split,
cat_smooth,
cat_l2,
max_cat_threshold,
min_data_per_group,
// input parent node information
parent_gain,
sum_gradients,
sum_hessians,
num_data,
parent_output,
// buffer
hist_grad_buffer_ptr,
hist_hess_buffer_ptr,
hist_stat_buffer_ptr,
hist_index_buffer_ptr,
// output parameters
out);
} else {
if (!task->reverse) {
FindBestSplitsForLeafKernelInner_GlobalMemory<USE_RAND, USE_L1, USE_SMOOTHING, false>(
// input feature information
hist_ptr,
// input task information
task,
cuda_random,
// input config parameter values
lambda_l1,
lambda_l2,
path_smooth,
min_data_in_leaf,
min_sum_hessian_in_leaf,
min_gain_to_split,
// input parent node information
parent_gain,
sum_gradients,
sum_hessians,
num_data,
parent_output,
// output parameters
out,
// buffer
hist_grad_buffer_ptr,
hist_hess_buffer_ptr);
} else {
FindBestSplitsForLeafKernelInner_GlobalMemory<USE_RAND, USE_L1, USE_SMOOTHING, true>(
// input feature information
hist_ptr,
// input task information
task,
cuda_random,
// input config parameter values
lambda_l1,
lambda_l2,
path_smooth,
min_data_in_leaf,
min_sum_hessian_in_leaf,
min_gain_to_split,
// input parent node information
parent_gain,
sum_gradients,
sum_hessians,
num_data,
parent_output,
// output parameters
out,
// buffer
hist_grad_buffer_ptr,
hist_hess_buffer_ptr);
}
}
} else {
out->is_valid = false;
}
}
#define LaunchFindBestSplitsForLeafKernel_PARAMS \
const CUDALeafSplitsStruct* smaller_leaf_splits, \
const CUDALeafSplitsStruct* larger_leaf_splits, \
const int smaller_leaf_index, \
const int larger_leaf_index, \
const bool is_smaller_leaf_valid, \
const bool is_larger_leaf_valid
#define LaunchFindBestSplitsForLeafKernel_ARGS \
smaller_leaf_splits, \
larger_leaf_splits, \
smaller_leaf_index, \
larger_leaf_index, \
is_smaller_leaf_valid, \
is_larger_leaf_valid
#define FindBestSplitsForLeafKernel_ARGS \
cuda_is_feature_used_bytree_, \
num_tasks_, \
cuda_split_find_tasks_.RawData(), \
cuda_randoms_.RawData(), \
smaller_leaf_splits, \
larger_leaf_splits, \
min_data_in_leaf_, \
min_sum_hessian_in_leaf_, \
min_gain_to_split_, \
lambda_l1_, \
lambda_l2_, \
path_smooth_, \
cat_smooth_, \
cat_l2_, \
max_cat_threshold_, \
min_data_per_group_, \
cuda_best_split_info_
#define GlobalMemory_Buffer_ARGS \
cuda_feature_hist_grad_buffer_, \
cuda_feature_hist_hess_buffer_, \
cuda_feature_hist_stat_buffer_, \
cuda_feature_hist_index_buffer_
void CUDABestSplitFinder::LaunchFindBestSplitsForLeafKernel(LaunchFindBestSplitsForLeafKernel_PARAMS) {
if (!is_smaller_leaf_valid && !is_larger_leaf_valid) {
return;
}
if (!extra_trees_) {
LaunchFindBestSplitsForLeafKernelInner0<false>(LaunchFindBestSplitsForLeafKernel_ARGS);
} else {
LaunchFindBestSplitsForLeafKernelInner0<true>(LaunchFindBestSplitsForLeafKernel_ARGS);
}
}
template <bool USE_RAND>
void CUDABestSplitFinder::LaunchFindBestSplitsForLeafKernelInner0(LaunchFindBestSplitsForLeafKernel_PARAMS) {
if (lambda_l1_ <= 0.0f) {
LaunchFindBestSplitsForLeafKernelInner1<USE_RAND, false>(LaunchFindBestSplitsForLeafKernel_ARGS);
} else {
LaunchFindBestSplitsForLeafKernelInner1<USE_RAND, true>(LaunchFindBestSplitsForLeafKernel_ARGS);
}
}
template <bool USE_RAND, bool USE_L1>
void CUDABestSplitFinder::LaunchFindBestSplitsForLeafKernelInner1(LaunchFindBestSplitsForLeafKernel_PARAMS) {
if (!use_smoothing_) {
LaunchFindBestSplitsForLeafKernelInner2<USE_RAND, USE_L1, false>(LaunchFindBestSplitsForLeafKernel_ARGS);
} else {
LaunchFindBestSplitsForLeafKernelInner2<USE_RAND, USE_L1, true>(LaunchFindBestSplitsForLeafKernel_ARGS);
}
}
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING>
void CUDABestSplitFinder::LaunchFindBestSplitsForLeafKernelInner2(LaunchFindBestSplitsForLeafKernel_PARAMS) {
if (!use_global_memory_) {
if (is_smaller_leaf_valid) {
FindBestSplitsForLeafKernel<USE_RAND, USE_L1, USE_SMOOTHING, false>
<<<num_tasks_, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 0, cuda_streams_[0]>>>
(FindBestSplitsForLeafKernel_ARGS);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
if (is_larger_leaf_valid) {
FindBestSplitsForLeafKernel<USE_RAND, USE_L1, USE_SMOOTHING, true>
<<<num_tasks_, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 0, cuda_streams_[1]>>>
(FindBestSplitsForLeafKernel_ARGS);
}
} else {
if (is_smaller_leaf_valid) {
FindBestSplitsForLeafKernel_GlobalMemory<USE_RAND, USE_L1, USE_SMOOTHING, false>
<<<num_tasks_, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 0, cuda_streams_[0]>>>
(FindBestSplitsForLeafKernel_ARGS, GlobalMemory_Buffer_ARGS);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
if (is_larger_leaf_valid) {
FindBestSplitsForLeafKernel_GlobalMemory<USE_RAND, USE_L1, USE_SMOOTHING, true>
<<<num_tasks_, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER, 0, cuda_streams_[1]>>>
(FindBestSplitsForLeafKernel_ARGS, GlobalMemory_Buffer_ARGS);
}
}
}
#undef LaunchFindBestSplitsForLeafKernel_PARAMS
#undef FindBestSplitsForLeafKernel_ARGS
#undef GlobalMemory_Buffer_ARGS
__device__ void ReduceBestSplit(bool* found, double* gain, uint32_t* shared_read_index,
uint32_t num_features_aligned) {
const uint32_t threadIdx_x = threadIdx.x;
for (unsigned int s = 1; s < num_features_aligned; s <<= 1) {
if (threadIdx_x % (2 * s) == 0 && (threadIdx_x + s) < num_features_aligned) {
const uint32_t pos_to_compare = threadIdx_x + s;
if ((!found[threadIdx_x] && found[pos_to_compare]) ||
(found[threadIdx_x] && found[pos_to_compare] && gain[threadIdx_x] < gain[pos_to_compare])) {
found[threadIdx_x] = found[pos_to_compare];
gain[threadIdx_x] = gain[pos_to_compare];
shared_read_index[threadIdx_x] = shared_read_index[pos_to_compare];
}
}
__syncthreads();
}
}
__global__ void SyncBestSplitForLeafKernel(const int smaller_leaf_index, const int larger_leaf_index,
CUDASplitInfo* cuda_leaf_best_split_info,
// input parameters
const SplitFindTask* tasks,
const CUDASplitInfo* cuda_best_split_info,
const int num_tasks,
const int num_tasks_aligned,
const int num_blocks_per_leaf,
const bool larger_only,
const int num_leaves) {
__shared__ double shared_gain_buffer[32];
__shared__ bool shared_found_buffer[32];
__shared__ uint32_t shared_thread_index_buffer[32];
const uint32_t threadIdx_x = threadIdx.x;
const uint32_t blockIdx_x = blockIdx.x;
bool best_found = false;
double best_gain = kMinScore;
uint32_t shared_read_index = 0;
const bool is_smaller = (blockIdx_x < static_cast<unsigned int>(num_blocks_per_leaf) && !larger_only);
const uint32_t leaf_block_index = (is_smaller || larger_only) ? blockIdx_x : (blockIdx_x - static_cast<unsigned int>(num_blocks_per_leaf));
const int task_index = static_cast<int>(leaf_block_index * blockDim.x + threadIdx_x);
const uint32_t read_index = is_smaller ? static_cast<uint32_t>(task_index) : static_cast<uint32_t>(task_index + num_tasks);
if (task_index < num_tasks) {
best_found = cuda_best_split_info[read_index].is_valid;
best_gain = cuda_best_split_info[read_index].gain;
shared_read_index = read_index;
} else {
best_found = false;
}
__syncthreads();
const uint32_t best_read_index = ReduceBestGain(best_gain, best_found, shared_read_index,
shared_gain_buffer, shared_found_buffer, shared_thread_index_buffer);
if (threadIdx.x == 0) {
const int leaf_index_ref = is_smaller ? smaller_leaf_index : larger_leaf_index;
const unsigned buffer_write_pos = static_cast<unsigned int>(leaf_index_ref) + leaf_block_index * num_leaves;
CUDASplitInfo* cuda_split_info = cuda_leaf_best_split_info + buffer_write_pos;
const CUDASplitInfo* best_split_info = cuda_best_split_info + best_read_index;
if (best_split_info->is_valid) {
*cuda_split_info = *best_split_info;
cuda_split_info->inner_feature_index = is_smaller ? tasks[best_read_index].inner_feature_index :
tasks[static_cast<int>(best_read_index) - num_tasks].inner_feature_index;
cuda_split_info->is_valid = true;
} else {
cuda_split_info->gain = kMinScore;
cuda_split_info->is_valid = false;
}
}
}
__global__ void SyncBestSplitForLeafKernelAllBlocks(
const int smaller_leaf_index,
const int larger_leaf_index,
const unsigned int num_blocks_per_leaf,
const int num_leaves,
CUDASplitInfo* cuda_leaf_best_split_info,
const bool larger_only) {
if (!larger_only) {
if (blockIdx.x == 0) {
CUDASplitInfo* smaller_leaf_split_info = cuda_leaf_best_split_info + smaller_leaf_index;
for (unsigned int block_index = 1; block_index < num_blocks_per_leaf; ++block_index) {
const unsigned int leaf_read_pos = static_cast<unsigned int>(smaller_leaf_index) + block_index * static_cast<unsigned int>(num_leaves);
const CUDASplitInfo* other_split_info = cuda_leaf_best_split_info + leaf_read_pos;
if ((other_split_info->is_valid && smaller_leaf_split_info->is_valid &&
other_split_info->gain > smaller_leaf_split_info->gain) ||
(!smaller_leaf_split_info->is_valid && other_split_info->is_valid)) {
*smaller_leaf_split_info = *other_split_info;
}
}
}
}
if (larger_leaf_index >= 0) {
if (blockIdx.x == 1 || larger_only) {
CUDASplitInfo* larger_leaf_split_info = cuda_leaf_best_split_info + larger_leaf_index;
for (unsigned int block_index = 1; block_index < num_blocks_per_leaf; ++block_index) {
const unsigned int leaf_read_pos = static_cast<unsigned int>(larger_leaf_index) + block_index * static_cast<unsigned int>(num_leaves);
const CUDASplitInfo* other_split_info = cuda_leaf_best_split_info + leaf_read_pos;
if ((other_split_info->is_valid && larger_leaf_split_info->is_valid &&
other_split_info->gain > larger_leaf_split_info->gain) ||
(!larger_leaf_split_info->is_valid && other_split_info->is_valid)) {
*larger_leaf_split_info = *other_split_info;
}
}
}
}
}
__global__ void SetInvalidLeafSplitInfoKernel(
CUDASplitInfo* cuda_leaf_best_split_info,
const bool is_smaller_leaf_valid,
const bool is_larger_leaf_valid,
const int smaller_leaf_index,
const int larger_leaf_index) {
if (!is_smaller_leaf_valid) {
cuda_leaf_best_split_info[smaller_leaf_index].is_valid = false;
}
if (!is_larger_leaf_valid && larger_leaf_index >= 0) {
cuda_leaf_best_split_info[larger_leaf_index].is_valid = false;
}
}
void CUDABestSplitFinder::LaunchSyncBestSplitForLeafKernel(
const int host_smaller_leaf_index,
const int host_larger_leaf_index,
const bool is_smaller_leaf_valid,
const bool is_larger_leaf_valid) {
if (!is_smaller_leaf_valid || !is_larger_leaf_valid) {
SetInvalidLeafSplitInfoKernel<<<1, 1>>>(
cuda_leaf_best_split_info_,
is_smaller_leaf_valid, is_larger_leaf_valid,
host_smaller_leaf_index, host_larger_leaf_index);
}
if (!is_smaller_leaf_valid && !is_larger_leaf_valid) {
return;
}
int num_tasks = num_tasks_;
int num_tasks_aligned = 1;
num_tasks -= 1;
while (num_tasks > 0) {
num_tasks_aligned <<= 1;
num_tasks >>= 1;
}
const int num_blocks_per_leaf = (num_tasks_ + NUM_TASKS_PER_SYNC_BLOCK - 1) / NUM_TASKS_PER_SYNC_BLOCK;
if (host_larger_leaf_index >= 0 && is_smaller_leaf_valid && is_larger_leaf_valid) {
SyncBestSplitForLeafKernel<<<num_blocks_per_leaf, NUM_TASKS_PER_SYNC_BLOCK, 0, cuda_streams_[0]>>>(
host_smaller_leaf_index,
host_larger_leaf_index,
cuda_leaf_best_split_info_,
cuda_split_find_tasks_.RawData(),
cuda_best_split_info_,
num_tasks_,
num_tasks_aligned,
num_blocks_per_leaf,
false,
num_leaves_);
if (num_blocks_per_leaf > 1) {
SyncBestSplitForLeafKernelAllBlocks<<<1, 1, 0, cuda_streams_[0]>>>(
host_smaller_leaf_index,
host_larger_leaf_index,
num_blocks_per_leaf,
num_leaves_,
cuda_leaf_best_split_info_,
false);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
SyncBestSplitForLeafKernel<<<num_blocks_per_leaf, NUM_TASKS_PER_SYNC_BLOCK, 0, cuda_streams_[1]>>>(
host_smaller_leaf_index,
host_larger_leaf_index,
cuda_leaf_best_split_info_,
cuda_split_find_tasks_.RawData(),
cuda_best_split_info_,
num_tasks_,
num_tasks_aligned,
num_blocks_per_leaf,
true,
num_leaves_);
if (num_blocks_per_leaf > 1) {
SyncBestSplitForLeafKernelAllBlocks<<<1, 1, 0, cuda_streams_[1]>>>(
host_smaller_leaf_index,
host_larger_leaf_index,
num_blocks_per_leaf,
num_leaves_,
cuda_leaf_best_split_info_,
true);
}
} else {
const bool larger_only = (!is_smaller_leaf_valid && is_larger_leaf_valid);
SyncBestSplitForLeafKernel<<<num_blocks_per_leaf, NUM_TASKS_PER_SYNC_BLOCK>>>(
host_smaller_leaf_index,
host_larger_leaf_index,
cuda_leaf_best_split_info_,
cuda_split_find_tasks_.RawData(),
cuda_best_split_info_,
num_tasks_,
num_tasks_aligned,
num_blocks_per_leaf,
larger_only,
num_leaves_);
if (num_blocks_per_leaf > 1) {
SynchronizeCUDADevice(__FILE__, __LINE__);
SyncBestSplitForLeafKernelAllBlocks<<<1, 1>>>(
host_smaller_leaf_index,
host_larger_leaf_index,
num_blocks_per_leaf,
num_leaves_,
cuda_leaf_best_split_info_,
larger_only);
}
}
}
__global__ void FindBestFromAllSplitsKernel(const int cur_num_leaves,
CUDASplitInfo* cuda_leaf_best_split_info,
int* cuda_best_split_info_buffer) {
__shared__ double gain_shared_buffer[32];
__shared__ int leaf_index_shared_buffer[32];
double thread_best_gain = kMinScore;
int thread_best_leaf_index = -1;
const int threadIdx_x = static_cast<int>(threadIdx.x);
for (int leaf_index = threadIdx_x; leaf_index < cur_num_leaves; leaf_index += static_cast<int>(blockDim.x)) {
const double leaf_best_gain = cuda_leaf_best_split_info[leaf_index].gain;
if (cuda_leaf_best_split_info[leaf_index].is_valid && leaf_best_gain > thread_best_gain) {
thread_best_gain = leaf_best_gain;
thread_best_leaf_index = leaf_index;
}
}
const int best_leaf_index = ReduceBestGainForLeaves(thread_best_gain, thread_best_leaf_index, gain_shared_buffer, leaf_index_shared_buffer);
if (threadIdx_x == 0) {
cuda_best_split_info_buffer[6] = best_leaf_index;
if (best_leaf_index != -1) {
cuda_leaf_best_split_info[best_leaf_index].is_valid = false;
cuda_leaf_best_split_info[cur_num_leaves].is_valid = false;
cuda_best_split_info_buffer[7] = cuda_leaf_best_split_info[best_leaf_index].num_cat_threshold;
}
}
}
__global__ void PrepareLeafBestSplitInfo(const int smaller_leaf_index, const int larger_leaf_index,
int* cuda_best_split_info_buffer,
const CUDASplitInfo* cuda_leaf_best_split_info) {
const unsigned int threadIdx_x = blockIdx.x;
if (threadIdx_x == 0) {
cuda_best_split_info_buffer[0] = cuda_leaf_best_split_info[smaller_leaf_index].inner_feature_index;
} else if (threadIdx_x == 1) {
cuda_best_split_info_buffer[1] = cuda_leaf_best_split_info[smaller_leaf_index].threshold;
} else if (threadIdx_x == 2) {
cuda_best_split_info_buffer[2] = cuda_leaf_best_split_info[smaller_leaf_index].default_left;
}
if (larger_leaf_index >= 0) {
if (threadIdx_x == 3) {
cuda_best_split_info_buffer[3] = cuda_leaf_best_split_info[larger_leaf_index].inner_feature_index;
} else if (threadIdx_x == 4) {
cuda_best_split_info_buffer[4] = cuda_leaf_best_split_info[larger_leaf_index].threshold;
} else if (threadIdx_x == 5) {
cuda_best_split_info_buffer[5] = cuda_leaf_best_split_info[larger_leaf_index].default_left;
}
}
}
void CUDABestSplitFinder::LaunchFindBestFromAllSplitsKernel(
const int cur_num_leaves,
const int smaller_leaf_index, const int larger_leaf_index,
int* smaller_leaf_best_split_feature,
uint32_t* smaller_leaf_best_split_threshold,
uint8_t* smaller_leaf_best_split_default_left,
int* larger_leaf_best_split_feature,
uint32_t* larger_leaf_best_split_threshold,
uint8_t* larger_leaf_best_split_default_left,
int* best_leaf_index,
int* num_cat_threshold) {
FindBestFromAllSplitsKernel<<<1, NUM_THREADS_FIND_BEST_LEAF, 0, cuda_streams_[1]>>>(cur_num_leaves,
cuda_leaf_best_split_info_,
cuda_best_split_info_buffer_);
PrepareLeafBestSplitInfo<<<6, 1, 0, cuda_streams_[0]>>>(smaller_leaf_index, larger_leaf_index,
cuda_best_split_info_buffer_,
cuda_leaf_best_split_info_);
std::vector<int> host_leaf_best_split_info_buffer(8, 0);
SynchronizeCUDADevice(__FILE__, __LINE__);
CopyFromCUDADeviceToHost<int>(host_leaf_best_split_info_buffer.data(), cuda_best_split_info_buffer_, 8, __FILE__, __LINE__);
*smaller_leaf_best_split_feature = host_leaf_best_split_info_buffer[0];
*smaller_leaf_best_split_threshold = static_cast<uint32_t>(host_leaf_best_split_info_buffer[1]);
*smaller_leaf_best_split_default_left = static_cast<uint8_t>(host_leaf_best_split_info_buffer[2]);
if (larger_leaf_index >= 0) {
*larger_leaf_best_split_feature = host_leaf_best_split_info_buffer[3];
*larger_leaf_best_split_threshold = static_cast<uint32_t>(host_leaf_best_split_info_buffer[4]);
*larger_leaf_best_split_default_left = static_cast<uint8_t>(host_leaf_best_split_info_buffer[5]);
}
*best_leaf_index = host_leaf_best_split_info_buffer[6];
*num_cat_threshold = host_leaf_best_split_info_buffer[7];
}
__global__ void AllocateCatVectorsKernel(
CUDASplitInfo* cuda_split_infos, size_t len,
const int max_num_categories_in_split,
const bool has_categorical_feature,
uint32_t* cat_threshold_vec,
int* cat_threshold_real_vec) {
const size_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < len) {
if (has_categorical_feature) {
cuda_split_infos[i].cat_threshold = cat_threshold_vec + i * max_num_categories_in_split;
cuda_split_infos[i].cat_threshold_real = cat_threshold_real_vec + i * max_num_categories_in_split;
cuda_split_infos[i].num_cat_threshold = 0;
} else {
cuda_split_infos[i].cat_threshold = nullptr;
cuda_split_infos[i].cat_threshold_real = nullptr;
cuda_split_infos[i].num_cat_threshold = 0;
}
}
}
void CUDABestSplitFinder::LaunchAllocateCatVectorsKernel(
CUDASplitInfo* cuda_split_infos, uint32_t* cat_threshold_vec, int* cat_threshold_real_vec, size_t len) {
const int num_blocks = (static_cast<int>(len) + NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER - 1) / NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER;
AllocateCatVectorsKernel<<<num_blocks, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER>>>(
cuda_split_infos, len, max_num_categories_in_split_, has_categorical_feature_, cat_threshold_vec, cat_threshold_real_vec);
}
__global__ void InitCUDARandomKernel(
const int seed,
const int num_tasks,
CUDARandom* cuda_randoms) {
const int task_index = static_cast<int>(threadIdx.x + blockIdx.x * blockDim.x);
if (task_index < num_tasks) {
cuda_randoms[task_index].SetSeed(seed + task_index);
}
}
void CUDABestSplitFinder::LaunchInitCUDARandomKernel() {
const int num_blocks = (static_cast<int>(cuda_randoms_.Size()) +
NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER - 1) / NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER;
InitCUDARandomKernel<<<num_blocks, NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER>>>(extra_seed_,
static_cast<int>(cuda_randoms_.Size()), cuda_randoms_.RawData());
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifndef LIGHTGBM_TREELEARNER_CUDA_CUDA_BEST_SPLIT_FINDER_HPP_
#define LIGHTGBM_TREELEARNER_CUDA_CUDA_BEST_SPLIT_FINDER_HPP_
#ifdef USE_CUDA_EXP
#include <LightGBM/bin.h>
#include <LightGBM/dataset.h>
#include <vector>
#include <LightGBM/cuda/cuda_random.hpp>
#include <LightGBM/cuda/cuda_split_info.hpp>
#include "cuda_leaf_splits.hpp"
#define NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER (256)
#define NUM_THREADS_FIND_BEST_LEAF (256)
#define NUM_TASKS_PER_SYNC_BLOCK (1024)
namespace LightGBM {
struct SplitFindTask {
int inner_feature_index;
bool reverse;
bool skip_default_bin;
bool na_as_missing;
bool assume_out_default_left;
bool is_categorical;
bool is_one_hot;
uint32_t hist_offset;
uint8_t mfb_offset;
uint32_t num_bin;
uint32_t default_bin;
int rand_threshold;
};
class CUDABestSplitFinder {
public:
CUDABestSplitFinder(
const hist_t* cuda_hist,
const Dataset* train_data,
const std::vector<uint32_t>& feature_hist_offsets,
const Config* config);
~CUDABestSplitFinder();
void InitFeatureMetaInfo(const Dataset* train_data);
void Init();
void InitCUDAFeatureMetaInfo();
void BeforeTrain(const std::vector<int8_t>& is_feature_used_bytree);
void FindBestSplitsForLeaf(
const CUDALeafSplitsStruct* smaller_leaf_splits,
const CUDALeafSplitsStruct* larger_leaf_splits,
const int smaller_leaf_index,
const int larger_leaf_index,
const data_size_t num_data_in_smaller_leaf,
const data_size_t num_data_in_larger_leaf,
const double sum_hessians_in_smaller_leaf,
const double sum_hessians_in_larger_leaf);
const CUDASplitInfo* FindBestFromAllSplits(
const int cur_num_leaves,
const int smaller_leaf_index,
const int larger_leaf_index,
int* smaller_leaf_best_split_feature,
uint32_t* smaller_leaf_best_split_threshold,
uint8_t* smaller_leaf_best_split_default_left,
int* larger_leaf_best_split_feature,
uint32_t* larger_leaf_best_split_threshold,
uint8_t* larger_leaf_best_split_default_left,
int* best_leaf_index,
int* num_cat_threshold);
void ResetTrainingData(
const hist_t* cuda_hist,
const Dataset* train_data,
const std::vector<uint32_t>& feature_hist_offsets);
void ResetConfig(const Config* config, const hist_t* cuda_hist);
private:
#define LaunchFindBestSplitsForLeafKernel_PARAMS \
const CUDALeafSplitsStruct* smaller_leaf_splits, \
const CUDALeafSplitsStruct* larger_leaf_splits, \
const int smaller_leaf_index, \
const int larger_leaf_index, \
const bool is_smaller_leaf_valid, \
const bool is_larger_leaf_valid
void LaunchFindBestSplitsForLeafKernel(LaunchFindBestSplitsForLeafKernel_PARAMS);
template <bool USE_RAND>
void LaunchFindBestSplitsForLeafKernelInner0(LaunchFindBestSplitsForLeafKernel_PARAMS);
template <bool USE_RAND, bool USE_L1>
void LaunchFindBestSplitsForLeafKernelInner1(LaunchFindBestSplitsForLeafKernel_PARAMS);
template <bool USE_RAND, bool USE_L1, bool USE_SMOOTHING>
void LaunchFindBestSplitsForLeafKernelInner2(LaunchFindBestSplitsForLeafKernel_PARAMS);
#undef LaunchFindBestSplitsForLeafKernel_PARAMS
void LaunchSyncBestSplitForLeafKernel(
const int host_smaller_leaf_index,
const int host_larger_leaf_index,
const bool is_smaller_leaf_valid,
const bool is_larger_leaf_valid);
void LaunchFindBestFromAllSplitsKernel(
const int cur_num_leaves,
const int smaller_leaf_index,
const int larger_leaf_index,
int* smaller_leaf_best_split_feature,
uint32_t* smaller_leaf_best_split_threshold,
uint8_t* smaller_leaf_best_split_default_left,
int* larger_leaf_best_split_feature,
uint32_t* larger_leaf_best_split_threshold,
uint8_t* larger_leaf_best_split_default_left,
int* best_leaf_index,
data_size_t* num_cat_threshold);
void AllocateCatVectors(CUDASplitInfo* cuda_split_infos, uint32_t* cat_threshold_vec, int* cat_threshold_real_vec, size_t len);
void LaunchAllocateCatVectorsKernel(CUDASplitInfo* cuda_split_infos, uint32_t* cat_threshold_vec, int* cat_threshold_real_vec, size_t len);
void LaunchInitCUDARandomKernel();
// Host memory
int num_features_;
int num_leaves_;
int max_num_bin_in_feature_;
std::vector<uint32_t> feature_hist_offsets_;
std::vector<uint8_t> feature_mfb_offsets_;
std::vector<uint32_t> feature_default_bins_;
std::vector<uint32_t> feature_num_bins_;
std::vector<MissingType> feature_missing_type_;
double lambda_l1_;
double lambda_l2_;
data_size_t min_data_in_leaf_;
double min_sum_hessian_in_leaf_;
double min_gain_to_split_;
double cat_smooth_;
double cat_l2_;
int max_cat_threshold_;
int min_data_per_group_;
int max_cat_to_onehot_;
bool extra_trees_;
int extra_seed_;
bool use_smoothing_;
double path_smooth_;
std::vector<cudaStream_t> cuda_streams_;
// for best split find tasks
std::vector<SplitFindTask> split_find_tasks_;
int num_tasks_;
// use global memory
bool use_global_memory_;
// number of total bins in the dataset
const int num_total_bin_;
// has categorical feature
bool has_categorical_feature_;
// maximum number of bins of categorical features
int max_num_categorical_bin_;
// marks whether a feature is categorical
std::vector<int8_t> is_categorical_;
// CUDA memory, held by this object
// for per leaf best split information
CUDASplitInfo* cuda_leaf_best_split_info_;
// for best split information when finding best split
CUDASplitInfo* cuda_best_split_info_;
// best split information buffer, to be copied to host
int* cuda_best_split_info_buffer_;
// find best split task information
CUDAVector<SplitFindTask> cuda_split_find_tasks_;
int8_t* cuda_is_feature_used_bytree_;
// used when finding best split with global memory
hist_t* cuda_feature_hist_grad_buffer_;
hist_t* cuda_feature_hist_hess_buffer_;
hist_t* cuda_feature_hist_stat_buffer_;
data_size_t* cuda_feature_hist_index_buffer_;
uint32_t* cuda_cat_threshold_leaf_;
int* cuda_cat_threshold_real_leaf_;
uint32_t* cuda_cat_threshold_feature_;
int* cuda_cat_threshold_real_feature_;
int max_num_categories_in_split_;
// used for extremely randomized trees
CUDAVector<CUDARandom> cuda_randoms_;
// CUDA memory, held by other object
const hist_t* cuda_hist_;
};
} // namespace LightGBM
#endif // USE_CUDA_EXP
#endif // LIGHTGBM_TREELEARNER_CUDA_CUDA_BEST_SPLIT_FINDER_HPP_
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA_EXP
#include <algorithm>
#include <memory>
#include "cuda_data_partition.hpp"
namespace LightGBM {
CUDADataPartition::CUDADataPartition(
const Dataset* train_data,
const int num_total_bin,
const int num_leaves,
const int num_threads,
hist_t* cuda_hist):
num_data_(train_data->num_data()),
num_features_(train_data->num_features()),
num_total_bin_(num_total_bin),
num_leaves_(num_leaves),
num_threads_(num_threads),
cuda_hist_(cuda_hist) {
CalcBlockDim(num_data_);
max_num_split_indices_blocks_ = grid_dim_;
cur_num_leaves_ = 1;
cuda_column_data_ = train_data->cuda_column_data();
is_categorical_feature_.resize(train_data->num_features(), false);
is_single_feature_in_column_.resize(train_data->num_features(), false);
for (int feature_index = 0; feature_index < train_data->num_features(); ++feature_index) {
if (train_data->FeatureBinMapper(feature_index)->bin_type() == BinType::CategoricalBin) {
is_categorical_feature_[feature_index] = true;
}
const int feature_group_index = train_data->Feature2Group(feature_index);
if (!train_data->IsMultiGroup(feature_group_index)) {
if ((feature_index == 0 || train_data->Feature2Group(feature_index - 1) != feature_group_index) &&
(feature_index == train_data->num_features() - 1 || train_data->Feature2Group(feature_index + 1) != feature_group_index)) {
is_single_feature_in_column_[feature_index] = true;
}
} else {
is_single_feature_in_column_[feature_index] = true;
}
}
cuda_data_indices_ = nullptr;
cuda_leaf_data_start_ = nullptr;
cuda_leaf_data_end_ = nullptr;
cuda_leaf_num_data_ = nullptr;
cuda_hist_pool_ = nullptr;
cuda_leaf_output_ = nullptr;
cuda_block_to_left_offset_ = nullptr;
cuda_data_index_to_leaf_index_ = nullptr;
cuda_block_data_to_left_offset_ = nullptr;
cuda_block_data_to_right_offset_ = nullptr;
cuda_out_data_indices_in_leaf_ = nullptr;
cuda_split_info_buffer_ = nullptr;
cuda_num_data_ = nullptr;
cuda_add_train_score_ = nullptr;
}
CUDADataPartition::~CUDADataPartition() {
DeallocateCUDAMemory<data_size_t>(&cuda_data_indices_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_leaf_data_start_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_leaf_data_end_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_leaf_num_data_, __FILE__, __LINE__);
DeallocateCUDAMemory<hist_t*>(&cuda_hist_pool_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_leaf_output_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint16_t>(&cuda_block_to_left_offset_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_data_index_to_leaf_index_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_block_data_to_left_offset_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_block_data_to_right_offset_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_out_data_indices_in_leaf_, __FILE__, __LINE__);
DeallocateCUDAMemory<int>(&cuda_split_info_buffer_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_num_data_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_add_train_score_, __FILE__, __LINE__);
CUDASUCCESS_OR_FATAL(cudaStreamDestroy(cuda_streams_[0]));
CUDASUCCESS_OR_FATAL(cudaStreamDestroy(cuda_streams_[1]));
CUDASUCCESS_OR_FATAL(cudaStreamDestroy(cuda_streams_[2]));
CUDASUCCESS_OR_FATAL(cudaStreamDestroy(cuda_streams_[3]));
cuda_streams_.clear();
cuda_streams_.shrink_to_fit();
}
void CUDADataPartition::Init() {
// allocate CUDA memory
AllocateCUDAMemory<data_size_t>(&cuda_data_indices_, static_cast<size_t>(num_data_), __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_leaf_data_start_, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_leaf_data_end_, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_leaf_num_data_, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
// leave some space for alignment
AllocateCUDAMemory<uint16_t>(&cuda_block_to_left_offset_, static_cast<size_t>(num_data_), __FILE__, __LINE__);
AllocateCUDAMemory<int>(&cuda_data_index_to_leaf_index_, static_cast<size_t>(num_data_), __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_block_data_to_left_offset_, static_cast<size_t>(max_num_split_indices_blocks_) + 1, __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_block_data_to_right_offset_, static_cast<size_t>(max_num_split_indices_blocks_) + 1, __FILE__, __LINE__);
SetCUDAMemory<data_size_t>(cuda_block_data_to_left_offset_, 0, static_cast<size_t>(max_num_split_indices_blocks_) + 1, __FILE__, __LINE__);
SetCUDAMemory<data_size_t>(cuda_block_data_to_right_offset_, 0, static_cast<size_t>(max_num_split_indices_blocks_) + 1, __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_out_data_indices_in_leaf_, static_cast<size_t>(num_data_), __FILE__, __LINE__);
AllocateCUDAMemory<hist_t*>(&cuda_hist_pool_, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
CopyFromHostToCUDADevice<hist_t*>(cuda_hist_pool_, &cuda_hist_, 1, __FILE__, __LINE__);
AllocateCUDAMemory<int>(&cuda_split_info_buffer_, 16, __FILE__, __LINE__);
AllocateCUDAMemory<double>(&cuda_leaf_output_, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
cuda_streams_.resize(4);
gpuAssert(cudaStreamCreate(&cuda_streams_[0]), __FILE__, __LINE__);
gpuAssert(cudaStreamCreate(&cuda_streams_[1]), __FILE__, __LINE__);
gpuAssert(cudaStreamCreate(&cuda_streams_[2]), __FILE__, __LINE__);
gpuAssert(cudaStreamCreate(&cuda_streams_[3]), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<data_size_t>(&cuda_num_data_, &num_data_, 1, __FILE__, __LINE__);
add_train_score_.resize(num_data_, 0.0f);
AllocateCUDAMemory<double>(&cuda_add_train_score_, static_cast<size_t>(num_data_), __FILE__, __LINE__);
use_bagging_ = false;
used_indices_ = nullptr;
}
void CUDADataPartition::BeforeTrain() {
if (!use_bagging_) {
LaunchFillDataIndicesBeforeTrain();
}
SetCUDAMemory<data_size_t>(cuda_leaf_num_data_, 0, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
SetCUDAMemory<data_size_t>(cuda_leaf_data_start_, 0, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
SetCUDAMemory<data_size_t>(cuda_leaf_data_end_, 0, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
SynchronizeCUDADevice(__FILE__, __LINE__);
if (!use_bagging_) {
CopyFromCUDADeviceToCUDADevice<data_size_t>(cuda_leaf_num_data_, cuda_num_data_, 1, __FILE__, __LINE__);
CopyFromCUDADeviceToCUDADevice<data_size_t>(cuda_leaf_data_end_, cuda_num_data_, 1, __FILE__, __LINE__);
} else {
CopyFromHostToCUDADevice<data_size_t>(cuda_leaf_num_data_, &num_used_indices_, 1, __FILE__, __LINE__);
CopyFromHostToCUDADevice<data_size_t>(cuda_leaf_data_end_, &num_used_indices_, 1, __FILE__, __LINE__);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
CopyFromHostToCUDADevice<hist_t*>(cuda_hist_pool_, &cuda_hist_, 1, __FILE__, __LINE__);
}
void CUDADataPartition::Split(
// input best split info
const CUDASplitInfo* best_split_info,
const int left_leaf_index,
const int right_leaf_index,
const int leaf_best_split_feature,
const uint32_t leaf_best_split_threshold,
const uint32_t* categorical_bitset,
const int categorical_bitset_len,
const uint8_t leaf_best_split_default_left,
const data_size_t num_data_in_leaf,
const data_size_t leaf_data_start,
// for leaf information update
CUDALeafSplitsStruct* smaller_leaf_splits,
CUDALeafSplitsStruct* larger_leaf_splits,
// gather information for CPU, used for launching kernels
data_size_t* left_leaf_num_data,
data_size_t* right_leaf_num_data,
data_size_t* left_leaf_start,
data_size_t* right_leaf_start,
double* left_leaf_sum_of_hessians,
double* right_leaf_sum_of_hessians,
double* left_leaf_sum_of_gradients,
double* right_leaf_sum_of_gradients) {
CalcBlockDim(num_data_in_leaf);
global_timer.Start("GenDataToLeftBitVector");
GenDataToLeftBitVector(num_data_in_leaf,
leaf_best_split_feature,
leaf_best_split_threshold,
categorical_bitset,
categorical_bitset_len,
leaf_best_split_default_left,
leaf_data_start,
left_leaf_index,
right_leaf_index);
global_timer.Stop("GenDataToLeftBitVector");
global_timer.Start("SplitInner");
SplitInner(num_data_in_leaf,
best_split_info,
left_leaf_index,
right_leaf_index,
smaller_leaf_splits,
larger_leaf_splits,
left_leaf_num_data,
right_leaf_num_data,
left_leaf_start,
right_leaf_start,
left_leaf_sum_of_hessians,
right_leaf_sum_of_hessians,
left_leaf_sum_of_gradients,
right_leaf_sum_of_gradients);
global_timer.Stop("SplitInner");
}
void CUDADataPartition::GenDataToLeftBitVector(
const data_size_t num_data_in_leaf,
const int split_feature_index,
const uint32_t split_threshold,
const uint32_t* categorical_bitset,
const int categorical_bitset_len,
const uint8_t split_default_left,
const data_size_t leaf_data_start,
const int left_leaf_index,
const int right_leaf_index) {
if (is_categorical_feature_[split_feature_index]) {
LaunchGenDataToLeftBitVectorCategoricalKernel(
num_data_in_leaf,
split_feature_index,
categorical_bitset,
categorical_bitset_len,
split_default_left,
leaf_data_start,
left_leaf_index,
right_leaf_index);
} else {
LaunchGenDataToLeftBitVectorKernel(
num_data_in_leaf,
split_feature_index,
split_threshold,
split_default_left,
leaf_data_start,
left_leaf_index,
right_leaf_index);
}
}
void CUDADataPartition::SplitInner(
const data_size_t num_data_in_leaf,
const CUDASplitInfo* best_split_info,
const int left_leaf_index,
const int right_leaf_index,
// for leaf splits information update
CUDALeafSplitsStruct* smaller_leaf_splits,
CUDALeafSplitsStruct* larger_leaf_splits,
data_size_t* left_leaf_num_data,
data_size_t* right_leaf_num_data,
data_size_t* left_leaf_start,
data_size_t* right_leaf_start,
double* left_leaf_sum_of_hessians,
double* right_leaf_sum_of_hessians,
double* left_leaf_sum_of_gradients,
double* right_leaf_sum_of_gradients) {
LaunchSplitInnerKernel(
num_data_in_leaf,
best_split_info,
left_leaf_index,
right_leaf_index,
smaller_leaf_splits,
larger_leaf_splits,
left_leaf_num_data,
right_leaf_num_data,
left_leaf_start,
right_leaf_start,
left_leaf_sum_of_hessians,
right_leaf_sum_of_hessians,
left_leaf_sum_of_gradients,
right_leaf_sum_of_gradients);
++cur_num_leaves_;
}
void CUDADataPartition::UpdateTrainScore(const Tree* tree, double* scores) {
const CUDATree* cuda_tree = nullptr;
std::unique_ptr<CUDATree> cuda_tree_ptr;
if (tree->is_cuda_tree()) {
cuda_tree = reinterpret_cast<const CUDATree*>(tree);
} else {
cuda_tree_ptr.reset(new CUDATree(tree));
cuda_tree = cuda_tree_ptr.get();
}
const data_size_t num_data_in_root = root_num_data();
if (use_bagging_) {
// we need restore the order of indices in cuda_data_indices_
CopyFromHostToCUDADevice<data_size_t>(cuda_data_indices_, used_indices_, static_cast<size_t>(num_used_indices_), __FILE__, __LINE__);
}
LaunchAddPredictionToScoreKernel(cuda_tree->cuda_leaf_value(), cuda_add_train_score_);
CopyFromCUDADeviceToHost<double>(add_train_score_.data(),
cuda_add_train_score_, static_cast<size_t>(num_data_in_root), __FILE__, __LINE__);
if (!use_bagging_) {
OMP_INIT_EX();
#pragma omp parallel for schedule(static) num_threads(num_threads_)
for (data_size_t data_index = 0; data_index < num_data_in_root; ++data_index) {
OMP_LOOP_EX_BEGIN();
scores[data_index] += add_train_score_[data_index];
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
} else {
OMP_INIT_EX();
#pragma omp parallel for schedule(static) num_threads(num_threads_)
for (data_size_t data_index = 0; data_index < num_data_in_root; ++data_index) {
OMP_LOOP_EX_BEGIN();
scores[used_indices_[data_index]] += add_train_score_[data_index];
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
}
}
void CUDADataPartition::CalcBlockDim(const data_size_t num_data_in_leaf) {
const int min_num_blocks = num_data_in_leaf <= 100 ? 1 : 80;
const int num_blocks = std::max(min_num_blocks, (num_data_in_leaf + SPLIT_INDICES_BLOCK_SIZE_DATA_PARTITION - 1) / SPLIT_INDICES_BLOCK_SIZE_DATA_PARTITION);
int split_indices_block_size_data_partition = (num_data_in_leaf + num_blocks - 1) / num_blocks - 1;
CHECK_GT(split_indices_block_size_data_partition, 0);
int split_indices_block_size_data_partition_aligned = 1;
while (split_indices_block_size_data_partition > 0) {
split_indices_block_size_data_partition_aligned <<= 1;
split_indices_block_size_data_partition >>= 1;
}
const int num_blocks_final = (num_data_in_leaf + split_indices_block_size_data_partition_aligned - 1) / split_indices_block_size_data_partition_aligned;
grid_dim_ = num_blocks_final;
block_dim_ = split_indices_block_size_data_partition_aligned;
}
void CUDADataPartition::SetUsedDataIndices(const data_size_t* used_indices, const data_size_t num_used_indices) {
use_bagging_ = true;
num_used_indices_ = num_used_indices;
used_indices_ = used_indices;
CopyFromHostToCUDADevice<data_size_t>(cuda_data_indices_, used_indices, static_cast<size_t>(num_used_indices), __FILE__, __LINE__);
LaunchFillDataIndexToLeafIndex();
}
void CUDADataPartition::ResetTrainingData(const Dataset* train_data, const int num_total_bin, hist_t* cuda_hist) {
const data_size_t old_num_data = num_data_;
num_data_ = train_data->num_data();
num_features_ = train_data->num_features();
num_total_bin_ = num_total_bin;
cuda_column_data_ = train_data->cuda_column_data();
cuda_hist_ = cuda_hist;
CopyFromHostToCUDADevice<hist_t*>(cuda_hist_pool_, &cuda_hist_, 1, __FILE__, __LINE__);
CopyFromHostToCUDADevice<int>(cuda_num_data_, &num_data_, 1, __FILE__, __LINE__);
if (num_data_ > old_num_data) {
CalcBlockDim(num_data_);
const int old_max_num_split_indices_blocks = max_num_split_indices_blocks_;
max_num_split_indices_blocks_ = grid_dim_;
if (max_num_split_indices_blocks_ > old_max_num_split_indices_blocks) {
DeallocateCUDAMemory<data_size_t>(&cuda_block_data_to_left_offset_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_block_data_to_right_offset_, __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_block_data_to_left_offset_, static_cast<size_t>(max_num_split_indices_blocks_) + 1, __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_block_data_to_right_offset_, static_cast<size_t>(max_num_split_indices_blocks_) + 1, __FILE__, __LINE__);
SetCUDAMemory<data_size_t>(cuda_block_data_to_left_offset_, 0, static_cast<size_t>(max_num_split_indices_blocks_) + 1, __FILE__, __LINE__);
SetCUDAMemory<data_size_t>(cuda_block_data_to_right_offset_, 0, static_cast<size_t>(max_num_split_indices_blocks_) + 1, __FILE__, __LINE__);
}
DeallocateCUDAMemory<data_size_t>(&cuda_data_indices_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint16_t>(&cuda_block_to_left_offset_, __FILE__, __LINE__);
DeallocateCUDAMemory<int>(&cuda_data_index_to_leaf_index_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_out_data_indices_in_leaf_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_add_train_score_, __FILE__, __LINE__);
add_train_score_.resize(num_data_, 0.0f);
AllocateCUDAMemory<data_size_t>(&cuda_data_indices_, static_cast<size_t>(num_data_), __FILE__, __LINE__);
AllocateCUDAMemory<uint16_t>(&cuda_block_to_left_offset_, static_cast<size_t>(num_data_), __FILE__, __LINE__);
AllocateCUDAMemory<int>(&cuda_data_index_to_leaf_index_, static_cast<size_t>(num_data_), __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_out_data_indices_in_leaf_, static_cast<size_t>(num_data_), __FILE__, __LINE__);
AllocateCUDAMemory<double>(&cuda_add_train_score_, static_cast<size_t>(num_data_), __FILE__, __LINE__);
}
used_indices_ = nullptr;
use_bagging_ = false;
num_used_indices_ = 0;
cur_num_leaves_ = 1;
}
void CUDADataPartition::ResetConfig(const Config* config, hist_t* cuda_hist) {
num_threads_ = OMP_NUM_THREADS();
num_leaves_ = config->num_leaves;
cuda_hist_ = cuda_hist;
DeallocateCUDAMemory<data_size_t>(&cuda_leaf_data_start_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_leaf_data_end_, __FILE__, __LINE__);
DeallocateCUDAMemory<data_size_t>(&cuda_leaf_num_data_, __FILE__, __LINE__);
DeallocateCUDAMemory<hist_t*>(&cuda_hist_pool_, __FILE__, __LINE__);
DeallocateCUDAMemory<double>(&cuda_leaf_output_, __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_leaf_data_start_, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_leaf_data_end_, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
AllocateCUDAMemory<data_size_t>(&cuda_leaf_num_data_, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
AllocateCUDAMemory<hist_t*>(&cuda_hist_pool_, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
AllocateCUDAMemory<double>(&cuda_leaf_output_, static_cast<size_t>(num_leaves_), __FILE__, __LINE__);
}
void CUDADataPartition::SetBaggingSubset(const Dataset* subset) {
num_used_indices_ = subset->num_data();
used_indices_ = nullptr;
use_bagging_ = true;
cuda_column_data_ = subset->cuda_column_data();
}
void CUDADataPartition::ResetByLeafPred(const std::vector<int>& leaf_pred, int num_leaves) {
if (leaf_pred.size() != static_cast<size_t>(num_data_)) {
DeallocateCUDAMemory<int>(&cuda_data_index_to_leaf_index_, __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<int>(&cuda_data_index_to_leaf_index_, leaf_pred.data(), leaf_pred.size(), __FILE__, __LINE__);
num_data_ = static_cast<data_size_t>(leaf_pred.size());
} else {
CopyFromHostToCUDADevice<int>(cuda_data_index_to_leaf_index_, leaf_pred.data(), leaf_pred.size(), __FILE__, __LINE__);
}
num_leaves_ = num_leaves;
cur_num_leaves_ = num_leaves;
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA_EXP
#include "cuda_data_partition.hpp"
#include <LightGBM/cuda/cuda_algorithms.hpp>
#include <LightGBM/tree.h>
#include <algorithm>
#include <vector>
namespace LightGBM {
__global__ void FillDataIndicesBeforeTrainKernel(const data_size_t num_data,
data_size_t* data_indices, int* cuda_data_index_to_leaf_index) {
const unsigned int data_index = threadIdx.x + blockIdx.x * blockDim.x;
if (data_index < num_data) {
data_indices[data_index] = data_index;
cuda_data_index_to_leaf_index[data_index] = 0;
}
}
__global__ void FillDataIndexToLeafIndexKernel(
const data_size_t num_data,
const data_size_t* data_indices,
int* data_index_to_leaf_index) {
const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
if (data_index < num_data) {
data_index_to_leaf_index[data_indices[data_index]] = 0;
}
}
void CUDADataPartition::LaunchFillDataIndicesBeforeTrain() {
const data_size_t num_data_in_root = root_num_data();
const int num_blocks = (num_data_in_root + FILL_INDICES_BLOCK_SIZE_DATA_PARTITION - 1) / FILL_INDICES_BLOCK_SIZE_DATA_PARTITION;
FillDataIndicesBeforeTrainKernel<<<num_blocks, FILL_INDICES_BLOCK_SIZE_DATA_PARTITION>>>(num_data_in_root, cuda_data_indices_, cuda_data_index_to_leaf_index_);
}
void CUDADataPartition::LaunchFillDataIndexToLeafIndex() {
const data_size_t num_data_in_root = root_num_data();
const int num_blocks = (num_data_in_root + FILL_INDICES_BLOCK_SIZE_DATA_PARTITION - 1) / FILL_INDICES_BLOCK_SIZE_DATA_PARTITION;
FillDataIndexToLeafIndexKernel<<<num_blocks, FILL_INDICES_BLOCK_SIZE_DATA_PARTITION>>>(num_data_in_root, cuda_data_indices_, cuda_data_index_to_leaf_index_);
}
__device__ __forceinline__ void PrepareOffset(const data_size_t num_data_in_leaf, uint16_t* block_to_left_offset,
data_size_t* block_to_left_offset_buffer, data_size_t* block_to_right_offset_buffer,
const uint16_t thread_to_left_offset_cnt, uint16_t* shared_mem_buffer) {
const unsigned int threadIdx_x = threadIdx.x;
const unsigned int blockDim_x = blockDim.x;
const uint16_t thread_to_left_offset = ShufflePrefixSum<uint16_t>(thread_to_left_offset_cnt, shared_mem_buffer);
const data_size_t num_data_in_block = (blockIdx.x + 1) * blockDim_x <= num_data_in_leaf ? static_cast<data_size_t>(blockDim_x) :
num_data_in_leaf - static_cast<data_size_t>(blockIdx.x * blockDim_x);
if (static_cast<data_size_t>(threadIdx_x) < num_data_in_block) {
block_to_left_offset[threadIdx_x] = thread_to_left_offset;
}
if (threadIdx_x == blockDim_x - 1) {
if (num_data_in_block > 0) {
const data_size_t data_to_left = static_cast<data_size_t>(thread_to_left_offset);
block_to_left_offset_buffer[blockIdx.x + 1] = data_to_left;
block_to_right_offset_buffer[blockIdx.x + 1] = num_data_in_block - data_to_left;
} else {
block_to_left_offset_buffer[blockIdx.x + 1] = 0;
block_to_right_offset_buffer[blockIdx.x + 1] = 0;
}
}
}
template <typename T>
__device__ bool CUDAFindInBitset(const uint32_t* bits, int n, T pos) {
int i1 = pos / 32;
if (i1 >= n) {
return false;
}
int i2 = pos % 32;
return (bits[i1] >> i2) & 1;
}
#define UpdateDataIndexToLeafIndexKernel_PARAMS \
const BIN_TYPE* column_data, \
const data_size_t num_data_in_leaf, \
const data_size_t* data_indices_in_leaf, \
const uint32_t th, \
const uint32_t t_zero_bin, \
const uint32_t max_bin, \
const uint32_t min_bin, \
const int left_leaf_index, \
const int right_leaf_index, \
const int default_leaf_index, \
const int missing_default_leaf_index
#define UpdateDataIndexToLeafIndex_ARGS \
column_data, \
num_data_in_leaf, \
data_indices_in_leaf, th, \
t_zero_bin, \
max_bin, \
min_bin, \
left_leaf_index, \
right_leaf_index, \
default_leaf_index, \
missing_default_leaf_index
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, bool MFB_IS_NA, bool MAX_TO_LEFT, bool USE_MIN_BIN, typename BIN_TYPE>
__global__ void UpdateDataIndexToLeafIndexKernel(
UpdateDataIndexToLeafIndexKernel_PARAMS,
int* cuda_data_index_to_leaf_index) {
const unsigned int local_data_index = blockIdx.x * blockDim.x + threadIdx.x;
if (local_data_index < num_data_in_leaf) {
const unsigned int global_data_index = data_indices_in_leaf[local_data_index];
const uint32_t bin = static_cast<uint32_t>(column_data[global_data_index]);
if (!MIN_IS_MAX) {
if ((MISSING_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) ||
(MISSING_IS_NA && !MFB_IS_NA && bin == max_bin)) {
cuda_data_index_to_leaf_index[global_data_index] = missing_default_leaf_index;
} else if ((USE_MIN_BIN && (bin < min_bin || bin > max_bin)) ||
(!USE_MIN_BIN && bin == 0)) {
if ((MISSING_IS_NA && MFB_IS_NA) || (MISSING_IS_ZERO && MFB_IS_ZERO)) {
cuda_data_index_to_leaf_index[global_data_index] = missing_default_leaf_index;
} else {
cuda_data_index_to_leaf_index[global_data_index] = default_leaf_index;
}
} else if (bin > th) {
cuda_data_index_to_leaf_index[global_data_index] = right_leaf_index;
} else {
cuda_data_index_to_leaf_index[global_data_index] = left_leaf_index;
}
} else {
if (MISSING_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) {
cuda_data_index_to_leaf_index[global_data_index] = missing_default_leaf_index;
} else if (bin != max_bin) {
if ((MISSING_IS_NA && MFB_IS_NA) || (MISSING_IS_ZERO && MFB_IS_ZERO)) {
cuda_data_index_to_leaf_index[global_data_index] = missing_default_leaf_index;
} else {
cuda_data_index_to_leaf_index[global_data_index] = default_leaf_index;
}
} else {
if (MISSING_IS_NA && !MFB_IS_NA) {
cuda_data_index_to_leaf_index[global_data_index] = missing_default_leaf_index;
} else {
if (!MAX_TO_LEFT) {
cuda_data_index_to_leaf_index[global_data_index] = right_leaf_index;
} else {
cuda_data_index_to_leaf_index[global_data_index] = left_leaf_index;
}
}
}
}
}
}
template <typename BIN_TYPE>
void CUDADataPartition::LaunchUpdateDataIndexToLeafIndexKernel(
UpdateDataIndexToLeafIndexKernel_PARAMS,
const bool missing_is_zero,
const bool missing_is_na,
const bool mfb_is_zero,
const bool mfb_is_na,
const bool max_to_left,
const bool is_single_feature_in_column) {
if (min_bin < max_bin) {
if (!missing_is_zero) {
LaunchUpdateDataIndexToLeafIndexKernel_Inner0<false, false, BIN_TYPE>
(UpdateDataIndexToLeafIndex_ARGS, missing_is_na, mfb_is_zero, mfb_is_na, max_to_left, is_single_feature_in_column);
} else {
LaunchUpdateDataIndexToLeafIndexKernel_Inner0<false, true, BIN_TYPE>
(UpdateDataIndexToLeafIndex_ARGS, missing_is_na, mfb_is_zero, mfb_is_na, max_to_left, is_single_feature_in_column);
}
} else {
if (!missing_is_zero) {
LaunchUpdateDataIndexToLeafIndexKernel_Inner0<true, false, BIN_TYPE>
(UpdateDataIndexToLeafIndex_ARGS, missing_is_na, mfb_is_zero, mfb_is_na, max_to_left, is_single_feature_in_column);
} else {
LaunchUpdateDataIndexToLeafIndexKernel_Inner0<true, true, BIN_TYPE>
(UpdateDataIndexToLeafIndex_ARGS, missing_is_na, mfb_is_zero, mfb_is_na, max_to_left, is_single_feature_in_column);
}
}
}
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, typename BIN_TYPE>
void CUDADataPartition::LaunchUpdateDataIndexToLeafIndexKernel_Inner0(
UpdateDataIndexToLeafIndexKernel_PARAMS,
const bool missing_is_na,
const bool mfb_is_zero,
const bool mfb_is_na,
const bool max_to_left,
const bool is_single_feature_in_column) {
if (!missing_is_na) {
LaunchUpdateDataIndexToLeafIndexKernel_Inner1<MIN_IS_MAX, MISSING_IS_ZERO, false, BIN_TYPE>
(UpdateDataIndexToLeafIndex_ARGS, mfb_is_zero, mfb_is_na, max_to_left, is_single_feature_in_column);
} else {
LaunchUpdateDataIndexToLeafIndexKernel_Inner1<MIN_IS_MAX, MISSING_IS_ZERO, true, BIN_TYPE>
(UpdateDataIndexToLeafIndex_ARGS, mfb_is_zero, mfb_is_na, max_to_left, is_single_feature_in_column);
}
}
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, typename BIN_TYPE>
void CUDADataPartition::LaunchUpdateDataIndexToLeafIndexKernel_Inner1(
UpdateDataIndexToLeafIndexKernel_PARAMS,
const bool mfb_is_zero,
const bool mfb_is_na,
const bool max_to_left,
const bool is_single_feature_in_column) {
if (!mfb_is_zero) {
LaunchUpdateDataIndexToLeafIndexKernel_Inner2<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, false, BIN_TYPE>
(UpdateDataIndexToLeafIndex_ARGS, mfb_is_na, max_to_left, is_single_feature_in_column);
} else {
LaunchUpdateDataIndexToLeafIndexKernel_Inner2<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, true, BIN_TYPE>
(UpdateDataIndexToLeafIndex_ARGS, mfb_is_na, max_to_left, is_single_feature_in_column);
}
}
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, typename BIN_TYPE>
void CUDADataPartition::LaunchUpdateDataIndexToLeafIndexKernel_Inner2(
UpdateDataIndexToLeafIndexKernel_PARAMS,
const bool mfb_is_na,
const bool max_to_left,
const bool is_single_feature_in_column) {
if (!mfb_is_na) {
LaunchUpdateDataIndexToLeafIndexKernel_Inner3<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, MFB_IS_ZERO, false, BIN_TYPE>
(UpdateDataIndexToLeafIndex_ARGS, max_to_left, is_single_feature_in_column);
} else {
LaunchUpdateDataIndexToLeafIndexKernel_Inner3<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, MFB_IS_ZERO, true, BIN_TYPE>
(UpdateDataIndexToLeafIndex_ARGS, max_to_left, is_single_feature_in_column);
}
}
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, bool MFB_IS_NA, typename BIN_TYPE>
void CUDADataPartition::LaunchUpdateDataIndexToLeafIndexKernel_Inner3(
UpdateDataIndexToLeafIndexKernel_PARAMS,
const bool max_to_left,
const bool is_single_feature_in_column) {
if (!max_to_left) {
LaunchUpdateDataIndexToLeafIndexKernel_Inner4<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, MFB_IS_ZERO, MFB_IS_NA, false, BIN_TYPE>
(UpdateDataIndexToLeafIndex_ARGS, is_single_feature_in_column);
} else {
LaunchUpdateDataIndexToLeafIndexKernel_Inner4<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, MFB_IS_ZERO, MFB_IS_NA, true, BIN_TYPE>
(UpdateDataIndexToLeafIndex_ARGS, is_single_feature_in_column);
}
}
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, bool MFB_IS_NA, bool MAX_TO_LEFT, typename BIN_TYPE>
void CUDADataPartition::LaunchUpdateDataIndexToLeafIndexKernel_Inner4(
UpdateDataIndexToLeafIndexKernel_PARAMS,
const bool is_single_feature_in_column) {
if (!is_single_feature_in_column) {
UpdateDataIndexToLeafIndexKernel<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, MFB_IS_ZERO, MFB_IS_NA, MAX_TO_LEFT, true, BIN_TYPE>
<<<grid_dim_, block_dim_, 0, cuda_streams_[3]>>>(
UpdateDataIndexToLeafIndex_ARGS,
cuda_data_index_to_leaf_index_);
} else {
UpdateDataIndexToLeafIndexKernel<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, MFB_IS_ZERO, MFB_IS_NA, MAX_TO_LEFT, false, BIN_TYPE>
<<<grid_dim_, block_dim_, 0, cuda_streams_[3]>>>(
UpdateDataIndexToLeafIndex_ARGS,
cuda_data_index_to_leaf_index_);
}
}
#define GenDataToLeftBitVectorKernel_PARMS \
const BIN_TYPE* column_data, \
const data_size_t num_data_in_leaf, \
const data_size_t* data_indices_in_leaf, \
const uint32_t th, \
const uint32_t t_zero_bin, \
const uint32_t max_bin, \
const uint32_t min_bin, \
const uint8_t split_default_to_left, \
const uint8_t split_missing_default_to_left
#define GenBitVector_ARGS \
column_data, \
num_data_in_leaf, \
data_indices_in_leaf, \
th, \
t_zero_bin, \
max_bin, \
min_bin, \
split_default_to_left, \
split_missing_default_to_left
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, bool MFB_IS_NA, bool MAX_TO_LEFT, bool USE_MIN_BIN, typename BIN_TYPE>
__global__ void GenDataToLeftBitVectorKernel(
GenDataToLeftBitVectorKernel_PARMS,
uint16_t* block_to_left_offset,
data_size_t* block_to_left_offset_buffer,
data_size_t* block_to_right_offset_buffer) {
__shared__ uint16_t shared_mem_buffer[32];
uint16_t thread_to_left_offset_cnt = 0;
const unsigned int local_data_index = blockIdx.x * blockDim.x + threadIdx.x;
if (local_data_index < num_data_in_leaf) {
const unsigned int global_data_index = data_indices_in_leaf[local_data_index];
const uint32_t bin = static_cast<uint32_t>(column_data[global_data_index]);
if (!MIN_IS_MAX) {
if ((MISSING_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) ||
(MISSING_IS_NA && !MFB_IS_NA && bin == max_bin)) {
thread_to_left_offset_cnt = split_missing_default_to_left;
} else if ((USE_MIN_BIN && (bin < min_bin || bin > max_bin)) ||
(!USE_MIN_BIN && bin == 0)) {
if ((MISSING_IS_NA && MFB_IS_NA) || (MISSING_IS_ZERO || MFB_IS_ZERO)) {
thread_to_left_offset_cnt = split_missing_default_to_left;
} else {
thread_to_left_offset_cnt = split_default_to_left;
}
} else if (bin <= th) {
thread_to_left_offset_cnt = 1;
}
} else {
if (MISSING_IS_ZERO && !MFB_IS_ZERO && bin == t_zero_bin) {
thread_to_left_offset_cnt = split_missing_default_to_left;
} else if (bin != max_bin) {
if ((MISSING_IS_NA && MFB_IS_NA) || (MISSING_IS_ZERO && MFB_IS_ZERO)) {
thread_to_left_offset_cnt = split_missing_default_to_left;
} else {
thread_to_left_offset_cnt = split_default_to_left;
}
} else {
if (MISSING_IS_NA && !MFB_IS_NA) {
thread_to_left_offset_cnt = split_missing_default_to_left;
} else if (MAX_TO_LEFT) {
thread_to_left_offset_cnt = 1;
}
}
}
}
__syncthreads();
PrepareOffset(num_data_in_leaf, block_to_left_offset + blockIdx.x * blockDim.x, block_to_left_offset_buffer, block_to_right_offset_buffer,
thread_to_left_offset_cnt, shared_mem_buffer);
}
template <typename BIN_TYPE>
void CUDADataPartition::LaunchGenDataToLeftBitVectorKernelInner(
GenDataToLeftBitVectorKernel_PARMS,
const bool missing_is_zero,
const bool missing_is_na,
const bool mfb_is_zero,
const bool mfb_is_na,
const bool max_bin_to_left,
const bool is_single_feature_in_column) {
if (min_bin < max_bin) {
if (!missing_is_zero) {
LaunchGenDataToLeftBitVectorKernelInner0<false, false, BIN_TYPE>
(GenBitVector_ARGS, missing_is_na, mfb_is_zero, mfb_is_na, max_bin_to_left, is_single_feature_in_column);
} else {
LaunchGenDataToLeftBitVectorKernelInner0<false, true, BIN_TYPE>
(GenBitVector_ARGS, missing_is_na, mfb_is_zero, mfb_is_na, max_bin_to_left, is_single_feature_in_column);
}
} else {
if (!missing_is_zero) {
LaunchGenDataToLeftBitVectorKernelInner0<true, false, BIN_TYPE>
(GenBitVector_ARGS, missing_is_na, mfb_is_zero, mfb_is_na, max_bin_to_left, is_single_feature_in_column);
} else {
LaunchGenDataToLeftBitVectorKernelInner0<true, true, BIN_TYPE>
(GenBitVector_ARGS, missing_is_na, mfb_is_zero, mfb_is_na, max_bin_to_left, is_single_feature_in_column);
}
}
}
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, typename BIN_TYPE>
void CUDADataPartition::LaunchGenDataToLeftBitVectorKernelInner0(
GenDataToLeftBitVectorKernel_PARMS,
const bool missing_is_na,
const bool mfb_is_zero,
const bool mfb_is_na,
const bool max_bin_to_left,
const bool is_single_feature_in_column) {
if (!missing_is_na) {
LaunchGenDataToLeftBitVectorKernelInner1<MIN_IS_MAX, MISSING_IS_ZERO, false, BIN_TYPE>
(GenBitVector_ARGS, mfb_is_zero, mfb_is_na, max_bin_to_left, is_single_feature_in_column);
} else {
LaunchGenDataToLeftBitVectorKernelInner1<MIN_IS_MAX, MISSING_IS_ZERO, true, BIN_TYPE>
(GenBitVector_ARGS, mfb_is_zero, mfb_is_na, max_bin_to_left, is_single_feature_in_column);
}
}
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, typename BIN_TYPE>
void CUDADataPartition::LaunchGenDataToLeftBitVectorKernelInner1(
GenDataToLeftBitVectorKernel_PARMS,
const bool mfb_is_zero,
const bool mfb_is_na,
const bool max_bin_to_left,
const bool is_single_feature_in_column) {
if (!mfb_is_zero) {
LaunchGenDataToLeftBitVectorKernelInner2<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, false, BIN_TYPE>
(GenBitVector_ARGS, mfb_is_na, max_bin_to_left, is_single_feature_in_column);
} else {
LaunchGenDataToLeftBitVectorKernelInner2<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, true, BIN_TYPE>
(GenBitVector_ARGS, mfb_is_na, max_bin_to_left, is_single_feature_in_column);
}
}
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, typename BIN_TYPE>
void CUDADataPartition::LaunchGenDataToLeftBitVectorKernelInner2(
GenDataToLeftBitVectorKernel_PARMS,
const bool mfb_is_na,
const bool max_bin_to_left,
const bool is_single_feature_in_column) {
if (!mfb_is_na) {
LaunchGenDataToLeftBitVectorKernelInner3
<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, MFB_IS_ZERO, false, BIN_TYPE>
(GenBitVector_ARGS, max_bin_to_left, is_single_feature_in_column);
} else {
LaunchGenDataToLeftBitVectorKernelInner3
<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, MFB_IS_ZERO, true, BIN_TYPE>
(GenBitVector_ARGS, max_bin_to_left, is_single_feature_in_column);
}
}
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, bool MFB_IS_NA, typename BIN_TYPE>
void CUDADataPartition::LaunchGenDataToLeftBitVectorKernelInner3(
GenDataToLeftBitVectorKernel_PARMS,
const bool max_bin_to_left,
const bool is_single_feature_in_column) {
if (!max_bin_to_left) {
LaunchGenDataToLeftBitVectorKernelInner4
<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, MFB_IS_ZERO, MFB_IS_NA, false, BIN_TYPE>
(GenBitVector_ARGS, is_single_feature_in_column);
} else {
LaunchGenDataToLeftBitVectorKernelInner4
<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, MFB_IS_ZERO, MFB_IS_NA, true, BIN_TYPE>
(GenBitVector_ARGS, is_single_feature_in_column);
}
}
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, bool MFB_IS_NA, bool MAX_TO_LEFT, typename BIN_TYPE>
void CUDADataPartition::LaunchGenDataToLeftBitVectorKernelInner4(
GenDataToLeftBitVectorKernel_PARMS,
const bool is_single_feature_in_column) {
if (!is_single_feature_in_column) {
GenDataToLeftBitVectorKernel
<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, MFB_IS_ZERO, MFB_IS_NA, MAX_TO_LEFT, true, BIN_TYPE>
<<<grid_dim_, block_dim_, 0, cuda_streams_[0]>>>(GenBitVector_ARGS,
cuda_block_to_left_offset_, cuda_block_data_to_left_offset_, cuda_block_data_to_right_offset_);
} else {
GenDataToLeftBitVectorKernel
<MIN_IS_MAX, MISSING_IS_ZERO, MISSING_IS_NA, MFB_IS_ZERO, MFB_IS_NA, MAX_TO_LEFT, false, BIN_TYPE>
<<<grid_dim_, block_dim_, 0, cuda_streams_[0]>>>(GenBitVector_ARGS,
cuda_block_to_left_offset_, cuda_block_data_to_left_offset_, cuda_block_data_to_right_offset_);
}
}
void CUDADataPartition::LaunchGenDataToLeftBitVectorKernel(
const data_size_t num_data_in_leaf,
const int split_feature_index,
const uint32_t split_threshold,
const uint8_t split_default_left,
const data_size_t leaf_data_start,
const int left_leaf_index,
const int right_leaf_index) {
const bool missing_is_zero = static_cast<bool>(cuda_column_data_->feature_missing_is_zero(split_feature_index));
const bool missing_is_na = static_cast<bool>(cuda_column_data_->feature_missing_is_na(split_feature_index));
const bool mfb_is_zero = static_cast<bool>(cuda_column_data_->feature_mfb_is_zero(split_feature_index));
const bool mfb_is_na = static_cast<bool>(cuda_column_data_->feature_mfb_is_na(split_feature_index));
const bool is_single_feature_in_column = is_single_feature_in_column_[split_feature_index];
const uint32_t default_bin = cuda_column_data_->feature_default_bin(split_feature_index);
const uint32_t most_freq_bin = cuda_column_data_->feature_most_freq_bin(split_feature_index);
const uint32_t min_bin = is_single_feature_in_column ? 1 : cuda_column_data_->feature_min_bin(split_feature_index);
const uint32_t max_bin = cuda_column_data_->feature_max_bin(split_feature_index);
uint32_t th = split_threshold + min_bin;
uint32_t t_zero_bin = min_bin + default_bin;
if (most_freq_bin == 0) {
--th;
--t_zero_bin;
}
uint8_t split_default_to_left = 0;
uint8_t split_missing_default_to_left = 0;
int default_leaf_index = right_leaf_index;
int missing_default_leaf_index = right_leaf_index;
if (most_freq_bin <= split_threshold) {
split_default_to_left = 1;
default_leaf_index = left_leaf_index;
}
if (missing_is_zero || missing_is_na) {
if (split_default_left) {
split_missing_default_to_left = 1;
missing_default_leaf_index = left_leaf_index;
}
}
const int column_index = cuda_column_data_->feature_to_column(split_feature_index);
const uint8_t bit_type = cuda_column_data_->column_bit_type(column_index);
const bool max_bin_to_left = (max_bin <= th);
const data_size_t* data_indices_in_leaf = cuda_data_indices_ + leaf_data_start;
const void* column_data_pointer = cuda_column_data_->GetColumnData(column_index);
if (bit_type == 8) {
const uint8_t* column_data = reinterpret_cast<const uint8_t*>(column_data_pointer);
LaunchGenDataToLeftBitVectorKernelInner<uint8_t>(
GenBitVector_ARGS,
missing_is_zero,
missing_is_na,
mfb_is_zero,
mfb_is_na,
max_bin_to_left,
is_single_feature_in_column);
LaunchUpdateDataIndexToLeafIndexKernel<uint8_t>(
UpdateDataIndexToLeafIndex_ARGS,
missing_is_zero,
missing_is_na,
mfb_is_zero,
mfb_is_na,
max_bin_to_left,
is_single_feature_in_column);
} else if (bit_type == 16) {
const uint16_t* column_data = reinterpret_cast<const uint16_t*>(column_data_pointer);
LaunchGenDataToLeftBitVectorKernelInner<uint16_t>(
GenBitVector_ARGS,
missing_is_zero,
missing_is_na,
mfb_is_zero,
mfb_is_na,
max_bin_to_left,
is_single_feature_in_column);
LaunchUpdateDataIndexToLeafIndexKernel<uint16_t>(
UpdateDataIndexToLeafIndex_ARGS,
missing_is_zero,
missing_is_na,
mfb_is_zero,
mfb_is_na,
max_bin_to_left,
is_single_feature_in_column);
} else if (bit_type == 32) {
const uint32_t* column_data = reinterpret_cast<const uint32_t*>(column_data_pointer);
LaunchGenDataToLeftBitVectorKernelInner<uint32_t>(
GenBitVector_ARGS,
missing_is_zero,
missing_is_na,
mfb_is_zero,
mfb_is_na,
max_bin_to_left,
is_single_feature_in_column);
LaunchUpdateDataIndexToLeafIndexKernel<uint32_t>(
UpdateDataIndexToLeafIndex_ARGS,
missing_is_zero,
missing_is_na,
mfb_is_zero,
mfb_is_na,
max_bin_to_left,
is_single_feature_in_column);
}
}
#undef UpdateDataIndexToLeafIndexKernel_PARAMS
#undef UpdateDataIndexToLeafIndex_ARGS
#undef GenDataToLeftBitVectorKernel_PARMS
#undef GenBitVector_ARGS
template <typename BIN_TYPE, bool USE_MIN_BIN>
__global__ void UpdateDataIndexToLeafIndexKernel_Categorical(
const data_size_t num_data_in_leaf, const data_size_t* data_indices_in_leaf,
const uint32_t* bitset, const int bitset_len, const BIN_TYPE* column_data,
// values from feature
const uint32_t max_bin, const uint32_t min_bin, const int8_t mfb_offset,
int* cuda_data_index_to_leaf_index, const int left_leaf_index, const int right_leaf_index,
const int default_leaf_index) {
const unsigned int local_data_index = blockIdx.x * blockDim.x + threadIdx.x;
if (local_data_index < num_data_in_leaf) {
const unsigned int global_data_index = data_indices_in_leaf[local_data_index];
const uint32_t bin = static_cast<uint32_t>(column_data[global_data_index]);
if (USE_MIN_BIN && (bin < min_bin || bin > max_bin)) {
cuda_data_index_to_leaf_index[global_data_index] = default_leaf_index;
} else if (!USE_MIN_BIN && bin == 0) {
cuda_data_index_to_leaf_index[global_data_index] = default_leaf_index;
} else if (CUDAFindInBitset(bitset, bitset_len, bin - min_bin + mfb_offset)) {
cuda_data_index_to_leaf_index[global_data_index] = left_leaf_index;
} else {
cuda_data_index_to_leaf_index[global_data_index] = right_leaf_index;
}
}
}
// for categorical features
template <typename BIN_TYPE, bool USE_MIN_BIN>
__global__ void GenDataToLeftBitVectorKernel_Categorical(
const data_size_t num_data_in_leaf, const data_size_t* data_indices_in_leaf,
const uint32_t* bitset, int bitset_len, const BIN_TYPE* column_data,
// values from feature
const uint32_t max_bin, const uint32_t min_bin, const int8_t mfb_offset,
const uint8_t split_default_to_left,
uint16_t* block_to_left_offset,
data_size_t* block_to_left_offset_buffer, data_size_t* block_to_right_offset_buffer) {
__shared__ uint16_t shared_mem_buffer[32];
uint16_t thread_to_left_offset_cnt = 0;
const unsigned int local_data_index = blockIdx.x * blockDim.x + threadIdx.x;
if (local_data_index < num_data_in_leaf) {
const unsigned int global_data_index = data_indices_in_leaf[local_data_index];
const uint32_t bin = static_cast<uint32_t>(column_data[global_data_index]);
if (USE_MIN_BIN && (bin < min_bin || bin > max_bin)) {
thread_to_left_offset_cnt = split_default_to_left;
} else if (!USE_MIN_BIN && bin == 0) {
thread_to_left_offset_cnt = split_default_to_left;
} else if (CUDAFindInBitset(bitset, bitset_len, bin - min_bin + mfb_offset)) {
thread_to_left_offset_cnt = 1;
}
}
__syncthreads();
PrepareOffset(num_data_in_leaf, block_to_left_offset + blockIdx.x * blockDim.x, block_to_left_offset_buffer, block_to_right_offset_buffer,
thread_to_left_offset_cnt, shared_mem_buffer);
}
#define GenBitVector_Categorical_ARGS \
num_data_in_leaf, data_indices_in_leaf, \
bitset, bitset_len, \
column_data, max_bin, min_bin, mfb_offset, split_default_to_left, \
cuda_block_to_left_offset_, cuda_block_data_to_left_offset_, cuda_block_data_to_right_offset_
#define UpdateDataIndexToLeafIndex_Categorical_ARGS \
num_data_in_leaf, data_indices_in_leaf, \
bitset, bitset_len, \
column_data, max_bin, min_bin, mfb_offset, \
cuda_data_index_to_leaf_index_, left_leaf_index, right_leaf_index, default_leaf_index
void CUDADataPartition::LaunchGenDataToLeftBitVectorCategoricalKernel(
const data_size_t num_data_in_leaf,
const int split_feature_index,
const uint32_t* bitset,
const int bitset_len,
const uint8_t split_default_left,
const data_size_t leaf_data_start,
const int left_leaf_index,
const int right_leaf_index) {
const data_size_t* data_indices_in_leaf = cuda_data_indices_ + leaf_data_start;
const int column_index = cuda_column_data_->feature_to_column(split_feature_index);
const uint8_t bit_type = cuda_column_data_->column_bit_type(column_index);
const bool is_single_feature_in_column = is_single_feature_in_column_[split_feature_index];
const uint32_t min_bin = is_single_feature_in_column ? 1 : cuda_column_data_->feature_min_bin(split_feature_index);
const uint32_t max_bin = cuda_column_data_->feature_max_bin(split_feature_index);
const uint32_t most_freq_bin = cuda_column_data_->feature_most_freq_bin(split_feature_index);
const uint32_t default_bin = cuda_column_data_->feature_default_bin(split_feature_index);
const void* column_data_pointer = cuda_column_data_->GetColumnData(column_index);
const int8_t mfb_offset = static_cast<int8_t>(most_freq_bin == 0);
std::vector<uint32_t> host_bitset(bitset_len, 0);
CopyFromCUDADeviceToHost<uint32_t>(host_bitset.data(), bitset, bitset_len, __FILE__, __LINE__);
uint8_t split_default_to_left = 0;
int default_leaf_index = right_leaf_index;
if (most_freq_bin > 0 && Common::FindInBitset(host_bitset.data(), bitset_len, most_freq_bin)) {
split_default_to_left = 1;
default_leaf_index = left_leaf_index;
}
if (bit_type == 8) {
const uint8_t* column_data = reinterpret_cast<const uint8_t*>(column_data_pointer);
if (is_single_feature_in_column) {
GenDataToLeftBitVectorKernel_Categorical<uint8_t, false><<<grid_dim_, block_dim_, 0, cuda_streams_[0]>>>(GenBitVector_Categorical_ARGS);
UpdateDataIndexToLeafIndexKernel_Categorical<uint8_t, false><<<grid_dim_, block_dim_, 0, cuda_streams_[3]>>>(UpdateDataIndexToLeafIndex_Categorical_ARGS);
} else {
GenDataToLeftBitVectorKernel_Categorical<uint8_t, true><<<grid_dim_, block_dim_, 0, cuda_streams_[0]>>>(GenBitVector_Categorical_ARGS);
UpdateDataIndexToLeafIndexKernel_Categorical<uint8_t, true><<<grid_dim_, block_dim_, 0, cuda_streams_[3]>>>(UpdateDataIndexToLeafIndex_Categorical_ARGS);
}
} else if (bit_type == 16) {
const uint16_t* column_data = reinterpret_cast<const uint16_t*>(column_data_pointer);
if (is_single_feature_in_column) {
GenDataToLeftBitVectorKernel_Categorical<uint16_t, false><<<grid_dim_, block_dim_, 0, cuda_streams_[0]>>>(GenBitVector_Categorical_ARGS);
UpdateDataIndexToLeafIndexKernel_Categorical<uint16_t, false><<<grid_dim_, block_dim_, 0, cuda_streams_[3]>>>(UpdateDataIndexToLeafIndex_Categorical_ARGS);
} else {
GenDataToLeftBitVectorKernel_Categorical<uint16_t, true><<<grid_dim_, block_dim_, 0, cuda_streams_[0]>>>(GenBitVector_Categorical_ARGS);
UpdateDataIndexToLeafIndexKernel_Categorical<uint16_t, true><<<grid_dim_, block_dim_, 0, cuda_streams_[3]>>>(UpdateDataIndexToLeafIndex_Categorical_ARGS);
}
} else if (bit_type == 32) {
const uint32_t* column_data = reinterpret_cast<const uint32_t*>(column_data_pointer);
if (is_single_feature_in_column) {
GenDataToLeftBitVectorKernel_Categorical<uint32_t, false><<<grid_dim_, block_dim_, 0, cuda_streams_[0]>>>(GenBitVector_Categorical_ARGS);
UpdateDataIndexToLeafIndexKernel_Categorical<uint32_t, false><<<grid_dim_, block_dim_, 0, cuda_streams_[3]>>>(UpdateDataIndexToLeafIndex_Categorical_ARGS);
} else {
GenDataToLeftBitVectorKernel_Categorical<uint32_t, true><<<grid_dim_, block_dim_, 0, cuda_streams_[0]>>>(GenBitVector_Categorical_ARGS);
UpdateDataIndexToLeafIndexKernel_Categorical<uint32_t, true><<<grid_dim_, block_dim_, 0, cuda_streams_[3]>>>(UpdateDataIndexToLeafIndex_Categorical_ARGS);
}
}
}
#undef GenBitVector_Categorical_ARGS
#undef UpdateDataIndexToLeafIndex_Categorical_ARGS
__global__ void AggregateBlockOffsetKernel0(
const int left_leaf_index,
const int right_leaf_index,
data_size_t* block_to_left_offset_buffer,
data_size_t* block_to_right_offset_buffer, data_size_t* cuda_leaf_data_start,
data_size_t* cuda_leaf_data_end, data_size_t* cuda_leaf_num_data, const data_size_t* cuda_data_indices,
const data_size_t num_blocks) {
__shared__ uint32_t shared_mem_buffer[32];
__shared__ uint32_t to_left_total_count;
const data_size_t num_data_in_leaf = cuda_leaf_num_data[left_leaf_index];
const unsigned int blockDim_x = blockDim.x;
const unsigned int threadIdx_x = threadIdx.x;
const data_size_t num_blocks_plus_1 = num_blocks + 1;
const uint32_t num_blocks_per_thread = (num_blocks_plus_1 + blockDim_x - 1) / blockDim_x;
const uint32_t remain = num_blocks_plus_1 - ((num_blocks_per_thread - 1) * blockDim_x);
const uint32_t remain_offset = remain * num_blocks_per_thread;
uint32_t thread_start_block_index = 0;
uint32_t thread_end_block_index = 0;
if (threadIdx_x < remain) {
thread_start_block_index = threadIdx_x * num_blocks_per_thread;
thread_end_block_index = min(thread_start_block_index + num_blocks_per_thread, num_blocks_plus_1);
} else {
thread_start_block_index = remain_offset + (num_blocks_per_thread - 1) * (threadIdx_x - remain);
thread_end_block_index = min(thread_start_block_index + num_blocks_per_thread - 1, num_blocks_plus_1);
}
if (threadIdx.x == 0) {
block_to_right_offset_buffer[0] = 0;
}
__syncthreads();
for (uint32_t block_index = thread_start_block_index + 1; block_index < thread_end_block_index; ++block_index) {
block_to_left_offset_buffer[block_index] += block_to_left_offset_buffer[block_index - 1];
block_to_right_offset_buffer[block_index] += block_to_right_offset_buffer[block_index - 1];
}
__syncthreads();
uint32_t block_to_left_offset = 0;
uint32_t block_to_right_offset = 0;
if (thread_start_block_index < thread_end_block_index && thread_start_block_index > 1) {
block_to_left_offset = block_to_left_offset_buffer[thread_start_block_index - 1];
block_to_right_offset = block_to_right_offset_buffer[thread_start_block_index - 1];
}
block_to_left_offset = ShufflePrefixSum<uint32_t>(block_to_left_offset, shared_mem_buffer);
__syncthreads();
block_to_right_offset = ShufflePrefixSum<uint32_t>(block_to_right_offset, shared_mem_buffer);
if (threadIdx_x == blockDim_x - 1) {
to_left_total_count = block_to_left_offset + block_to_left_offset_buffer[num_blocks];
}
__syncthreads();
const uint32_t to_left_thread_block_offset = block_to_left_offset;
const uint32_t to_right_thread_block_offset = block_to_right_offset + to_left_total_count;
for (uint32_t block_index = thread_start_block_index; block_index < thread_end_block_index; ++block_index) {
block_to_left_offset_buffer[block_index] += to_left_thread_block_offset;
block_to_right_offset_buffer[block_index] += to_right_thread_block_offset;
}
__syncthreads();
if (blockIdx.x == 0 && threadIdx.x == 0) {
const data_size_t old_leaf_data_end = cuda_leaf_data_end[left_leaf_index];
cuda_leaf_data_end[left_leaf_index] = cuda_leaf_data_start[left_leaf_index] + static_cast<data_size_t>(to_left_total_count);
cuda_leaf_num_data[left_leaf_index] = static_cast<data_size_t>(to_left_total_count);
cuda_leaf_data_start[right_leaf_index] = cuda_leaf_data_end[left_leaf_index];
cuda_leaf_data_end[right_leaf_index] = old_leaf_data_end;
cuda_leaf_num_data[right_leaf_index] = num_data_in_leaf - static_cast<data_size_t>(to_left_total_count);
}
}
__global__ void AggregateBlockOffsetKernel1(
const int left_leaf_index,
const int right_leaf_index,
data_size_t* block_to_left_offset_buffer,
data_size_t* block_to_right_offset_buffer, data_size_t* cuda_leaf_data_start,
data_size_t* cuda_leaf_data_end, data_size_t* cuda_leaf_num_data, const data_size_t* cuda_data_indices,
const data_size_t num_blocks) {
__shared__ uint32_t shared_mem_buffer[32];
__shared__ uint32_t to_left_total_count;
const data_size_t num_data_in_leaf = cuda_leaf_num_data[left_leaf_index];
const unsigned int threadIdx_x = threadIdx.x;
uint32_t block_to_left_offset = 0;
uint32_t block_to_right_offset = 0;
if (threadIdx_x < static_cast<unsigned int>(num_blocks)) {
block_to_left_offset = block_to_left_offset_buffer[threadIdx_x + 1];
block_to_right_offset = block_to_right_offset_buffer[threadIdx_x + 1];
}
block_to_left_offset = ShufflePrefixSum<uint32_t>(block_to_left_offset, shared_mem_buffer);
__syncthreads();
block_to_right_offset = ShufflePrefixSum<uint32_t>(block_to_right_offset, shared_mem_buffer);
if (threadIdx.x == blockDim.x - 1) {
to_left_total_count = block_to_left_offset;
}
__syncthreads();
if (threadIdx_x < static_cast<unsigned int>(num_blocks)) {
block_to_left_offset_buffer[threadIdx_x + 1] = block_to_left_offset;
block_to_right_offset_buffer[threadIdx_x + 1] = block_to_right_offset + to_left_total_count;
}
if (threadIdx_x == 0) {
block_to_right_offset_buffer[0] = to_left_total_count;
}
__syncthreads();
if (blockIdx.x == 0 && threadIdx.x == 0) {
const data_size_t old_leaf_data_end = cuda_leaf_data_end[left_leaf_index];
cuda_leaf_data_end[left_leaf_index] = cuda_leaf_data_start[left_leaf_index] + static_cast<data_size_t>(to_left_total_count);
cuda_leaf_num_data[left_leaf_index] = static_cast<data_size_t>(to_left_total_count);
cuda_leaf_data_start[right_leaf_index] = cuda_leaf_data_end[left_leaf_index];
cuda_leaf_data_end[right_leaf_index] = old_leaf_data_end;
cuda_leaf_num_data[right_leaf_index] = num_data_in_leaf - static_cast<data_size_t>(to_left_total_count);
}
}
__global__ void SplitTreeStructureKernel(const int left_leaf_index,
const int right_leaf_index,
data_size_t* block_to_left_offset_buffer,
data_size_t* block_to_right_offset_buffer, data_size_t* cuda_leaf_data_start,
data_size_t* cuda_leaf_data_end, data_size_t* cuda_leaf_num_data, const data_size_t* cuda_data_indices,
const CUDASplitInfo* best_split_info,
// for leaf splits information update
CUDALeafSplitsStruct* smaller_leaf_splits,
CUDALeafSplitsStruct* larger_leaf_splits,
const int num_total_bin,
hist_t* cuda_hist, hist_t** cuda_hist_pool,
double* cuda_leaf_output,
int* cuda_split_info_buffer) {
const unsigned int to_left_total_cnt = cuda_leaf_num_data[left_leaf_index];
double* cuda_split_info_buffer_for_hessians = reinterpret_cast<double*>(cuda_split_info_buffer + 8);
const unsigned int global_thread_index = blockIdx.x * blockDim.x + threadIdx.x;
if (global_thread_index == 0) {
cuda_leaf_output[left_leaf_index] = best_split_info->left_value;
} else if (global_thread_index == 1) {
cuda_leaf_output[right_leaf_index] = best_split_info->right_value;
} else if (global_thread_index == 2) {
cuda_split_info_buffer[0] = left_leaf_index;
} else if (global_thread_index == 3) {
cuda_split_info_buffer[1] = cuda_leaf_num_data[left_leaf_index];
} else if (global_thread_index == 4) {
cuda_split_info_buffer[2] = cuda_leaf_data_start[left_leaf_index];
} else if (global_thread_index == 5) {
cuda_split_info_buffer[3] = right_leaf_index;
} else if (global_thread_index == 6) {
cuda_split_info_buffer[4] = cuda_leaf_num_data[right_leaf_index];
} else if (global_thread_index == 7) {
cuda_split_info_buffer[5] = cuda_leaf_data_start[right_leaf_index];
} else if (global_thread_index == 8) {
cuda_split_info_buffer_for_hessians[0] = best_split_info->left_sum_hessians;
cuda_split_info_buffer_for_hessians[2] = best_split_info->left_sum_gradients;
} else if (global_thread_index == 9) {
cuda_split_info_buffer_for_hessians[1] = best_split_info->right_sum_hessians;
cuda_split_info_buffer_for_hessians[3] = best_split_info->right_sum_gradients;
}
if (cuda_leaf_num_data[left_leaf_index] < cuda_leaf_num_data[right_leaf_index]) {
if (global_thread_index == 0) {
hist_t* parent_hist_ptr = cuda_hist_pool[left_leaf_index];
cuda_hist_pool[right_leaf_index] = parent_hist_ptr;
cuda_hist_pool[left_leaf_index] = cuda_hist + 2 * right_leaf_index * num_total_bin;
smaller_leaf_splits->hist_in_leaf = cuda_hist_pool[left_leaf_index];
larger_leaf_splits->hist_in_leaf = cuda_hist_pool[right_leaf_index];
} else if (global_thread_index == 1) {
smaller_leaf_splits->sum_of_gradients = best_split_info->left_sum_gradients;
} else if (global_thread_index == 2) {
smaller_leaf_splits->sum_of_hessians = best_split_info->left_sum_hessians;
} else if (global_thread_index == 3) {
smaller_leaf_splits->num_data_in_leaf = to_left_total_cnt;
} else if (global_thread_index == 4) {
smaller_leaf_splits->gain = best_split_info->left_gain;
} else if (global_thread_index == 5) {
smaller_leaf_splits->leaf_value = best_split_info->left_value;
} else if (global_thread_index == 6) {
smaller_leaf_splits->data_indices_in_leaf = cuda_data_indices;
} else if (global_thread_index == 7) {
larger_leaf_splits->leaf_index = right_leaf_index;
} else if (global_thread_index == 8) {
larger_leaf_splits->sum_of_gradients = best_split_info->right_sum_gradients;
} else if (global_thread_index == 9) {
larger_leaf_splits->sum_of_hessians = best_split_info->right_sum_hessians;
} else if (global_thread_index == 10) {
larger_leaf_splits->num_data_in_leaf = cuda_leaf_num_data[right_leaf_index];
} else if (global_thread_index == 11) {
larger_leaf_splits->gain = best_split_info->right_gain;
} else if (global_thread_index == 12) {
larger_leaf_splits->leaf_value = best_split_info->right_value;
} else if (global_thread_index == 13) {
larger_leaf_splits->data_indices_in_leaf = cuda_data_indices + cuda_leaf_num_data[left_leaf_index];
} else if (global_thread_index == 14) {
cuda_split_info_buffer[6] = left_leaf_index;
} else if (global_thread_index == 15) {
cuda_split_info_buffer[7] = right_leaf_index;
} else if (global_thread_index == 16) {
smaller_leaf_splits->leaf_index = left_leaf_index;
}
} else {
if (global_thread_index == 0) {
larger_leaf_splits->leaf_index = left_leaf_index;
} else if (global_thread_index == 1) {
larger_leaf_splits->sum_of_gradients = best_split_info->left_sum_gradients;
} else if (global_thread_index == 2) {
larger_leaf_splits->sum_of_hessians = best_split_info->left_sum_hessians;
} else if (global_thread_index == 3) {
larger_leaf_splits->num_data_in_leaf = to_left_total_cnt;
} else if (global_thread_index == 4) {
larger_leaf_splits->gain = best_split_info->left_gain;
} else if (global_thread_index == 5) {
larger_leaf_splits->leaf_value = best_split_info->left_value;
} else if (global_thread_index == 6) {
larger_leaf_splits->data_indices_in_leaf = cuda_data_indices;
} else if (global_thread_index == 7) {
smaller_leaf_splits->leaf_index = right_leaf_index;
} else if (global_thread_index == 8) {
smaller_leaf_splits->sum_of_gradients = best_split_info->right_sum_gradients;
} else if (global_thread_index == 9) {
smaller_leaf_splits->sum_of_hessians = best_split_info->right_sum_hessians;
} else if (global_thread_index == 10) {
smaller_leaf_splits->num_data_in_leaf = cuda_leaf_num_data[right_leaf_index];
} else if (global_thread_index == 11) {
smaller_leaf_splits->gain = best_split_info->right_gain;
} else if (global_thread_index == 12) {
smaller_leaf_splits->leaf_value = best_split_info->right_value;
} else if (global_thread_index == 13) {
smaller_leaf_splits->data_indices_in_leaf = cuda_data_indices + cuda_leaf_num_data[left_leaf_index];
} else if (global_thread_index == 14) {
cuda_hist_pool[right_leaf_index] = cuda_hist + 2 * right_leaf_index * num_total_bin;
smaller_leaf_splits->hist_in_leaf = cuda_hist_pool[right_leaf_index];
} else if (global_thread_index == 15) {
larger_leaf_splits->hist_in_leaf = cuda_hist_pool[left_leaf_index];
} else if (global_thread_index == 16) {
cuda_split_info_buffer[6] = right_leaf_index;
} else if (global_thread_index == 17) {
cuda_split_info_buffer[7] = left_leaf_index;
}
}
}
__global__ void SplitInnerKernel(const int left_leaf_index, const int right_leaf_index,
const data_size_t* cuda_leaf_data_start, const data_size_t* cuda_leaf_num_data,
const data_size_t* cuda_data_indices,
const data_size_t* block_to_left_offset_buffer, const data_size_t* block_to_right_offset_buffer,
const uint16_t* block_to_left_offset, data_size_t* out_data_indices_in_leaf) {
const data_size_t leaf_num_data_offset = cuda_leaf_data_start[left_leaf_index];
const data_size_t num_data_in_leaf = cuda_leaf_num_data[left_leaf_index] + cuda_leaf_num_data[right_leaf_index];
const unsigned int threadIdx_x = threadIdx.x;
const unsigned int blockDim_x = blockDim.x;
const unsigned int global_thread_index = blockIdx.x * blockDim_x + threadIdx_x;
const data_size_t* cuda_data_indices_in_leaf = cuda_data_indices + leaf_num_data_offset;
const uint16_t* block_to_left_offset_ptr = block_to_left_offset + blockIdx.x * blockDim_x;
const uint32_t to_right_block_offset = block_to_right_offset_buffer[blockIdx.x];
const uint32_t to_left_block_offset = block_to_left_offset_buffer[blockIdx.x];
data_size_t* left_out_data_indices_in_leaf = out_data_indices_in_leaf + to_left_block_offset;
data_size_t* right_out_data_indices_in_leaf = out_data_indices_in_leaf + to_right_block_offset;
if (static_cast<data_size_t>(global_thread_index) < num_data_in_leaf) {
const uint32_t thread_to_left_offset = (threadIdx_x == 0 ? 0 : block_to_left_offset_ptr[threadIdx_x - 1]);
const bool to_left = block_to_left_offset_ptr[threadIdx_x] > thread_to_left_offset;
if (to_left) {
left_out_data_indices_in_leaf[thread_to_left_offset] = cuda_data_indices_in_leaf[global_thread_index];
} else {
const uint32_t thread_to_right_offset = threadIdx.x - thread_to_left_offset;
right_out_data_indices_in_leaf[thread_to_right_offset] = cuda_data_indices_in_leaf[global_thread_index];
}
}
}
__global__ void CopyDataIndicesKernel(
const data_size_t num_data_in_leaf,
const data_size_t* out_data_indices_in_leaf,
data_size_t* cuda_data_indices) {
const unsigned int threadIdx_x = threadIdx.x;
const unsigned int global_thread_index = blockIdx.x * blockDim.x + threadIdx_x;
if (global_thread_index < num_data_in_leaf) {
cuda_data_indices[global_thread_index] = out_data_indices_in_leaf[global_thread_index];
}
}
void CUDADataPartition::LaunchSplitInnerKernel(
const data_size_t num_data_in_leaf,
const CUDASplitInfo* best_split_info,
const int left_leaf_index,
const int right_leaf_index,
// for leaf splits information update
CUDALeafSplitsStruct* smaller_leaf_splits,
CUDALeafSplitsStruct* larger_leaf_splits,
data_size_t* left_leaf_num_data_ref,
data_size_t* right_leaf_num_data_ref,
data_size_t* left_leaf_start_ref,
data_size_t* right_leaf_start_ref,
double* left_leaf_sum_of_hessians_ref,
double* right_leaf_sum_of_hessians_ref,
double* left_leaf_sum_of_gradients_ref,
double* right_leaf_sum_of_gradients_ref) {
int num_blocks_final_ref = grid_dim_ - 1;
int num_blocks_final_aligned = 1;
while (num_blocks_final_ref > 0) {
num_blocks_final_aligned <<= 1;
num_blocks_final_ref >>= 1;
}
global_timer.Start("CUDADataPartition::AggregateBlockOffsetKernel");
if (grid_dim_ > AGGREGATE_BLOCK_SIZE_DATA_PARTITION) {
AggregateBlockOffsetKernel0<<<1, AGGREGATE_BLOCK_SIZE_DATA_PARTITION, 0, cuda_streams_[0]>>>(
left_leaf_index,
right_leaf_index,
cuda_block_data_to_left_offset_,
cuda_block_data_to_right_offset_, cuda_leaf_data_start_, cuda_leaf_data_end_,
cuda_leaf_num_data_, cuda_data_indices_,
grid_dim_);
} else {
AggregateBlockOffsetKernel1<<<1, num_blocks_final_aligned, 0, cuda_streams_[0]>>>(
left_leaf_index,
right_leaf_index,
cuda_block_data_to_left_offset_,
cuda_block_data_to_right_offset_, cuda_leaf_data_start_, cuda_leaf_data_end_,
cuda_leaf_num_data_, cuda_data_indices_,
grid_dim_);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
global_timer.Stop("CUDADataPartition::AggregateBlockOffsetKernel");
global_timer.Start("CUDADataPartition::SplitInnerKernel");
SplitInnerKernel<<<grid_dim_, block_dim_, 0, cuda_streams_[1]>>>(
left_leaf_index, right_leaf_index, cuda_leaf_data_start_, cuda_leaf_num_data_, cuda_data_indices_,
cuda_block_data_to_left_offset_, cuda_block_data_to_right_offset_, cuda_block_to_left_offset_,
cuda_out_data_indices_in_leaf_);
global_timer.Stop("CUDADataPartition::SplitInnerKernel");
SynchronizeCUDADevice(__FILE__, __LINE__);
global_timer.Start("CUDADataPartition::SplitTreeStructureKernel");
SplitTreeStructureKernel<<<4, 5, 0, cuda_streams_[0]>>>(left_leaf_index, right_leaf_index,
cuda_block_data_to_left_offset_,
cuda_block_data_to_right_offset_, cuda_leaf_data_start_, cuda_leaf_data_end_,
cuda_leaf_num_data_, cuda_out_data_indices_in_leaf_,
best_split_info,
smaller_leaf_splits,
larger_leaf_splits,
num_total_bin_,
cuda_hist_,
cuda_hist_pool_,
cuda_leaf_output_, cuda_split_info_buffer_);
global_timer.Stop("CUDADataPartition::SplitTreeStructureKernel");
std::vector<int> cpu_split_info_buffer(16);
const double* cpu_sum_hessians_info = reinterpret_cast<const double*>(cpu_split_info_buffer.data() + 8);
global_timer.Start("CUDADataPartition::CopyFromCUDADeviceToHostAsync");
CopyFromCUDADeviceToHostAsync<int>(cpu_split_info_buffer.data(), cuda_split_info_buffer_, 16, cuda_streams_[0], __FILE__, __LINE__);
SynchronizeCUDADevice(__FILE__, __LINE__);
global_timer.Stop("CUDADataPartition::CopyFromCUDADeviceToHostAsync");
const data_size_t left_leaf_num_data = cpu_split_info_buffer[1];
const data_size_t left_leaf_data_start = cpu_split_info_buffer[2];
const data_size_t right_leaf_num_data = cpu_split_info_buffer[4];
global_timer.Start("CUDADataPartition::CopyDataIndicesKernel");
CopyDataIndicesKernel<<<grid_dim_, block_dim_, 0, cuda_streams_[2]>>>(
left_leaf_num_data + right_leaf_num_data, cuda_out_data_indices_in_leaf_, cuda_data_indices_ + left_leaf_data_start);
global_timer.Stop("CUDADataPartition::CopyDataIndicesKernel");
const data_size_t right_leaf_data_start = cpu_split_info_buffer[5];
*left_leaf_num_data_ref = left_leaf_num_data;
*left_leaf_start_ref = left_leaf_data_start;
*right_leaf_num_data_ref = right_leaf_num_data;
*right_leaf_start_ref = right_leaf_data_start;
*left_leaf_sum_of_hessians_ref = cpu_sum_hessians_info[0];
*right_leaf_sum_of_hessians_ref = cpu_sum_hessians_info[1];
*left_leaf_sum_of_gradients_ref = cpu_sum_hessians_info[2];
*right_leaf_sum_of_gradients_ref = cpu_sum_hessians_info[3];
}
template <bool USE_BAGGING>
__global__ void AddPredictionToScoreKernel(
const data_size_t* data_indices_in_leaf,
const double* leaf_value, double* cuda_scores,
const int* cuda_data_index_to_leaf_index, const data_size_t num_data) {
const unsigned int threadIdx_x = threadIdx.x;
const unsigned int blockIdx_x = blockIdx.x;
const unsigned int blockDim_x = blockDim.x;
const data_size_t local_data_index = static_cast<data_size_t>(blockIdx_x * blockDim_x + threadIdx_x);
if (local_data_index < num_data) {
if (USE_BAGGING) {
const data_size_t global_data_index = data_indices_in_leaf[local_data_index];
const int leaf_index = cuda_data_index_to_leaf_index[global_data_index];
const double leaf_prediction_value = leaf_value[leaf_index];
cuda_scores[local_data_index] = leaf_prediction_value;
} else {
const int leaf_index = cuda_data_index_to_leaf_index[local_data_index];
const double leaf_prediction_value = leaf_value[leaf_index];
cuda_scores[local_data_index] = leaf_prediction_value;
}
}
}
void CUDADataPartition::LaunchAddPredictionToScoreKernel(const double* leaf_value, double* cuda_scores) {
global_timer.Start("CUDADataPartition::AddPredictionToScoreKernel");
const data_size_t num_data_in_root = root_num_data();
const int num_blocks = (num_data_in_root + FILL_INDICES_BLOCK_SIZE_DATA_PARTITION - 1) / FILL_INDICES_BLOCK_SIZE_DATA_PARTITION;
if (use_bagging_) {
AddPredictionToScoreKernel<true><<<num_blocks, FILL_INDICES_BLOCK_SIZE_DATA_PARTITION>>>(
cuda_data_indices_, leaf_value, cuda_scores, cuda_data_index_to_leaf_index_, num_data_in_root);
} else {
AddPredictionToScoreKernel<false><<<num_blocks, FILL_INDICES_BLOCK_SIZE_DATA_PARTITION>>>(
cuda_data_indices_, leaf_value, cuda_scores, cuda_data_index_to_leaf_index_, num_data_in_root);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
global_timer.Stop("CUDADataPartition::AddPredictionToScoreKernel");
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifndef LIGHTGBM_TREELEARNER_CUDA_CUDA_DATA_PARTITION_HPP_
#define LIGHTGBM_TREELEARNER_CUDA_CUDA_DATA_PARTITION_HPP_
#ifdef USE_CUDA_EXP
#include <LightGBM/bin.h>
#include <LightGBM/meta.h>
#include <LightGBM/tree.h>
#include <vector>
#include <LightGBM/cuda/cuda_column_data.hpp>
#include <LightGBM/cuda/cuda_split_info.hpp>
#include <LightGBM/cuda/cuda_tree.hpp>
#include "cuda_leaf_splits.hpp"
#define FILL_INDICES_BLOCK_SIZE_DATA_PARTITION (1024)
#define SPLIT_INDICES_BLOCK_SIZE_DATA_PARTITION (1024)
#define AGGREGATE_BLOCK_SIZE_DATA_PARTITION (1024)
namespace LightGBM {
class CUDADataPartition {
public:
CUDADataPartition(
const Dataset* train_data,
const int num_total_bin,
const int num_leaves,
const int num_threads,
hist_t* cuda_hist);
~CUDADataPartition();
void Init();
void BeforeTrain();
void Split(
// input best split info
const CUDASplitInfo* best_split_info,
const int left_leaf_index,
const int right_leaf_index,
const int leaf_best_split_feature,
const uint32_t leaf_best_split_threshold,
const uint32_t* categorical_bitset,
const int categorical_bitset_len,
const uint8_t leaf_best_split_default_left,
const data_size_t num_data_in_leaf,
const data_size_t leaf_data_start,
// for leaf information update
CUDALeafSplitsStruct* smaller_leaf_splits,
CUDALeafSplitsStruct* larger_leaf_splits,
// gather information for CPU, used for launching kernels
data_size_t* left_leaf_num_data,
data_size_t* right_leaf_num_data,
data_size_t* left_leaf_start,
data_size_t* right_leaf_start,
double* left_leaf_sum_of_hessians,
double* right_leaf_sum_of_hessians,
double* left_leaf_sum_of_gradients,
double* right_leaf_sum_of_gradients);
void UpdateTrainScore(const Tree* tree, double* cuda_scores);
void SetUsedDataIndices(const data_size_t* used_indices, const data_size_t num_used_indices);
void SetBaggingSubset(const Dataset* subset);
void ResetTrainingData(const Dataset* train_data, const int num_total_bin, hist_t* cuda_hist);
void ResetConfig(const Config* config, hist_t* cuda_hist);
void ResetByLeafPred(const std::vector<int>& leaf_pred, int num_leaves);
data_size_t root_num_data() const {
if (use_bagging_) {
return num_used_indices_;
} else {
return num_data_;
}
}
const data_size_t* cuda_data_indices() const { return cuda_data_indices_; }
const data_size_t* cuda_leaf_num_data() const { return cuda_leaf_num_data_; }
const data_size_t* cuda_leaf_data_start() const { return cuda_leaf_data_start_; }
const int* cuda_data_index_to_leaf_index() const { return cuda_data_index_to_leaf_index_; }
bool use_bagging() const { return use_bagging_; }
private:
void CalcBlockDim(const data_size_t num_data_in_leaf);
void GenDataToLeftBitVector(
const data_size_t num_data_in_leaf,
const int split_feature_index,
const uint32_t split_threshold,
const uint32_t* categorical_bitset,
const int categorical_bitset_len,
const uint8_t split_default_left,
const data_size_t leaf_data_start,
const int left_leaf_index,
const int right_leaf_index);
void SplitInner(
// input best split info
const data_size_t num_data_in_leaf,
const CUDASplitInfo* best_split_info,
const int left_leaf_index,
const int right_leaf_index,
// for leaf splits information update
CUDALeafSplitsStruct* smaller_leaf_splits,
CUDALeafSplitsStruct* larger_leaf_splits,
// gather information for CPU, used for launching kernels
data_size_t* left_leaf_num_data,
data_size_t* right_leaf_num_data,
data_size_t* left_leaf_start,
data_size_t* right_leaf_start,
double* left_leaf_sum_of_hessians,
double* right_leaf_sum_of_hessians,
double* left_leaf_sum_of_gradients,
double* right_leaf_sum_of_gradients);
// kernel launch functions
void LaunchFillDataIndicesBeforeTrain();
void LaunchSplitInnerKernel(
// input best split info
const data_size_t num_data_in_leaf,
const CUDASplitInfo* best_split_info,
const int left_leaf_index,
const int right_leaf_index,
// for leaf splits information update
CUDALeafSplitsStruct* smaller_leaf_splits,
CUDALeafSplitsStruct* larger_leaf_splits,
// gather information for CPU, used for launching kernels
data_size_t* left_leaf_num_data,
data_size_t* right_leaf_num_data,
data_size_t* left_leaf_start,
data_size_t* right_leaf_start,
double* left_leaf_sum_of_hessians,
double* right_leaf_sum_of_hessians,
double* left_leaf_sum_of_gradients,
double* right_leaf_sum_of_gradients);
void LaunchGenDataToLeftBitVectorKernel(
const data_size_t num_data_in_leaf,
const int split_feature_index,
const uint32_t split_threshold,
const uint8_t split_default_left,
const data_size_t leaf_data_start,
const int left_leaf_index,
const int right_leaf_index);
void LaunchGenDataToLeftBitVectorCategoricalKernel(
const data_size_t num_data_in_leaf,
const int split_feature_index,
const uint32_t* bitset,
const int bitset_len,
const uint8_t split_default_left,
const data_size_t leaf_data_start,
const int left_leaf_index,
const int right_leaf_index);
#define GenDataToLeftBitVectorKernel_PARMS \
const BIN_TYPE* column_data, \
const data_size_t num_data_in_leaf, \
const data_size_t* data_indices_in_leaf, \
const uint32_t th, \
const uint32_t t_zero_bin, \
const uint32_t max_bin, \
const uint32_t min_bin, \
const uint8_t split_default_to_left, \
const uint8_t split_missing_default_to_left
template <typename BIN_TYPE>
void LaunchGenDataToLeftBitVectorKernelInner(
GenDataToLeftBitVectorKernel_PARMS,
const bool missing_is_zero,
const bool missing_is_na,
const bool mfb_is_zero,
const bool mfb_is_na,
const bool max_bin_to_left,
const bool is_single_feature_in_column);
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, typename BIN_TYPE>
void LaunchGenDataToLeftBitVectorKernelInner0(
GenDataToLeftBitVectorKernel_PARMS,
const bool missing_is_na,
const bool mfb_is_zero,
const bool mfb_is_na,
const bool max_bin_to_left,
const bool is_single_feature_in_column);
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, typename BIN_TYPE>
void LaunchGenDataToLeftBitVectorKernelInner1(
GenDataToLeftBitVectorKernel_PARMS,
const bool mfb_is_zero,
const bool mfb_is_na,
const bool max_bin_to_left,
const bool is_single_feature_in_column);
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, typename BIN_TYPE>
void LaunchGenDataToLeftBitVectorKernelInner2(
GenDataToLeftBitVectorKernel_PARMS,
const bool mfb_is_na,
const bool max_bin_to_left,
const bool is_single_feature_in_column);
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, bool MFB_IS_NA, typename BIN_TYPE>
void LaunchGenDataToLeftBitVectorKernelInner3(
GenDataToLeftBitVectorKernel_PARMS,
const bool max_bin_to_left,
const bool is_single_feature_in_column);
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, bool MFB_IS_NA, bool MAX_TO_LEFT, typename BIN_TYPE>
void LaunchGenDataToLeftBitVectorKernelInner4(
GenDataToLeftBitVectorKernel_PARMS,
const bool is_single_feature_in_column);
#undef GenDataToLeftBitVectorKernel_PARMS
#define UpdateDataIndexToLeafIndexKernel_PARAMS \
const BIN_TYPE* column_data, \
const data_size_t num_data_in_leaf, \
const data_size_t* data_indices_in_leaf, \
const uint32_t th, \
const uint32_t t_zero_bin, \
const uint32_t max_bin_ref, \
const uint32_t min_bin_ref, \
const int left_leaf_index, \
const int right_leaf_index, \
const int default_leaf_index, \
const int missing_default_leaf_index
template <typename BIN_TYPE>
void LaunchUpdateDataIndexToLeafIndexKernel(
UpdateDataIndexToLeafIndexKernel_PARAMS,
const bool missing_is_zero,
const bool missing_is_na,
const bool mfb_is_zero,
const bool mfb_is_na,
const bool max_to_left,
const bool is_single_feature_in_column);
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, typename BIN_TYPE>
void LaunchUpdateDataIndexToLeafIndexKernel_Inner0(
UpdateDataIndexToLeafIndexKernel_PARAMS,
const bool missing_is_na,
const bool mfb_is_zero,
const bool mfb_is_na,
const bool max_to_left,
const bool is_single_feature_in_column);
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, typename BIN_TYPE>
void LaunchUpdateDataIndexToLeafIndexKernel_Inner1(
UpdateDataIndexToLeafIndexKernel_PARAMS,
const bool mfb_is_zero,
const bool mfb_is_na,
const bool max_to_left,
const bool is_single_feature_in_column);
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, typename BIN_TYPE>
void LaunchUpdateDataIndexToLeafIndexKernel_Inner2(
UpdateDataIndexToLeafIndexKernel_PARAMS,
const bool mfb_is_na,
const bool max_to_left,
const bool is_single_feature_in_column);
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, bool MFB_IS_NA, typename BIN_TYPE>
void LaunchUpdateDataIndexToLeafIndexKernel_Inner3(
UpdateDataIndexToLeafIndexKernel_PARAMS,
const bool max_to_left,
const bool is_single_feature_in_column);
template <bool MIN_IS_MAX, bool MISSING_IS_ZERO, bool MISSING_IS_NA, bool MFB_IS_ZERO, bool MFB_IS_NA, bool MAX_TO_LEFT, typename BIN_TYPE>
void LaunchUpdateDataIndexToLeafIndexKernel_Inner4(
UpdateDataIndexToLeafIndexKernel_PARAMS,
const bool is_single_feature_in_column);
#undef UpdateDataIndexToLeafIndexKernel_PARAMS
void LaunchAddPredictionToScoreKernel(const double* leaf_value, double* cuda_scores);
void LaunchFillDataIndexToLeafIndex();
// Host memory
// dataset information
/*! \brief number of training data */
data_size_t num_data_;
/*! \brief number of features in training data */
int num_features_;
/*! \brief number of total bins in training data */
int num_total_bin_;
/*! \brief bin data stored by column */
const CUDAColumnData* cuda_column_data_;
/*! \brief grid dimension when splitting one leaf */
int grid_dim_;
/*! \brief block dimension when splitting one leaf */
int block_dim_;
/*! \brief add train score buffer in host */
mutable std::vector<double> add_train_score_;
/*! \brief data indices used in this iteration */
const data_size_t* used_indices_;
/*! \brief marks whether a feature is a categorical feature */
std::vector<bool> is_categorical_feature_;
/*! \brief marks whether a feature is the only feature in its group */
std::vector<bool> is_single_feature_in_column_;
// config information
/*! \brief maximum number of leaves in a tree */
int num_leaves_;
/*! \brief number of threads */
int num_threads_;
// per iteration information
/*! \brief whether bagging is used in this iteration */
bool use_bagging_;
/*! \brief number of used data indices in this iteration */
data_size_t num_used_indices_;
// tree structure information
/*! \brief current number of leaves in tree */
int cur_num_leaves_;
// split algorithm related
/*! \brief maximum number of blocks to aggregate after finding bit vector by blocks */
int max_num_split_indices_blocks_;
// CUDA streams
/*! \brief cuda streams used for asynchronizing kernel computing and memory copy */
std::vector<cudaStream_t> cuda_streams_;
// CUDA memory, held by this object
// tree structure information
/*! \brief data indices by leaf */
data_size_t* cuda_data_indices_;
/*! \brief start position of each leaf in cuda_data_indices_ */
data_size_t* cuda_leaf_data_start_;
/*! \brief end position of each leaf in cuda_data_indices_ */
data_size_t* cuda_leaf_data_end_;
/*! \brief number of data in each leaf */
data_size_t* cuda_leaf_num_data_;
/*! \brief records the histogram of each leaf */
hist_t** cuda_hist_pool_;
/*! \brief records the value of each leaf */
double* cuda_leaf_output_;
// split data algorithm related
uint16_t* cuda_block_to_left_offset_;
/*! \brief maps data index to leaf index, for adding scores to training data set */
int* cuda_data_index_to_leaf_index_;
/*! \brief prefix sum of number of data going to left in all blocks */
data_size_t* cuda_block_data_to_left_offset_;
/*! \brief prefix sum of number of data going to right in all blocks */
data_size_t* cuda_block_data_to_right_offset_;
/*! \brief buffer for splitting data indices, will be copied back to cuda_data_indices_ after split */
data_size_t* cuda_out_data_indices_in_leaf_;
// split tree structure algorithm related
/*! \brief buffer to store split information, prepared to be copied to cpu */
int* cuda_split_info_buffer_;
// dataset information
/*! \brief number of data in training set, for intialization of cuda_leaf_num_data_ and cuda_leaf_data_end_ */
data_size_t* cuda_num_data_;
// for train score update
/*! \brief added train score buffer in CUDA */
double* cuda_add_train_score_;
// CUDA memory, held by other object
// dataset information
/*! \brief beginning of histograms, for initialization of cuda_hist_pool_ */
hist_t* cuda_hist_;
};
} // namespace LightGBM
#endif // USE_CUDA_EXP
#endif // LIGHTGBM_TREELEARNER_CUDA_CUDA_DATA_PARTITION_HPP_
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA_EXP
#include "cuda_histogram_constructor.hpp"
#include <algorithm>
namespace LightGBM {
CUDAHistogramConstructor::CUDAHistogramConstructor(
const Dataset* train_data,
const int num_leaves,
const int num_threads,
const std::vector<uint32_t>& feature_hist_offsets,
const int min_data_in_leaf,
const double min_sum_hessian_in_leaf,
const int gpu_device_id,
const bool gpu_use_dp):
num_data_(train_data->num_data()),
num_features_(train_data->num_features()),
num_leaves_(num_leaves),
num_threads_(num_threads),
min_data_in_leaf_(min_data_in_leaf),
min_sum_hessian_in_leaf_(min_sum_hessian_in_leaf),
gpu_device_id_(gpu_device_id),
gpu_use_dp_(gpu_use_dp) {
InitFeatureMetaInfo(train_data, feature_hist_offsets);
cuda_row_data_.reset(nullptr);
cuda_feature_num_bins_ = nullptr;
cuda_feature_hist_offsets_ = nullptr;
cuda_feature_most_freq_bins_ = nullptr;
cuda_hist_ = nullptr;
cuda_need_fix_histogram_features_ = nullptr;
cuda_need_fix_histogram_features_num_bin_aligned_ = nullptr;
}
CUDAHistogramConstructor::~CUDAHistogramConstructor() {
DeallocateCUDAMemory<uint32_t>(&cuda_feature_num_bins_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_feature_hist_offsets_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_feature_most_freq_bins_, __FILE__, __LINE__);
DeallocateCUDAMemory<hist_t>(&cuda_hist_, __FILE__, __LINE__);
DeallocateCUDAMemory<int>(&cuda_need_fix_histogram_features_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, __FILE__, __LINE__);
gpuAssert(cudaStreamDestroy(cuda_stream_), __FILE__, __LINE__);
}
void CUDAHistogramConstructor::InitFeatureMetaInfo(const Dataset* train_data, const std::vector<uint32_t>& feature_hist_offsets) {
need_fix_histogram_features_.clear();
need_fix_histogram_features_num_bin_aligend_.clear();
feature_num_bins_.clear();
feature_most_freq_bins_.clear();
for (int feature_index = 0; feature_index < train_data->num_features(); ++feature_index) {
const BinMapper* bin_mapper = train_data->FeatureBinMapper(feature_index);
const uint32_t most_freq_bin = bin_mapper->GetMostFreqBin();
if (most_freq_bin != 0) {
need_fix_histogram_features_.emplace_back(feature_index);
uint32_t num_bin_ref = static_cast<uint32_t>(bin_mapper->num_bin()) - 1;
uint32_t num_bin_aligned = 1;
while (num_bin_ref > 0) {
num_bin_aligned <<= 1;
num_bin_ref >>= 1;
}
need_fix_histogram_features_num_bin_aligend_.emplace_back(num_bin_aligned);
}
feature_num_bins_.emplace_back(static_cast<uint32_t>(bin_mapper->num_bin()));
feature_most_freq_bins_.emplace_back(most_freq_bin);
}
feature_hist_offsets_.clear();
for (size_t i = 0; i < feature_hist_offsets.size(); ++i) {
feature_hist_offsets_.emplace_back(feature_hist_offsets[i]);
}
if (feature_hist_offsets.empty()) {
num_total_bin_ = 0;
} else {
num_total_bin_ = static_cast<int>(feature_hist_offsets.back());
}
}
void CUDAHistogramConstructor::BeforeTrain(const score_t* gradients, const score_t* hessians) {
cuda_gradients_ = gradients;
cuda_hessians_ = hessians;
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
}
void CUDAHistogramConstructor::Init(const Dataset* train_data, TrainingShareStates* share_state) {
AllocateCUDAMemory<hist_t>(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_num_bins_,
feature_num_bins_.data(), feature_num_bins_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_hist_offsets_,
feature_hist_offsets_.data(), feature_hist_offsets_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_most_freq_bins_,
feature_most_freq_bins_.data(), feature_most_freq_bins_.size(), __FILE__, __LINE__);
cuda_row_data_.reset(new CUDARowData(train_data, share_state, gpu_device_id_, gpu_use_dp_));
cuda_row_data_->Init(train_data, share_state);
CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_stream_));
InitCUDAMemoryFromHostMemory<int>(&cuda_need_fix_histogram_features_, need_fix_histogram_features_.data(), need_fix_histogram_features_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, need_fix_histogram_features_num_bin_aligend_.data(),
need_fix_histogram_features_num_bin_aligend_.size(), __FILE__, __LINE__);
if (cuda_row_data_->NumLargeBinPartition() > 0) {
int grid_dim_x = 0, grid_dim_y = 0, block_dim_x = 0, block_dim_y = 0;
CalcConstructHistogramKernelDim(&grid_dim_x, &grid_dim_y, &block_dim_x, &block_dim_y, num_data_);
const size_t buffer_size = static_cast<size_t>(grid_dim_y) * static_cast<size_t>(num_total_bin_) * 2;
AllocateCUDAMemory<float>(&cuda_hist_buffer_, buffer_size, __FILE__, __LINE__);
}
}
void CUDAHistogramConstructor::ConstructHistogramForLeaf(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits,
const data_size_t num_data_in_smaller_leaf,
const data_size_t num_data_in_larger_leaf,
const double sum_hessians_in_smaller_leaf,
const double sum_hessians_in_larger_leaf) {
if ((num_data_in_smaller_leaf <= min_data_in_leaf_ || sum_hessians_in_smaller_leaf <= min_sum_hessian_in_leaf_) &&
(num_data_in_larger_leaf <= min_data_in_leaf_ || sum_hessians_in_larger_leaf <= min_sum_hessian_in_leaf_)) {
return;
}
LaunchConstructHistogramKernel(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
SynchronizeCUDADevice(__FILE__, __LINE__);
global_timer.Start("CUDAHistogramConstructor::ConstructHistogramForLeaf::LaunchSubtractHistogramKernel");
LaunchSubtractHistogramKernel(cuda_smaller_leaf_splits, cuda_larger_leaf_splits);
global_timer.Stop("CUDAHistogramConstructor::ConstructHistogramForLeaf::LaunchSubtractHistogramKernel");
}
void CUDAHistogramConstructor::CalcConstructHistogramKernelDim(
int* grid_dim_x,
int* grid_dim_y,
int* block_dim_x,
int* block_dim_y,
const data_size_t num_data_in_smaller_leaf) {
*block_dim_x = cuda_row_data_->max_num_column_per_partition();
*block_dim_y = NUM_THRADS_PER_BLOCK / cuda_row_data_->max_num_column_per_partition();
*grid_dim_x = cuda_row_data_->num_feature_partitions();
*grid_dim_y = std::max(min_grid_dim_y_,
((num_data_in_smaller_leaf + NUM_DATA_PER_THREAD - 1) / NUM_DATA_PER_THREAD + (*block_dim_y) - 1) / (*block_dim_y));
}
void CUDAHistogramConstructor::ResetTrainingData(const Dataset* train_data, TrainingShareStates* share_states) {
num_data_ = train_data->num_data();
num_features_ = train_data->num_features();
InitFeatureMetaInfo(train_data, share_states->feature_hist_offsets());
if (feature_num_bins_.size() > 0) {
DeallocateCUDAMemory<uint32_t>(&cuda_feature_num_bins_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_feature_hist_offsets_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_feature_most_freq_bins_, __FILE__, __LINE__);
DeallocateCUDAMemory<int>(&cuda_need_fix_histogram_features_, __FILE__, __LINE__);
DeallocateCUDAMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, __FILE__, __LINE__);
DeallocateCUDAMemory<hist_t>(&cuda_hist_, __FILE__, __LINE__);
}
AllocateCUDAMemory<hist_t>(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_num_bins_,
feature_num_bins_.data(), feature_num_bins_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_hist_offsets_,
feature_hist_offsets_.data(), feature_hist_offsets_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_feature_most_freq_bins_,
feature_most_freq_bins_.data(), feature_most_freq_bins_.size(), __FILE__, __LINE__);
cuda_row_data_.reset(new CUDARowData(train_data, share_states, gpu_device_id_, gpu_use_dp_));
cuda_row_data_->Init(train_data, share_states);
InitCUDAMemoryFromHostMemory<int>(&cuda_need_fix_histogram_features_, need_fix_histogram_features_.data(), need_fix_histogram_features_.size(), __FILE__, __LINE__);
InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_need_fix_histogram_features_num_bin_aligned_, need_fix_histogram_features_num_bin_aligend_.data(),
need_fix_histogram_features_num_bin_aligend_.size(), __FILE__, __LINE__);
}
void CUDAHistogramConstructor::ResetConfig(const Config* config) {
num_threads_ = OMP_NUM_THREADS();
num_leaves_ = config->num_leaves;
min_data_in_leaf_ = config->min_data_in_leaf;
min_sum_hessian_in_leaf_ = config->min_sum_hessian_in_leaf;
DeallocateCUDAMemory<hist_t>(&cuda_hist_, __FILE__, __LINE__);
AllocateCUDAMemory<hist_t>(&cuda_hist_, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
SetCUDAMemory<hist_t>(cuda_hist_, 0, num_total_bin_ * 2 * num_leaves_, __FILE__, __LINE__);
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA_EXP
#include "cuda_histogram_constructor.hpp"
#include <LightGBM/cuda/cuda_algorithms.hpp>
#include <algorithm>
namespace LightGBM {
template <typename BIN_TYPE, typename HIST_TYPE, size_t SHARED_HIST_SIZE>
__global__ void CUDAConstructHistogramDenseKernel(
const CUDALeafSplitsStruct* smaller_leaf_splits,
const score_t* cuda_gradients,
const score_t* cuda_hessians,
const BIN_TYPE* data,
const uint32_t* column_hist_offsets,
const uint32_t* column_hist_offsets_full,
const int* feature_partition_column_index_offsets,
const data_size_t num_data) {
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
__shared__ HIST_TYPE shared_hist[SHARED_HIST_SIZE];
const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
const int partition_column_start = feature_partition_column_index_offsets[blockIdx.x];
const int partition_column_end = feature_partition_column_index_offsets[blockIdx.x + 1];
const BIN_TYPE* data_ptr = data + partition_column_start * num_data;
const int num_columns_in_partition = partition_column_end - partition_column_start;
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start) << 1;
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
shared_hist[i] = 0.0f;
}
__syncthreads();
const unsigned int blockIdx_y = blockIdx.y;
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread;
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
const int column_index = static_cast<int>(threadIdx.x) + partition_column_start;
if (threadIdx.x < static_cast<unsigned int>(num_columns_in_partition)) {
HIST_TYPE* shared_hist_ptr = shared_hist + (column_hist_offsets[column_index] << 1);
for (data_size_t inner_data_index = static_cast<data_size_t>(threadIdx.y); inner_data_index < block_num_data; inner_data_index += blockDim.y) {
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
const score_t grad = cuda_gradients[data_index];
const score_t hess = cuda_hessians[data_index];
const uint32_t bin = static_cast<uint32_t>(data_ptr[data_index * num_columns_in_partition + threadIdx.x]);
const uint32_t pos = bin << 1;
HIST_TYPE* pos_ptr = shared_hist_ptr + pos;
atomicAdd_block(pos_ptr, grad);
atomicAdd_block(pos_ptr + 1, hess);
}
}
__syncthreads();
hist_t* feature_histogram_ptr = smaller_leaf_splits->hist_in_leaf + (partition_hist_start << 1);
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
atomicAdd_system(feature_histogram_ptr + i, shared_hist[i]);
}
}
template <typename BIN_TYPE, typename DATA_PTR_TYPE, typename HIST_TYPE, size_t SHARED_HIST_SIZE>
__global__ void CUDAConstructHistogramSparseKernel(
const CUDALeafSplitsStruct* smaller_leaf_splits,
const score_t* cuda_gradients,
const score_t* cuda_hessians,
const BIN_TYPE* data,
const DATA_PTR_TYPE* row_ptr,
const DATA_PTR_TYPE* partition_ptr,
const uint32_t* column_hist_offsets_full,
const data_size_t num_data) {
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
__shared__ HIST_TYPE shared_hist[SHARED_HIST_SIZE];
const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
const DATA_PTR_TYPE* block_row_ptr = row_ptr + blockIdx.x * (num_data + 1);
const BIN_TYPE* data_ptr = data + partition_ptr[blockIdx.x];
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start) << 1;
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
shared_hist[i] = 0.0f;
}
__syncthreads();
const unsigned int threadIdx_y = threadIdx.y;
const unsigned int blockIdx_y = blockIdx.y;
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread;
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y;
const data_size_t remainder = block_num_data % blockDim.y;
const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast<data_size_t>(threadIdx_y >= remainder);
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
for (data_size_t i = 0; i < num_iteration_this; ++i) {
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
const DATA_PTR_TYPE row_start = block_row_ptr[data_index];
const DATA_PTR_TYPE row_end = block_row_ptr[data_index + 1];
const DATA_PTR_TYPE row_size = row_end - row_start;
if (threadIdx.x < row_size) {
const score_t grad = cuda_gradients[data_index];
const score_t hess = cuda_hessians[data_index];
const uint32_t bin = static_cast<uint32_t>(data_ptr[row_start + threadIdx.x]);
const uint32_t pos = bin << 1;
HIST_TYPE* pos_ptr = shared_hist + pos;
atomicAdd_block(pos_ptr, grad);
atomicAdd_block(pos_ptr + 1, hess);
}
inner_data_index += blockDim.y;
}
__syncthreads();
hist_t* feature_histogram_ptr = smaller_leaf_splits->hist_in_leaf + (partition_hist_start << 1);
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
atomicAdd_system(feature_histogram_ptr + i, shared_hist[i]);
}
}
template <typename BIN_TYPE>
__global__ void CUDAConstructHistogramDenseKernel_GlobalMemory(
const CUDALeafSplitsStruct* smaller_leaf_splits,
const score_t* cuda_gradients,
const score_t* cuda_hessians,
const BIN_TYPE* data,
const uint32_t* column_hist_offsets,
const uint32_t* column_hist_offsets_full,
const int* feature_partition_column_index_offsets,
const data_size_t num_data,
float* global_hist_buffer) {
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
const int partition_column_start = feature_partition_column_index_offsets[blockIdx.x];
const int partition_column_end = feature_partition_column_index_offsets[blockIdx.x + 1];
const BIN_TYPE* data_ptr = data + partition_column_start * num_data;
const int num_columns_in_partition = partition_column_end - partition_column_start;
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start) << 1;
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
const int num_total_bin = column_hist_offsets_full[gridDim.x];
float* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
shared_hist[i] = 0.0f;
}
__syncthreads();
const unsigned int threadIdx_y = threadIdx.y;
const unsigned int blockIdx_y = blockIdx.y;
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread;
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y;
const data_size_t remainder = block_num_data % blockDim.y;
const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast<data_size_t>(threadIdx_y >= remainder);
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
const int column_index = static_cast<int>(threadIdx.x) + partition_column_start;
if (threadIdx.x < static_cast<unsigned int>(num_columns_in_partition)) {
float* shared_hist_ptr = shared_hist + (column_hist_offsets[column_index] << 1);
for (data_size_t i = 0; i < num_iteration_this; ++i) {
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
const score_t grad = cuda_gradients[data_index];
const score_t hess = cuda_hessians[data_index];
const uint32_t bin = static_cast<uint32_t>(data_ptr[data_index * num_columns_in_partition + threadIdx.x]);
const uint32_t pos = bin << 1;
float* pos_ptr = shared_hist_ptr + pos;
atomicAdd_block(pos_ptr, grad);
atomicAdd_block(pos_ptr + 1, hess);
inner_data_index += blockDim.y;
}
}
__syncthreads();
hist_t* feature_histogram_ptr = smaller_leaf_splits->hist_in_leaf + (partition_hist_start << 1);
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
atomicAdd_system(feature_histogram_ptr + i, shared_hist[i]);
}
}
template <typename BIN_TYPE, typename DATA_PTR_TYPE>
__global__ void CUDAConstructHistogramSparseKernel_GlobalMemory(
const CUDALeafSplitsStruct* smaller_leaf_splits,
const score_t* cuda_gradients,
const score_t* cuda_hessians,
const BIN_TYPE* data,
const DATA_PTR_TYPE* row_ptr,
const DATA_PTR_TYPE* partition_ptr,
const uint32_t* column_hist_offsets_full,
const data_size_t num_data,
float* global_hist_buffer) {
const int dim_y = static_cast<int>(gridDim.y * blockDim.y);
const data_size_t num_data_in_smaller_leaf = smaller_leaf_splits->num_data_in_leaf;
const data_size_t num_data_per_thread = (num_data_in_smaller_leaf + dim_y - 1) / dim_y;
const data_size_t* data_indices_ref = smaller_leaf_splits->data_indices_in_leaf;
const unsigned int num_threads_per_block = blockDim.x * blockDim.y;
const DATA_PTR_TYPE* block_row_ptr = row_ptr + blockIdx.x * (num_data + 1);
const BIN_TYPE* data_ptr = data + partition_ptr[blockIdx.x];
const uint32_t partition_hist_start = column_hist_offsets_full[blockIdx.x];
const uint32_t partition_hist_end = column_hist_offsets_full[blockIdx.x + 1];
const uint32_t num_items_in_partition = (partition_hist_end - partition_hist_start) << 1;
const unsigned int thread_idx = threadIdx.x + threadIdx.y * blockDim.x;
const int num_total_bin = column_hist_offsets_full[gridDim.x];
float* shared_hist = global_hist_buffer + (blockIdx.y * num_total_bin + partition_hist_start) * 2;
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
shared_hist[i] = 0.0f;
}
__syncthreads();
const unsigned int threadIdx_y = threadIdx.y;
const unsigned int blockIdx_y = blockIdx.y;
const data_size_t block_start = (blockIdx_y * blockDim.y) * num_data_per_thread;
const data_size_t* data_indices_ref_this_block = data_indices_ref + block_start;
data_size_t block_num_data = max(0, min(num_data_in_smaller_leaf - block_start, num_data_per_thread * static_cast<data_size_t>(blockDim.y)));
const data_size_t num_iteration_total = (block_num_data + blockDim.y - 1) / blockDim.y;
const data_size_t remainder = block_num_data % blockDim.y;
const data_size_t num_iteration_this = remainder == 0 ? num_iteration_total : num_iteration_total - static_cast<data_size_t>(threadIdx_y >= remainder);
data_size_t inner_data_index = static_cast<data_size_t>(threadIdx_y);
for (data_size_t i = 0; i < num_iteration_this; ++i) {
const data_size_t data_index = data_indices_ref_this_block[inner_data_index];
const DATA_PTR_TYPE row_start = block_row_ptr[data_index];
const DATA_PTR_TYPE row_end = block_row_ptr[data_index + 1];
const DATA_PTR_TYPE row_size = row_end - row_start;
if (threadIdx.x < row_size) {
const score_t grad = cuda_gradients[data_index];
const score_t hess = cuda_hessians[data_index];
const uint32_t bin = static_cast<uint32_t>(data_ptr[row_start + threadIdx.x]);
const uint32_t pos = bin << 1;
float* pos_ptr = shared_hist + pos;
atomicAdd_block(pos_ptr, grad);
atomicAdd_block(pos_ptr + 1, hess);
}
inner_data_index += blockDim.y;
}
__syncthreads();
hist_t* feature_histogram_ptr = smaller_leaf_splits->hist_in_leaf + (partition_hist_start << 1);
for (unsigned int i = thread_idx; i < num_items_in_partition; i += num_threads_per_block) {
atomicAdd_system(feature_histogram_ptr + i, shared_hist[i]);
}
}
void CUDAHistogramConstructor::LaunchConstructHistogramKernel(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf) {
if (cuda_row_data_->shared_hist_size() == DP_SHARED_HIST_SIZE && gpu_use_dp_) {
LaunchConstructHistogramKernelInner<double, DP_SHARED_HIST_SIZE>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
} else if (cuda_row_data_->shared_hist_size() == SP_SHARED_HIST_SIZE && !gpu_use_dp_) {
LaunchConstructHistogramKernelInner<float, SP_SHARED_HIST_SIZE>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
} else {
Log::Fatal("Unknown shared histogram size %d", cuda_row_data_->shared_hist_size());
}
}
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE>
void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf) {
if (cuda_row_data_->bit_type() == 8) {
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint8_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
} else if (cuda_row_data_->bit_type() == 16) {
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
} else if (cuda_row_data_->bit_type() == 32) {
LaunchConstructHistogramKernelInner0<HIST_TYPE, SHARED_HIST_SIZE, uint32_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
} else {
Log::Fatal("Unknown bit_type = %d", cuda_row_data_->bit_type());
}
}
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE>
void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner0(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf) {
if (cuda_row_data_->row_ptr_bit_type() == 16) {
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint16_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
} else if (cuda_row_data_->row_ptr_bit_type() == 32) {
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint32_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
} else if (cuda_row_data_->row_ptr_bit_type() == 64) {
LaunchConstructHistogramKernelInner1<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, uint64_t>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
} else {
Log::Fatal("Unknown row_ptr_bit_type = %d", cuda_row_data_->row_ptr_bit_type());
}
}
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE>
void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner1(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf) {
if (cuda_row_data_->NumLargeBinPartition() == 0) {
LaunchConstructHistogramKernelInner2<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, PTR_TYPE, false>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
} else {
LaunchConstructHistogramKernelInner2<HIST_TYPE, SHARED_HIST_SIZE, BIN_TYPE, PTR_TYPE, true>(cuda_smaller_leaf_splits, num_data_in_smaller_leaf);
}
}
template <typename HIST_TYPE, size_t SHARED_HIST_SIZE, typename BIN_TYPE, typename PTR_TYPE, bool USE_GLOBAL_MEM_BUFFER>
void CUDAHistogramConstructor::LaunchConstructHistogramKernelInner2(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const data_size_t num_data_in_smaller_leaf) {
int grid_dim_x = 0;
int grid_dim_y = 0;
int block_dim_x = 0;
int block_dim_y = 0;
CalcConstructHistogramKernelDim(&grid_dim_x, &grid_dim_y, &block_dim_x, &block_dim_y, num_data_in_smaller_leaf);
dim3 grid_dim(grid_dim_x, grid_dim_y);
dim3 block_dim(block_dim_x, block_dim_y);
if (!USE_GLOBAL_MEM_BUFFER) {
if (cuda_row_data_->is_sparse()) {
CUDAConstructHistogramSparseKernel<BIN_TYPE, PTR_TYPE, HIST_TYPE, SHARED_HIST_SIZE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
cuda_gradients_, cuda_hessians_,
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
cuda_row_data_->cuda_partition_hist_offsets(),
num_data_);
} else {
CUDAConstructHistogramDenseKernel<BIN_TYPE, HIST_TYPE, SHARED_HIST_SIZE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
cuda_gradients_, cuda_hessians_,
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->cuda_column_hist_offsets(),
cuda_row_data_->cuda_partition_hist_offsets(),
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
num_data_);
}
} else {
if (cuda_row_data_->is_sparse()) {
CUDAConstructHistogramSparseKernel_GlobalMemory<BIN_TYPE, PTR_TYPE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
cuda_gradients_, cuda_hessians_,
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->GetRowPtr<PTR_TYPE>(),
cuda_row_data_->GetPartitionPtr<PTR_TYPE>(),
cuda_row_data_->cuda_partition_hist_offsets(),
num_data_,
cuda_hist_buffer_);
} else {
CUDAConstructHistogramDenseKernel_GlobalMemory<BIN_TYPE><<<grid_dim, block_dim, 0, cuda_stream_>>>(
cuda_smaller_leaf_splits,
cuda_gradients_, cuda_hessians_,
cuda_row_data_->GetBin<BIN_TYPE>(),
cuda_row_data_->cuda_column_hist_offsets(),
cuda_row_data_->cuda_partition_hist_offsets(),
cuda_row_data_->cuda_feature_partition_column_index_offsets(),
num_data_,
cuda_hist_buffer_);
}
}
}
__global__ void SubtractHistogramKernel(
const int num_total_bin,
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits) {
const unsigned int global_thread_index = threadIdx.x + blockIdx.x * blockDim.x;
const int cuda_larger_leaf_index = cuda_larger_leaf_splits->leaf_index;
if (cuda_larger_leaf_index >= 0) {
const hist_t* smaller_leaf_hist = cuda_smaller_leaf_splits->hist_in_leaf;
hist_t* larger_leaf_hist = cuda_larger_leaf_splits->hist_in_leaf;
if (global_thread_index < 2 * num_total_bin) {
larger_leaf_hist[global_thread_index] -= smaller_leaf_hist[global_thread_index];
}
}
}
__global__ void FixHistogramKernel(
const uint32_t* cuda_feature_num_bins,
const uint32_t* cuda_feature_hist_offsets,
const uint32_t* cuda_feature_most_freq_bins,
const int* cuda_need_fix_histogram_features,
const uint32_t* cuda_need_fix_histogram_features_num_bin_aligned,
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits) {
__shared__ hist_t shared_mem_buffer[32];
const unsigned int blockIdx_x = blockIdx.x;
const int feature_index = cuda_need_fix_histogram_features[blockIdx_x];
const uint32_t num_bin_aligned = cuda_need_fix_histogram_features_num_bin_aligned[blockIdx_x];
const uint32_t feature_hist_offset = cuda_feature_hist_offsets[feature_index];
const uint32_t most_freq_bin = cuda_feature_most_freq_bins[feature_index];
const double leaf_sum_gradients = cuda_smaller_leaf_splits->sum_of_gradients;
const double leaf_sum_hessians = cuda_smaller_leaf_splits->sum_of_hessians;
hist_t* feature_hist = cuda_smaller_leaf_splits->hist_in_leaf + feature_hist_offset * 2;
const unsigned int threadIdx_x = threadIdx.x;
const uint32_t num_bin = cuda_feature_num_bins[feature_index];
const uint32_t hist_pos = threadIdx_x << 1;
const hist_t bin_gradient = (threadIdx_x < num_bin && threadIdx_x != most_freq_bin) ? feature_hist[hist_pos] : 0.0f;
const hist_t bin_hessian = (threadIdx_x < num_bin && threadIdx_x != most_freq_bin) ? feature_hist[hist_pos + 1] : 0.0f;
const hist_t sum_gradient = ShuffleReduceSum<hist_t>(bin_gradient, shared_mem_buffer, num_bin_aligned);
const hist_t sum_hessian = ShuffleReduceSum<hist_t>(bin_hessian, shared_mem_buffer, num_bin_aligned);
if (threadIdx_x == 0) {
feature_hist[most_freq_bin << 1] = leaf_sum_gradients - sum_gradient;
feature_hist[(most_freq_bin << 1) + 1] = leaf_sum_hessians - sum_hessian;
}
}
void CUDAHistogramConstructor::LaunchSubtractHistogramKernel(
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits,
const CUDALeafSplitsStruct* cuda_larger_leaf_splits) {
const int num_subtract_threads = 2 * num_total_bin_;
const int num_subtract_blocks = (num_subtract_threads + SUBTRACT_BLOCK_SIZE - 1) / SUBTRACT_BLOCK_SIZE;
global_timer.Start("CUDAHistogramConstructor::FixHistogramKernel");
if (need_fix_histogram_features_.size() > 0) {
FixHistogramKernel<<<need_fix_histogram_features_.size(), FIX_HISTOGRAM_BLOCK_SIZE, 0, cuda_stream_>>>(
cuda_feature_num_bins_,
cuda_feature_hist_offsets_,
cuda_feature_most_freq_bins_,
cuda_need_fix_histogram_features_,
cuda_need_fix_histogram_features_num_bin_aligned_,
cuda_smaller_leaf_splits);
}
global_timer.Stop("CUDAHistogramConstructor::FixHistogramKernel");
global_timer.Start("CUDAHistogramConstructor::SubtractHistogramKernel");
SubtractHistogramKernel<<<num_subtract_blocks, SUBTRACT_BLOCK_SIZE, 0, cuda_stream_>>>(
num_total_bin_,
cuda_smaller_leaf_splits,
cuda_larger_leaf_splits);
global_timer.Stop("CUDAHistogramConstructor::SubtractHistogramKernel");
}
} // namespace LightGBM
#endif // USE_CUDA_EXP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment