Commit 0bb4a825 authored by Huan Zhang's avatar Huan Zhang Committed by Guolin Ke
Browse files

Initial GPU acceleration support for LightGBM (#368)

* add dummy gpu solver code

* initial GPU code

* fix crash bug

* first working version

* use asynchronous copy

* use a better kernel for root

* parallel read histogram

* sparse features now works, but no acceleration, compute on CPU

* compute sparse feature on CPU simultaneously

* fix big bug; add gpu selection; add kernel selection

* better debugging

* clean up

* add feature scatter

* Add sparse_threshold control

* fix a bug in feature scatter

* clean up debug

* temporarily add OpenCL kernels for k=64,256

* fix up CMakeList and definition USE_GPU

* add OpenCL kernels as string literals

* Add boost.compute as a submodule

* add boost dependency into CMakeList

* fix opencl pragma

* use pinned memory for histogram

* use pinned buffer for gradients and hessians

* better debugging message

* add double precision support on GPU

* fix boost version in CMakeList

* Add a README

* reconstruct GPU initialization code for ResetTrainingData

* move data to GPU in parallel

* fix a bug during feature copy

* update gpu kernels

* update gpu code

* initial port to LightGBM v2

* speedup GPU data loading process

* Add 4-bit bin support to GPU

* re-add sparse_threshold parameter

* remove kMaxNumWorkgroups and allows an unlimited number of features

* add feature mask support for skipping unused features

* enable kernel cache

* use GPU kernels withoug feature masks when all features are used

* REAdme.

* REAdme.

* update README

* fix typos (#349)

* change compile to gcc on Apple as default

* clean vscode related file

* refine api of constructing from sampling data.

* fix bug in the last commit.

* more efficient algorithm to sample k from n.

* fix bug in filter bin

* change to boost from average output.

* fix tests.

* only stop training when all classes are finshed in multi-class.

* limit the max tree output. change hessian in multi-class objective.

* robust tree model loading.

* fix test.

* convert the probabilities to raw score in boost_from_average of classification.

* fix the average label for binary classification.

* Add boost_from_average to docs (#354)

* don't use "ConvertToRawScore" for self-defined objective function.

* boost_from_average seems doesn't work well in binary classification. remove it.

* For a better jump link (#355)

* Update Python-API.md

* for a better jump in page

A space is needed between `#` and the headers content according to Github's markdown format [guideline](https://guides.github.com/features/mastering-markdown/)

After adding the spaces, we can jump to the exact position in page by click the link.

* fixed something mentioned by @wxchan

* Update Python-API.md

* add FitByExistingTree.

* adapt GPU tree learner for FitByExistingTree

* avoid NaN output.

* update boost.compute

* fix typos (#361)

* fix broken links (#359)

* update README

* disable GPU acceleration by default

* fix image url

* cleanup debug macro

* remove old README

* do not save sparse_threshold_ in FeatureGroup

* add details for new GPU settings

* ignore submodule when doing pep8 check

* allocate workspace for at least one thread during builing Feature4

* move sparse_threshold to class Dataset

* remove duplicated code in GPUTreeLearner::Split

* Remove duplicated code in FindBestThresholds and BeforeFindBestSplit

* do not rebuild ordered gradients and hessians for sparse features

* support feature groups in GPUTreeLearner

* Initial parallel learners with GPU support

* add option device, cleanup code

* clean up FindBestThresholds; add some omp parallel

* constant hessian optimization for GPU

* Fix GPUTreeLearner crash when there is zero feature

* use np.testing.assert_almost_equal() to compare lists of floats in tests

* travis for GPU
parent db3d1f89
#ifndef LIGHTGBM_TREELEARNER_GPU_TREE_LEARNER_H_
#define LIGHTGBM_TREELEARNER_GPU_TREE_LEARNER_H_
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/dataset.h>
#include <LightGBM/tree.h>
#include <LightGBM/feature_group.h>
#include "feature_histogram.hpp"
#include "serial_tree_learner.h"
#include "data_partition.hpp"
#include "split_info.hpp"
#include "leaf_splits.hpp"
#include <cstdio>
#include <vector>
#include <random>
#include <cmath>
#include <memory>
#ifdef USE_GPU
#define BOOST_COMPUTE_THREAD_SAFE
#define BOOST_COMPUTE_HAVE_THREAD_LOCAL
// Use Boost.Compute on-disk kernel cache
#define BOOST_COMPUTE_USE_OFFLINE_CACHE
#include <boost/compute/core.hpp>
#include <boost/compute/container/vector.hpp>
#include <boost/align/aligned_allocator.hpp>
namespace LightGBM {
/*!
* \brief GPU-based parallel learning algorithm.
*/
class GPUTreeLearner: public SerialTreeLearner {
public:
explicit GPUTreeLearner(const TreeConfig* tree_config);
~GPUTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingData(const Dataset* train_data) override;
Tree* Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian) override;
void SetBaggingData(const data_size_t* used_indices, data_size_t num_data) override {
SerialTreeLearner::SetBaggingData(used_indices, num_data);
// determine if we are using bagging before we construct the data partition
// thus we can start data movement to GPU earlier
if (used_indices != nullptr) {
if (num_data != num_data_) {
use_bagging_ = true;
return;
}
}
use_bagging_ = false;
}
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
void FindBestThresholds() override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
private:
/*! \brief 4-byte feature tuple used by GPU kernels */
struct Feature4 {
union {
unsigned char s[4];
struct {
unsigned char s0;
unsigned char s1;
unsigned char s2;
unsigned char s3;
};
};
};
/*! \brief Single precision histogram entiry for GPU */
struct GPUHistogramBinEntry {
score_t sum_gradients;
score_t sum_hessians;
uint32_t cnt;
};
/*!
* \brief Find the best number of workgroups processing one feature for maximizing efficiency
* \param leaf_num_data The number of data examples on the current leaf being processed
* \return Log2 of the best number for workgroups per feature, in range 0...kMaxLogWorkgroupsPerFeature
*/
int GetNumWorkgroupsPerFeature(data_size_t leaf_num_data);
/*!
* \brief Initialize GPU device, context and command queues
* Also compiles the OpenCL kernel
* \param platform_id OpenCL platform ID
* \param device_id OpenCL device ID
*/
void InitGPU(int platform_id, int device_id);
/*!
* \brief Allocate memory for GPU computation
*/
void AllocateGPUMemory();
/*!
* \brief Compile OpenCL GPU source code to kernel binaries
*/
void BuildGPUKernels();
/*!
* \brief Setup GPU kernel arguments, preparing for launching
*/
void SetupKernelArguments();
/*!
* \brief Compute GPU feature histogram for the current leaf.
* Indices, gradients and hessians have been copied to the device.
* \param leaf_num_data Number of data on current leaf
* \param use_all_features Set to true to not use feature masks, with a faster kernel
*/
void GPUHistogram(data_size_t leaf_num_data, bool use_all_features);
/*!
* \brief Wait for GPU kernel execution and read histogram
* \param histograms Destination of histogram results from GPU.
* \param is_feature_used A predicate vector for enabling each feature
*/
template <typename HistType>
void WaitAndGetHistograms(HistogramBinEntry* histograms, const std::vector<int8_t>& is_feature_used);
/*!
* \brief Construct GPU histogram asynchronously.
* Interface is similar to Dataset::ConstructHistograms().
* \param is_feature_used A predicate vector for enabling each feature
* \param data_indices Array of data example IDs to be included in histogram, will be copied to GPU.
* Set to nullptr to skip copy to GPU.
* \param num_data Number of data examples to be included in histogram
* \param gradients Array of gradients for all examples.
* \param hessians Array of hessians for all examples.
* \param ordered_gradients Ordered gradients will be generated and copied to GPU when gradients is not nullptr,
* Set gradients to nullptr to skip copy to GPU.
* \param ordered_hessians Ordered hessians will be generated and copied to GPU when hessians is not nullptr,
* Set hessians to nullptr to skip copy to GPU.
* \return true if GPU kernel is launched, false if GPU is not used
*/
bool ConstructGPUHistogramsAsync(
const std::vector<int8_t>& is_feature_used,
const data_size_t* data_indices, data_size_t num_data,
const score_t* gradients, const score_t* hessians,
score_t* ordered_gradients, score_t* ordered_hessians);
/*! brief Log2 of max number of workgroups per feature*/
const int kMaxLogWorkgroupsPerFeature = 10; // 2^10
/*! brief Max total number of workgroups with preallocated workspace.
* If we use more than this number of workgroups, we have to reallocate subhistograms */
int preallocd_max_num_wg_ = 1024;
/*! \brief True if bagging is used */
bool use_bagging_;
/*! \brief GPU device object */
boost::compute::device dev_;
/*! \brief GPU context object */
boost::compute::context ctx_;
/*! \brief GPU command queue object */
boost::compute::command_queue queue_;
/*! \brief GPU kernel for 256 bins */
const char *kernel256_src_ =
#include "ocl/histogram256.cl"
;
/*! \brief GPU kernel for 64 bins */
const char *kernel64_src_ =
#include "ocl/histogram64.cl"
;
/*! \brief GPU kernel for 64 bins */
const char *kernel16_src_ =
#include "ocl/histogram16.cl"
;
/*! \brief Currently used kernel source */
std::string kernel_source_;
/*! \brief Currently used kernel name */
std::string kernel_name_;
/*! \brief a array of histogram kernels with different number
of workgroups per feature */
std::vector<boost::compute::kernel> histogram_kernels_;
/*! \brief a array of histogram kernels with different number
of workgroups per feature, with all features enabled to avoid branches */
std::vector<boost::compute::kernel> histogram_allfeats_kernels_;
/*! \brief a array of histogram kernels with different number
of workgroups per feature, and processing the whole dataset */
std::vector<boost::compute::kernel> histogram_fulldata_kernels_;
/*! \brief total number of feature-groups */
int num_feature_groups_;
/*! \brief total number of dense feature-groups, which will be processed on GPU */
int num_dense_feature_groups_;
/*! \brief On GPU we read one DWORD (4-byte) of features of one example once.
* With bin size > 16, there are 4 features per DWORD.
* With bin size <=16, there are 8 features per DWORD.
* */
int dword_features_;
/*! \brief total number of dense feature-group tuples on GPU.
* Each feature tuple is 4-byte (4 features if each feature takes a byte) */
int num_dense_feature4_;
/*! \brief Max number of bins of training data, used to determine
* which GPU kernel to use */
int max_num_bin_;
/*! \brief Used GPU kernel bin size (64, 256) */
int device_bin_size_;
/*! \brief Size of histogram bin entry, depending if single or double precision is used */
size_t hist_bin_entry_sz_;
/*! \brief Indices of all dense feature-groups */
std::vector<int> dense_feature_group_map_;
/*! \brief Indices of all sparse feature-groups */
std::vector<int> sparse_feature_group_map_;
/*! \brief Multipliers of all dense feature-groups, used for redistributing bins */
std::vector<int> device_bin_mults_;
/*! \brief GPU memory object holding the training data */
std::unique_ptr<boost::compute::vector<Feature4>> device_features_;
/*! \brief GPU memory object holding the ordered gradient */
boost::compute::buffer device_gradients_;
/*! \brief Pinned memory object for ordered gradient */
boost::compute::buffer pinned_gradients_;
/*! \brief Pointer to pinned memory of ordered gradient */
void * ptr_pinned_gradients_ = nullptr;
/*! \brief GPU memory object holding the ordered hessian */
boost::compute::buffer device_hessians_;
/*! \brief Pinned memory object for ordered hessian */
boost::compute::buffer pinned_hessians_;
/*! \brief Pointer to pinned memory of ordered hessian */
void * ptr_pinned_hessians_ = nullptr;
/*! \brief A vector of feature mask. 1 = feature used, 0 = feature not used */
std::vector<char, boost::alignment::aligned_allocator<char, 4096>> feature_masks_;
/*! \brief GPU memory object holding the feature masks */
boost::compute::buffer device_feature_masks_;
/*! \brief Pinned memory object for feature masks */
boost::compute::buffer pinned_feature_masks_;
/*! \brief Pointer to pinned memory of feature masks */
void * ptr_pinned_feature_masks_ = nullptr;
/*! \brief GPU memory object holding indices of the leaf being processed */
std::unique_ptr<boost::compute::vector<data_size_t>> device_data_indices_;
/*! \brief GPU memory object holding counters for workgroup coordination */
std::unique_ptr<boost::compute::vector<int>> sync_counters_;
/*! \brief GPU memory object holding temporary sub-histograms per workgroup */
std::unique_ptr<boost::compute::vector<char>> device_subhistograms_;
/*! \brief Host memory object for histogram output (GPU will write to Host memory directly) */
boost::compute::buffer device_histogram_outputs_;
/*! \brief Host memory pointer for histogram outputs */
void * host_histogram_outputs_;
/*! \brief OpenCL waitlist object for waiting for data transfer before kernel execution */
boost::compute::wait_list kernel_wait_obj_;
/*! \brief OpenCL waitlist object for reading output histograms after kernel execution */
boost::compute::wait_list histograms_wait_obj_;
/*! \brief Asynchronous waiting object for copying indices */
boost::compute::future<void> indices_future_;
/*! \brief Asynchronous waiting object for copying gradients */
boost::compute::event gradients_future_;
/*! \brief Asynchronous waiting object for copying hessians */
boost::compute::event hessians_future_;
};
} // namespace LightGBM
#else
// When GPU support is not compiled in, quit with an error message
namespace LightGBM {
class GPUTreeLearner: public SerialTreeLearner {
public:
explicit GPUTreeLearner(const TreeConfig* tree_config) : SerialTreeLearner(tree_config) {
Log::Fatal("GPU Tree Learner was not enabled in this build. Recompile with CMake option -DUSE_GPU=1");
}
};
}
#endif // USE_GPU
#endif // LightGBM_TREELEARNER_GPU_TREE_LEARNER_H_
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <LightGBM/network.h> #include <LightGBM/network.h>
#include "serial_tree_learner.h" #include "serial_tree_learner.h"
#include "gpu_tree_learner.h"
#include <cstring> #include <cstring>
...@@ -18,11 +19,12 @@ namespace LightGBM { ...@@ -18,11 +19,12 @@ namespace LightGBM {
* Different machine will find best split on different features, then sync global best split * Different machine will find best split on different features, then sync global best split
* It is recommonded used when #data is small or #feature is large * It is recommonded used when #data is small or #feature is large
*/ */
class FeatureParallelTreeLearner: public SerialTreeLearner { template <typename TREELEARNER_T>
class FeatureParallelTreeLearner: public TREELEARNER_T {
public: public:
explicit FeatureParallelTreeLearner(const TreeConfig* tree_config); explicit FeatureParallelTreeLearner(const TreeConfig* tree_config);
~FeatureParallelTreeLearner(); ~FeatureParallelTreeLearner();
void Init(const Dataset* train_data) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
...@@ -43,11 +45,12 @@ private: ...@@ -43,11 +45,12 @@ private:
* Workers use local data to construct histograms locally, then sync up global histograms. * Workers use local data to construct histograms locally, then sync up global histograms.
* It is recommonded used when #data is large or #feature is small * It is recommonded used when #data is large or #feature is small
*/ */
class DataParallelTreeLearner: public SerialTreeLearner { template <typename TREELEARNER_T>
class DataParallelTreeLearner: public TREELEARNER_T {
public: public:
explicit DataParallelTreeLearner(const TreeConfig* tree_config); explicit DataParallelTreeLearner(const TreeConfig* tree_config);
~DataParallelTreeLearner(); ~DataParallelTreeLearner();
void Init(const Dataset* train_data) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetConfig(const TreeConfig* tree_config) override; void ResetConfig(const TreeConfig* tree_config) override;
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
...@@ -95,11 +98,12 @@ private: ...@@ -95,11 +98,12 @@ private:
* Here using voting to reduce features, and only aggregate histograms for selected features. * Here using voting to reduce features, and only aggregate histograms for selected features.
* When #data is large and #feature is large, you can use this to have better speed-up * When #data is large and #feature is large, you can use this to have better speed-up
*/ */
class VotingParallelTreeLearner: public SerialTreeLearner { template <typename TREELEARNER_T>
class VotingParallelTreeLearner: public TREELEARNER_T {
public: public:
explicit VotingParallelTreeLearner(const TreeConfig* tree_config); explicit VotingParallelTreeLearner(const TreeConfig* tree_config);
~VotingParallelTreeLearner() { } ~VotingParallelTreeLearner() { }
void Init(const Dataset* train_data) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetConfig(const TreeConfig* tree_config) override; void ResetConfig(const TreeConfig* tree_config) override;
protected: protected:
void BeforeTrain() override; void BeforeTrain() override;
......
...@@ -37,10 +37,11 @@ SerialTreeLearner::~SerialTreeLearner() { ...@@ -37,10 +37,11 @@ SerialTreeLearner::~SerialTreeLearner() {
#endif #endif
} }
void SerialTreeLearner::Init(const Dataset* train_data) { void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) {
train_data_ = train_data; train_data_ = train_data;
num_data_ = train_data_->num_data(); num_data_ = train_data_->num_data();
num_features_ = train_data_->num_features(); num_features_ = train_data_->num_features();
is_constant_hessian_ = is_constant_hessian;
int max_cache_size = 0; int max_cache_size = 0;
// Get the max size of pool // Get the max size of pool
if (tree_config_->histogram_pool_size <= 0) { if (tree_config_->histogram_pool_size <= 0) {
......
...@@ -18,6 +18,11 @@ ...@@ -18,6 +18,11 @@
#include <random> #include <random>
#include <cmath> #include <cmath>
#include <memory> #include <memory>
#ifdef USE_GPU
// Use 4KBytes aligned allocator for ordered gradients and ordered hessians when GPU is enabled.
// This is necessary to pin the two arrays in memory and make transferring faster.
#include <boost/align/aligned_allocator.hpp>
#endif
namespace LightGBM { namespace LightGBM {
...@@ -30,7 +35,7 @@ public: ...@@ -30,7 +35,7 @@ public:
~SerialTreeLearner(); ~SerialTreeLearner();
void Init(const Dataset* train_data) override; void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingData(const Dataset* train_data) override; void ResetTrainingData(const Dataset* train_data) override;
...@@ -69,7 +74,7 @@ protected: ...@@ -69,7 +74,7 @@ protected:
*/ */
virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf); virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf);
void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract); virtual void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
/*! /*!
* \brief Find best thresholds for all features, using multi-threading. * \brief Find best thresholds for all features, using multi-threading.
...@@ -130,10 +135,17 @@ protected: ...@@ -130,10 +135,17 @@ protected:
/*! \brief stores best thresholds for all feature for larger leaf */ /*! \brief stores best thresholds for all feature for larger leaf */
std::unique_ptr<LeafSplits> larger_leaf_splits_; std::unique_ptr<LeafSplits> larger_leaf_splits_;
#ifdef USE_GPU
/*! \brief gradients of current iteration, ordered for cache optimized, aligned to 4K page */
std::vector<score_t, boost::alignment::aligned_allocator<score_t, 4096>> ordered_gradients_;
/*! \brief hessians of current iteration, ordered for cache optimized, aligned to 4K page */
std::vector<score_t, boost::alignment::aligned_allocator<score_t, 4096>> ordered_hessians_;
#else
/*! \brief gradients of current iteration, ordered for cache optimized */ /*! \brief gradients of current iteration, ordered for cache optimized */
std::vector<score_t> ordered_gradients_; std::vector<score_t> ordered_gradients_;
/*! \brief hessians of current iteration, ordered for cache optimized */ /*! \brief hessians of current iteration, ordered for cache optimized */
std::vector<score_t> ordered_hessians_; std::vector<score_t> ordered_hessians_;
#endif
/*! \brief Store ordered bin */ /*! \brief Store ordered bin */
std::vector<std::unique_ptr<OrderedBin>> ordered_bins_; std::vector<std::unique_ptr<OrderedBin>> ordered_bins_;
......
#include <LightGBM/tree_learner.h> #include <LightGBM/tree_learner.h>
#include "serial_tree_learner.h" #include "serial_tree_learner.h"
#include "gpu_tree_learner.h"
#include "parallel_tree_learner.h" #include "parallel_tree_learner.h"
namespace LightGBM { namespace LightGBM {
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& type, const TreeConfig* tree_config) { TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const TreeConfig* tree_config) {
if (type == std::string("serial")) { if (device_type == std::string("cpu")) {
return new SerialTreeLearner(tree_config); if (learner_type == std::string("serial")) {
} else if (type == std::string("feature")) { return new SerialTreeLearner(tree_config);
return new FeatureParallelTreeLearner(tree_config); } else if (learner_type == std::string("feature")) {
} else if (type == std::string("data")) { return new FeatureParallelTreeLearner<SerialTreeLearner>(tree_config);
return new DataParallelTreeLearner(tree_config); } else if (learner_type == std::string("data")) {
} else if (type == std::string("voting")) { return new DataParallelTreeLearner<SerialTreeLearner>(tree_config);
return new VotingParallelTreeLearner(tree_config); } else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<SerialTreeLearner>(tree_config);
}
}
else if (device_type == std::string("gpu")) {
if (learner_type == std::string("serial")) {
return new GPUTreeLearner(tree_config);
} else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<GPUTreeLearner>(tree_config);
} else if (learner_type == std::string("data")) {
return new DataParallelTreeLearner<GPUTreeLearner>(tree_config);
} else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<GPUTreeLearner>(tree_config);
}
} }
return nullptr; return nullptr;
} }
......
...@@ -220,10 +220,10 @@ class TestEngine(unittest.TestCase): ...@@ -220,10 +220,10 @@ class TestEngine(unittest.TestCase):
gbm3.save_model('categorical.model') gbm3.save_model('categorical.model')
gbm4 = lgb.Booster(model_file='categorical.model') gbm4 = lgb.Booster(model_file='categorical.model')
pred4 = list(gbm4.predict(X_test)) pred4 = list(gbm4.predict(X_test))
self.assertListEqual(pred0, pred1) np.testing.assert_almost_equal(pred0, pred1)
self.assertListEqual(pred0, pred2) np.testing.assert_almost_equal(pred0, pred2)
self.assertListEqual(pred0, pred3) np.testing.assert_almost_equal(pred0, pred3)
self.assertListEqual(pred0, pred4) np.testing.assert_almost_equal(pred0, pred4)
print("----------------------------------------------------------------------") print("----------------------------------------------------------------------")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment