"include/vscode:/vscode.git/clone" did not exist on "8a19834a6644ca2598704fd1208a4b50cbedd02d"
Unverified Commit f7ad9457 authored by Chip Kerchner's avatar Chip Kerchner Committed by GitHub
Browse files

[GPU] Add support for CUDA-based GPU build (#3160)



* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* redirect log to python console (#3090)

* redir log to python console

* fix pylint

* Apply suggestions from code review

* Update basic.py

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update c_api.h

* Apply suggestions from code review

* Apply suggestions from code review

* super-minor: better wording
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarStrikerRUS <nekit94-12@hotmail.com>

* re-order includes (fixes #3132) (#3133)

* Revert "re-order includes (fixes #3132) (#3133)" (#3153)

This reverts commit 656d2676

.

* Missing change from previous rebase

* Minor cleanup and removal of development scripts.

* Only set gpu_use_dp on by default for CUDA. Other minor change.

* Fix python lint indentation problem.

* More python lint issues.

* Big lint cleanup - more to come.

* Another large lint cleanup - more to come.

* Even more lint cleanup.

* Minor cleanup so less differences in code.

* Revert is_use_subset changes

* Another rebase from master to fix recent conflicts.

* More lint.

* Simple code cleanup - add & remove blank lines, revert unneccessary format changes, remove added dead code.

* Removed parameters added for CUDA and various bug fix.

* Yet more lint and unneccessary changes.

* Revert another change.

* Removal of unneccessary code.

* temporary appveyor.yml for building and testing

* Remove return value in ReSize

* Removal of unused variables.

* Code cleanup from reviewers suggestions.

* Removal of FIXME comments and unused defines.

* More reviewers comments cleanup.

* More reviewers comments cleanup.

* More reviewers comments cleanup.

* Fix config variables.

* Attempt to fix check-docs failure

* Update Paramster.rst for num_gpu

* Removing test appveyor.yml

* Add ƒCUDA_RESOLVE_DEVICE_SYMBOLS to libraries to fix linking issue.

* Fixed handling of data elements less than 2K.

* More reviewers comments cleanup.

* Removal of TODO and fix printing of int64_t

* Add cuda change for CI testing and remove cuda from device_type in python.

* Missed one change form previous check-in

* Removal AdditionConfig and fix settings.

* Limit number of GPUs to one for now in CUDA.

* Update Parameters.rst for previous check-in

* Whitespace removal.

* Cleanup unused code.

* Changed uint/ushort/ulong to unsigned int/short/long to help Windows based CUDA compiler work.

* Lint change from previous check-in.

* Changes based on reviewers comments.

* More reviewer comment changes.

* Adding warning for is_sparse. Revert tmp_subset code. Only return FeatureGroupData if not is_multi_val_

* Fix so that CUDA code will compile even if you enable the SCORE_T_USE_DOUBLE define.

* Reviewer comment cleanup.

* Replace warning with Log message. Removal of some of the USE_CUDA. Fix typo and removal of pragma once.

* Remove PRINT debug for CUDA code.

* Allow to use of multiple GPUs for CUDA.

* More multi-GPUs enablement for CUDA.

* More code cleanup based on reviews comments.

* Update docs with latest config changes.
Co-authored-by: default avatarGordon Fossum <fossum@us.ibm.com>
Co-authored-by: default avatarChipKerchner <ckerchne@linux.vnet.ibm.com>
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarStrikerRUS <nekit94-12@hotmail.com>
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent 1fddabb5
/*!
* Copyright (c) 2020 IBM Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_TREELEARNER_CUDA_KERNEL_LAUNCHER_H_
#define LIGHTGBM_TREELEARNER_CUDA_KERNEL_LAUNCHER_H_
#ifdef USE_CUDA
#include <chrono>
#include "kernels/histogram_16_64_256.hu" // kernel, acc_type, data_size_t, uchar, score_t
namespace LightGBM {
struct ThreadData {
// device id
int device_id;
// parameters for cuda_histogram
int histogram_size;
data_size_t leaf_num_data;
data_size_t num_data;
bool use_all_features;
bool is_constant_hessian;
int num_workgroups;
cudaStream_t stream;
uint8_t* device_features;
uint8_t* device_feature_masks;
data_size_t* device_data_indices;
score_t* device_gradients;
score_t* device_hessians;
score_t hessians_const;
char* device_subhistograms;
volatile int* sync_counters;
void* device_histogram_outputs;
size_t exp_workgroups_per_feature;
// cuda events
cudaEvent_t* kernel_start;
cudaEvent_t* kernel_wait_obj;
std::chrono::duration<double, std::milli>* kernel_input_wait_time;
// copy histogram
size_t output_size;
char* host_histogram_output;
cudaEvent_t* histograms_wait_obj;
};
void cuda_histogram(
int histogram_size,
data_size_t leaf_num_data,
data_size_t num_data,
bool use_all_features,
bool is_constant_hessian,
int num_workgroups,
cudaStream_t stream,
uint8_t* arg0,
uint8_t* arg1,
data_size_t arg2,
data_size_t* arg3,
data_size_t arg4,
score_t* arg5,
score_t* arg6,
score_t arg6_const,
char* arg7,
volatile int* arg8,
void* arg9,
size_t exp_workgroups_per_feature);
} // namespace LightGBM
#endif // USE_CUDA
#endif // LIGHTGBM_TREELEARNER_CUDA_KERNEL_LAUNCHER_H_
This diff is collapsed.
/*!
* Copyright (c) 2020 IBM Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_TREELEARNER_CUDA_TREE_LEARNER_H_
#define LIGHTGBM_TREELEARNER_CUDA_TREE_LEARNER_H_
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/dataset.h>
#include <LightGBM/feature_group.h>
#include <LightGBM/tree.h>
#include <string>
#include <cmath>
#include <cstdio>
#include <memory>
#include <random>
#include <vector>
#ifdef USE_CUDA
#include <cuda_runtime.h>
#endif
#include "feature_histogram.hpp"
#include "serial_tree_learner.h"
#include "data_partition.hpp"
#include "split_info.hpp"
#include "leaf_splits.hpp"
#ifdef USE_CUDA
#include <LightGBM/cuda/vector_cudahost.h>
#include "cuda_kernel_launcher.h"
using json11::Json;
namespace LightGBM {
/*!
* \brief CUDA-based parallel learning algorithm.
*/
class CUDATreeLearner: public SerialTreeLearner {
public:
explicit CUDATreeLearner(const Config* tree_config);
~CUDATreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) override;
Tree* Train(const score_t* gradients, const score_t *hessians);
void SetBaggingData(const Dataset* subset, const data_size_t* used_indices, data_size_t num_data) override {
SerialTreeLearner::SetBaggingData(subset, used_indices, num_data);
if (subset == nullptr && used_indices != nullptr) {
if (num_data != num_data_) {
use_bagging_ = true;
return;
}
}
use_bagging_ = false;
}
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
void FindBestSplits(const Tree* tree) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
private:
typedef float gpu_hist_t;
/*!
* \brief Find the best number of workgroups processing one feature for maximizing efficiency
* \param leaf_num_data The number of data examples on the current leaf being processed
* \return Log2 of the best number for workgroups per feature, in range 0...kMaxLogWorkgroupsPerFeature
*/
int GetNumWorkgroupsPerFeature(data_size_t leaf_num_data);
/*!
* \brief Initialize GPU device
* \param num_gpu: number of maximum gpus
*/
void InitGPU(int num_gpu);
/*!
* \brief Allocate memory for GPU computation // alloc only
*/
void CountDenseFeatureGroups(); // compute num_dense_feature_group
void prevAllocateGPUMemory(); // compute CPU-side param calculation & Pin HostMemory
void AllocateGPUMemory();
/*!
* \ ResetGPUMemory
*/
void ResetGPUMemory();
/*!
* \ copy dense feature from CPU to GPU
*/
void copyDenseFeature();
/*!
* \brief Compute GPU feature histogram for the current leaf.
* Indices, gradients and hessians have been copied to the device.
* \param leaf_num_data Number of data on current leaf
* \param use_all_features Set to true to not use feature masks, with a faster kernel
*/
void GPUHistogram(data_size_t leaf_num_data, bool use_all_features);
void SetThreadData(ThreadData* thread_data, int device_id, int histogram_size,
int leaf_num_data, bool use_all_features,
int num_workgroups, int exp_workgroups_per_feature) {
ThreadData* td = &thread_data[device_id];
td->device_id = device_id;
td->histogram_size = histogram_size;
td->leaf_num_data = leaf_num_data;
td->num_data = num_data_;
td->use_all_features = use_all_features;
td->is_constant_hessian = share_state_->is_constant_hessian;
td->num_workgroups = num_workgroups;
td->stream = stream_[device_id];
td->device_features = device_features_[device_id];
td->device_feature_masks = reinterpret_cast<uint8_t *>(device_feature_masks_[device_id]);
td->device_data_indices = device_data_indices_[device_id];
td->device_gradients = device_gradients_[device_id];
td->device_hessians = device_hessians_[device_id];
td->hessians_const = hessians_[0];
td->device_subhistograms = device_subhistograms_[device_id];
td->sync_counters = sync_counters_[device_id];
td->device_histogram_outputs = device_histogram_outputs_[device_id];
td->exp_workgroups_per_feature = exp_workgroups_per_feature;
td->kernel_start = &(kernel_start_[device_id]);
td->kernel_wait_obj = &(kernel_wait_obj_[device_id]);
td->kernel_input_wait_time = &(kernel_input_wait_time_[device_id]);
size_t output_size = num_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_;
size_t host_output_offset = offset_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_;
td->output_size = output_size;
td->host_histogram_output = reinterpret_cast<char*>(host_histogram_outputs_) + host_output_offset;
td->histograms_wait_obj = &(histograms_wait_obj_[device_id]);
}
/*!
* \brief Wait for GPU kernel execution and read histogram
* \param histograms Destination of histogram results from GPU.
*/
template <typename HistType>
void WaitAndGetHistograms(FeatureHistogram* leaf_histogram_array);
/*!
* \brief Construct GPU histogram asynchronously.
* Interface is similar to Dataset::ConstructHistograms().
* \param is_feature_used A predicate vector for enabling each feature
* \param data_indices Array of data example IDs to be included in histogram, will be copied to GPU.
* Set to nullptr to skip copy to GPU.
* \param num_data Number of data examples to be included in histogram
* \return true if GPU kernel is launched, false if GPU is not used
*/
bool ConstructGPUHistogramsAsync(
const std::vector<int8_t>& is_feature_used,
const data_size_t* data_indices, data_size_t num_data);
/*! brief Log2 of max number of workgroups per feature*/
const int kMaxLogWorkgroupsPerFeature = 10; // 2^10
/*! brief Max total number of workgroups with preallocated workspace.
* If we use more than this number of workgroups, we have to reallocate subhistograms */
std::vector<int> preallocd_max_num_wg_;
/*! \brief True if bagging is used */
bool use_bagging_;
/*! \brief GPU command queue object */
std::vector<cudaStream_t> stream_;
/*! \brief total number of feature-groups */
int num_feature_groups_;
/*! \brief total number of dense feature-groups, which will be processed on GPU */
int num_dense_feature_groups_;
std::vector<int> num_gpu_feature_groups_;
std::vector<int> offset_gpu_feature_groups_;
/*! \brief On GPU we read one DWORD (4-byte) of features of one example once.
* With bin size > 16, there are 4 features per DWORD.
* With bin size <=16, there are 8 features per DWORD.
*/
int dword_features_;
/*! \brief Max number of bins of training data, used to determine
* which GPU kernel to use */
int max_num_bin_;
/*! \brief Used GPU kernel bin size (64, 256) */
int histogram_size_;
int device_bin_size_;
/*! \brief Size of histogram bin entry, depending if single or double precision is used */
size_t hist_bin_entry_sz_;
/*! \brief Indices of all dense feature-groups */
std::vector<int> dense_feature_group_map_;
/*! \brief Indices of all sparse feature-groups */
std::vector<int> sparse_feature_group_map_;
/*! \brief GPU memory object holding the training data */
std::vector<uint8_t*> device_features_;
/*! \brief GPU memory object holding the ordered gradient */
std::vector<score_t*> device_gradients_;
/*! \brief Pointer to pinned memory of ordered gradient */
void * ptr_pinned_gradients_ = nullptr;
/*! \brief GPU memory object holding the ordered hessian */
std::vector<score_t*> device_hessians_;
/*! \brief Pointer to pinned memory of ordered hessian */
void * ptr_pinned_hessians_ = nullptr;
/*! \brief A vector of feature mask. 1 = feature used, 0 = feature not used */
std::vector<char> feature_masks_;
/*! \brief GPU memory object holding the feature masks */
std::vector<char*> device_feature_masks_;
/*! \brief Pointer to pinned memory of feature masks */
char* ptr_pinned_feature_masks_ = nullptr;
/*! \brief GPU memory object holding indices of the leaf being processed */
std::vector<data_size_t*> device_data_indices_;
/*! \brief GPU memory object holding counters for workgroup coordination */
std::vector<int*> sync_counters_;
/*! \brief GPU memory object holding temporary sub-histograms per workgroup */
std::vector<char*> device_subhistograms_;
/*! \brief Host memory object for histogram output (GPU will write to Host memory directly) */
std::vector<void*> device_histogram_outputs_;
/*! \brief Host memory pointer for histogram outputs */
void *host_histogram_outputs_;
/*! CUDA waitlist object for waiting for data transfer before kernel execution */
std::vector<cudaEvent_t> kernel_wait_obj_;
/*! CUDA waitlist object for reading output histograms after kernel execution */
std::vector<cudaEvent_t> histograms_wait_obj_;
/*! CUDA Asynchronous waiting object for copying indices */
std::vector<cudaEvent_t> indices_future_;
/*! Asynchronous waiting object for copying gradients */
std::vector<cudaEvent_t> gradients_future_;
/*! Asynchronous waiting object for copying hessians */
std::vector<cudaEvent_t> hessians_future_;
/*! Asynchronous waiting object for copying dense features */
std::vector<cudaEvent_t> features_future_;
// host-side buffer for converting feature data into featre4 data
int nthreads_; // number of Feature4* vector on host4_vecs_
std::vector<cudaEvent_t> kernel_start_;
std::vector<float> kernel_time_; // measure histogram kernel time
std::vector<std::chrono::duration<double, std::milli>> kernel_input_wait_time_;
int num_gpu_;
int allocated_num_data_; // allocated data instances
pthread_t **cpu_threads_; // pthread, 1 cpu thread / gpu
};
} // namespace LightGBM
#else // USE_CUDA
// When GPU support is not compiled in, quit with an error message
namespace LightGBM {
class CUDATreeLearner: public SerialTreeLearner {
public:
#pragma warning(disable : 4702)
explicit CUDATreeLearner(const Config* tree_config) : SerialTreeLearner(tree_config) {
Log::Fatal("CUDA Tree Learner was not enabled in this build.\n"
"Please recompile with CMake option -DUSE_CUDA=1");
}
};
} // namespace LightGBM
#endif // USE_CUDA
#endif // LIGHTGBM_TREELEARNER_CUDA_TREE_LEARNER_H_
......@@ -256,6 +256,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, in
}
// instantiate template classes, otherwise linker cannot find the code
template class DataParallelTreeLearner<CUDATreeLearner>;
template class DataParallelTreeLearner<GPUTreeLearner>;
template class DataParallelTreeLearner<SerialTreeLearner>;
......
......@@ -77,6 +77,7 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(
}
// instantiate template classes, otherwise linker cannot find the code
template class FeatureParallelTreeLearner<CUDATreeLearner>;
template class FeatureParallelTreeLearner<GPUTreeLearner>;
template class FeatureParallelTreeLearner<SerialTreeLearner>;
} // namespace LightGBM
......@@ -52,7 +52,7 @@ void PrintHistograms(hist_t* h, size_t size) {
double total_hess = 0;
for (size_t i = 0; i < size; ++i) {
printf("%03lu=%9.3g,%9.3g\t", i, GET_GRAD(h, i), GET_HESS(h, i));
if ((i & 2) == 2)
if ((i & 3) == 3)
printf("\n");
total_hess += GET_HESS(h, i);
}
......@@ -1068,10 +1068,10 @@ void GPUTreeLearner::FindBestSplits(const Tree* tree) {
}
size_t bin_size = train_data_->FeatureNumBin(feature_index) + 1;
printf("Feature %d smaller leaf:\n", feature_index);
PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - 1, bin_size);
PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size);
if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; }
printf("Feature %d larger leaf:\n", feature_index);
PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - 1, bin_size);
PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size);
}
#endif
}
......
This diff is collapsed.
/*!
* Copyright (c) 2020 IBM Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_
#define LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_
#include "LightGBM/meta.h"
namespace LightGBM {
// use double precision or not
#ifndef USE_DP_FLOAT
#define USE_DP_FLOAT 1
#endif
// ignore hessian, and use the local memory for hessian as an additional bank for gradient
#ifndef CONST_HESSIAN
#define CONST_HESSIAN 0
#endif
typedef unsigned char uchar;
template<typename T>
__device__ double as_double(const T t) {
static_assert(sizeof(T) == sizeof(double), "size mismatch");
double d;
memcpy(&d, &t, sizeof(T));
return d;
}
template<typename T>
__device__ unsigned long long as_ulong_ulong(const T t) {
static_assert(sizeof(T) == sizeof(unsigned long long), "size mismatch");
unsigned long long u;
memcpy(&u, &t, sizeof(T));
return u;
}
template<typename T>
__device__ float as_float(const T t) {
static_assert(sizeof(T) == sizeof(float), "size mismatch");
float f;
memcpy(&f, &t, sizeof(T));
return f;
}
template<typename T>
__device__ unsigned int as_uint(const T t) {
static_assert(sizeof(T) == sizeof(unsigned int), "size_mismatch");
unsigned int u;
memcpy(&u, &t, sizeof(T));
return u;
}
template<typename T>
__device__ uchar4 as_uchar4(const T t) {
static_assert(sizeof(T) == sizeof(uchar4), "size mismatch");
uchar4 u;
memcpy(&u, &t, sizeof(T));
return u;
}
#if USE_DP_FLOAT == 1
typedef double acc_type;
typedef unsigned long long acc_int_type;
#define as_acc_type as_double
#define as_acc_int_type as_ulong_ulong
#else
typedef float acc_type;
typedef unsigned int acc_int_type;
#define as_acc_type as_float
#define as_acc_int_type as_uint
#endif
// use all features and do not use feature mask
#ifndef ENABLE_ALL_FEATURES
#define ENABLE_ALL_FEATURES 1
#endif
// define all of the different kernels
#define DECLARE_CONST_BUF(name) \
__global__ void name(__global const uchar* restrict feature_data_base, \
const uchar* restrict feature_masks,\
const data_size_t feature_size,\
const data_size_t* restrict data_indices, \
const data_size_t num_data, \
const score_t* restrict ordered_gradients, \
const score_t* restrict ordered_hessians,\
char* __restrict__ output_buf,\
volatile int * sync_counters,\
acc_type* __restrict__ hist_buf_base, \
const size_t power_feature_workgroups);
#define DECLARE_CONST_HES_CONST_BUF(name) \
__global__ void name(const uchar* __restrict__ feature_data_base, \
const uchar* __restrict__ feature_masks,\
const data_size_t feature_size,\
const data_size_t* __restrict__ data_indices, \
const data_size_t num_data, \
const score_t* __restrict__ ordered_gradients, \
const score_t const_hessian,\
char* __restrict__ output_buf,\
volatile int * sync_counters,\
acc_type* __restrict__ hist_buf_base, \
const size_t power_feature_workgroups);
#define DECLARE_CONST_HES(name) \
__global__ void name(const uchar* feature_data_base, \
const uchar* __restrict__ feature_masks,\
const data_size_t feature_size,\
const data_size_t* data_indices, \
const data_size_t num_data, \
const score_t* ordered_gradients, \
const score_t const_hessian,\
char* __restrict__ output_buf, \
volatile int * sync_counters,\
acc_type* __restrict__ hist_buf_base, \
const size_t power_feature_workgroups);
#define DECLARE(name) \
__global__ void name(const uchar* feature_data_base, \
const uchar* __restrict__ feature_masks,\
const data_size_t feature_size,\
const data_size_t* data_indices, \
const data_size_t num_data, \
const score_t* ordered_gradients, \
const score_t* ordered_hessians,\
char* __restrict__ output_buf, \
volatile int * sync_counters,\
acc_type* __restrict__ hist_buf_base, \
const size_t power_feature_workgroups);
DECLARE_CONST_HES(histogram16_allfeats);
DECLARE_CONST_HES(histogram16_fulldata);
DECLARE_CONST_HES(histogram16);
DECLARE(histogram16_allfeats);
DECLARE(histogram16_fulldata);
DECLARE(histogram16);
DECLARE_CONST_HES(histogram64_allfeats);
DECLARE_CONST_HES(histogram64_fulldata);
DECLARE_CONST_HES(histogram64);
DECLARE(histogram64_allfeats);
DECLARE(histogram64_fulldata);
DECLARE(histogram64);
DECLARE_CONST_HES(histogram256_allfeats);
DECLARE_CONST_HES(histogram256_fulldata);
DECLARE_CONST_HES(histogram256);
DECLARE(histogram256_allfeats);
DECLARE(histogram256_fulldata);
DECLARE(histogram256);
} // namespace LightGBM
#endif // LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_
......@@ -12,6 +12,7 @@
#include <memory>
#include <vector>
#include "cuda_tree_learner.h"
#include "gpu_tree_learner.h"
#include "serial_tree_learner.h"
......
......@@ -326,7 +326,16 @@ void SerialTreeLearner::FindBestSplits(const Tree* tree) {
is_feature_used[feature_index] = 1;
}
bool use_subtract = parent_leaf_histogram_array_ != nullptr;
#ifdef USE_CUDA
if (LGBM_config_::current_learner == use_cpu_learner) {
SerialTreeLearner::ConstructHistograms(is_feature_used, use_subtract);
} else {
ConstructHistograms(is_feature_used, use_subtract);
}
#else
ConstructHistograms(is_feature_used, use_subtract);
#endif
FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree);
}
......
......@@ -8,6 +8,7 @@
#include <LightGBM/dataset.h>
#include <LightGBM/tree.h>
#include <LightGBM/tree_learner.h>
#include <LightGBM/cuda/vector_cudahost.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/json11.h>
#include <LightGBM/utils/random.h>
......@@ -201,6 +202,11 @@ class SerialTreeLearner: public TreeLearner {
std::vector<score_t, boost::alignment::aligned_allocator<score_t, 4096>> ordered_gradients_;
/*! \brief hessians of current iteration, ordered for cache optimized, aligned to 4K page */
std::vector<score_t, boost::alignment::aligned_allocator<score_t, 4096>> ordered_hessians_;
#elif USE_CUDA
/*! \brief gradients of current iteration, ordered for cache optimized */
std::vector<score_t, CHAllocator<score_t>> ordered_gradients_;
/*! \brief hessians of current iteration, ordered for cache optimized */
std::vector<score_t, CHAllocator<score_t>> ordered_hessians_;
#else
/*! \brief gradients of current iteration, ordered for cache optimized */
std::vector<score_t, Common::AlignmentAllocator<score_t, kAlignedSize>> ordered_gradients_;
......
......@@ -4,6 +4,7 @@
*/
#include <LightGBM/tree_learner.h>
#include "cuda_tree_learner.h"
#include "gpu_tree_learner.h"
#include "parallel_tree_learner.h"
#include "serial_tree_learner.h"
......@@ -31,6 +32,16 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con
} else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<GPUTreeLearner>(config);
}
} else if (device_type == std::string("cuda")) {
if (learner_type == std::string("serial")) {
return new CUDATreeLearner(config);
} else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<CUDATreeLearner>(config);
} else if (learner_type == std::string("data")) {
return new DataParallelTreeLearner<CUDATreeLearner>(config);
} else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<CUDATreeLearner>(config);
}
}
return nullptr;
}
......
......@@ -454,6 +454,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf,
}
// instantiate template classes, otherwise linker cannot find the code
template class VotingParallelTreeLearner<CUDATreeLearner>;
template class VotingParallelTreeLearner<GPUTreeLearner>;
template class VotingParallelTreeLearner<SerialTreeLearner>;
} // namespace LightGBM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment