Unverified Commit f7ad9457 authored by Chip Kerchner's avatar Chip Kerchner Committed by GitHub
Browse files

[GPU] Add support for CUDA-based GPU build (#3160)



* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* redirect log to python console (#3090)

* redir log to python console

* fix pylint

* Apply suggestions from code review

* Update basic.py

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update c_api.h

* Apply suggestions from code review

* Apply suggestions from code review

* super-minor: better wording
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarStrikerRUS <nekit94-12@hotmail.com>

* re-order includes (fixes #3132) (#3133)

* Revert "re-order includes (fixes #3132) (#3133)" (#3153)

This reverts commit 656d2676

.

* Missing change from previous rebase

* Minor cleanup and removal of development scripts.

* Only set gpu_use_dp on by default for CUDA. Other minor change.

* Fix python lint indentation problem.

* More python lint issues.

* Big lint cleanup - more to come.

* Another large lint cleanup - more to come.

* Even more lint cleanup.

* Minor cleanup so less differences in code.

* Revert is_use_subset changes

* Another rebase from master to fix recent conflicts.

* More lint.

* Simple code cleanup - add & remove blank lines, revert unneccessary format changes, remove added dead code.

* Removed parameters added for CUDA and various bug fix.

* Yet more lint and unneccessary changes.

* Revert another change.

* Removal of unneccessary code.

* temporary appveyor.yml for building and testing

* Remove return value in ReSize

* Removal of unused variables.

* Code cleanup from reviewers suggestions.

* Removal of FIXME comments and unused defines.

* More reviewers comments cleanup.

* More reviewers comments cleanup.

* More reviewers comments cleanup.

* Fix config variables.

* Attempt to fix check-docs failure

* Update Paramster.rst for num_gpu

* Removing test appveyor.yml

* Add ƒCUDA_RESOLVE_DEVICE_SYMBOLS to libraries to fix linking issue.

* Fixed handling of data elements less than 2K.

* More reviewers comments cleanup.

* Removal of TODO and fix printing of int64_t

* Add cuda change for CI testing and remove cuda from device_type in python.

* Missed one change form previous check-in

* Removal AdditionConfig and fix settings.

* Limit number of GPUs to one for now in CUDA.

* Update Parameters.rst for previous check-in

* Whitespace removal.

* Cleanup unused code.

* Changed uint/ushort/ulong to unsigned int/short/long to help Windows based CUDA compiler work.

* Lint change from previous check-in.

* Changes based on reviewers comments.

* More reviewer comment changes.

* Adding warning for is_sparse. Revert tmp_subset code. Only return FeatureGroupData if not is_multi_val_

* Fix so that CUDA code will compile even if you enable the SCORE_T_USE_DOUBLE define.

* Reviewer comment cleanup.

* Replace warning with Log message. Removal of some of the USE_CUDA. Fix typo and removal of pragma once.

* Remove PRINT debug for CUDA code.

* Allow to use of multiple GPUs for CUDA.

* More multi-GPUs enablement for CUDA.

* More code cleanup based on reviews comments.

* Update docs with latest config changes.
Co-authored-by: default avatarGordon Fossum <fossum@us.ibm.com>
Co-authored-by: default avatarChipKerchner <ckerchne@linux.vnet.ibm.com>
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarStrikerRUS <nekit94-12@hotmail.com>
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent 1fddabb5
/*!
* Copyright (c) 2020 IBM Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_TREELEARNER_CUDA_KERNEL_LAUNCHER_H_
#define LIGHTGBM_TREELEARNER_CUDA_KERNEL_LAUNCHER_H_
#ifdef USE_CUDA
#include <chrono>
#include "kernels/histogram_16_64_256.hu" // kernel, acc_type, data_size_t, uchar, score_t
namespace LightGBM {
struct ThreadData {
// device id
int device_id;
// parameters for cuda_histogram
int histogram_size;
data_size_t leaf_num_data;
data_size_t num_data;
bool use_all_features;
bool is_constant_hessian;
int num_workgroups;
cudaStream_t stream;
uint8_t* device_features;
uint8_t* device_feature_masks;
data_size_t* device_data_indices;
score_t* device_gradients;
score_t* device_hessians;
score_t hessians_const;
char* device_subhistograms;
volatile int* sync_counters;
void* device_histogram_outputs;
size_t exp_workgroups_per_feature;
// cuda events
cudaEvent_t* kernel_start;
cudaEvent_t* kernel_wait_obj;
std::chrono::duration<double, std::milli>* kernel_input_wait_time;
// copy histogram
size_t output_size;
char* host_histogram_output;
cudaEvent_t* histograms_wait_obj;
};
void cuda_histogram(
int histogram_size,
data_size_t leaf_num_data,
data_size_t num_data,
bool use_all_features,
bool is_constant_hessian,
int num_workgroups,
cudaStream_t stream,
uint8_t* arg0,
uint8_t* arg1,
data_size_t arg2,
data_size_t* arg3,
data_size_t arg4,
score_t* arg5,
score_t* arg6,
score_t arg6_const,
char* arg7,
volatile int* arg8,
void* arg9,
size_t exp_workgroups_per_feature);
} // namespace LightGBM
#endif // USE_CUDA
#endif // LIGHTGBM_TREELEARNER_CUDA_KERNEL_LAUNCHER_H_
/*!
* Copyright (c) 2020 IBM Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifdef USE_CUDA
#include "cuda_tree_learner.h"
#include <LightGBM/bin.h>
#include <LightGBM/network.h>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/common.h>
#include <pthread.h>
#include <algorithm>
#include <cinttypes>
#include <vector>
#include "../io/dense_bin.hpp"
namespace LightGBM {
#define cudaMemcpy_DEBUG 0 // 1: DEBUG cudaMemcpy
#define ResetTrainingData_DEBUG 0 // 1: Debug ResetTrainingData
#define CUDA_DEBUG 0
static void *launch_cuda_histogram(void *thread_data) {
ThreadData td = *(reinterpret_cast<ThreadData*>(thread_data));
int device_id = td.device_id;
CUDASUCCESS_OR_FATAL(cudaSetDevice(device_id));
// launch cuda kernel
cuda_histogram(td.histogram_size,
td.leaf_num_data, td.num_data, td.use_all_features,
td.is_constant_hessian, td.num_workgroups, td.stream,
td.device_features,
td.device_feature_masks,
td.num_data,
td.device_data_indices,
td.leaf_num_data,
td.device_gradients,
td.device_hessians, td.hessians_const,
td.device_subhistograms, td.sync_counters,
td.device_histogram_outputs,
td.exp_workgroups_per_feature);
CUDASUCCESS_OR_FATAL(cudaGetLastError());
return NULL;
}
CUDATreeLearner::CUDATreeLearner(const Config* config)
:SerialTreeLearner(config) {
use_bagging_ = false;
nthreads_ = 0;
if (config->gpu_use_dp && USE_DP_FLOAT) {
Log::Info("LightGBM using CUDA trainer with DP float!!");
} else {
Log::Info("LightGBM using CUDA trainer with SP float!!");
}
}
CUDATreeLearner::~CUDATreeLearner() {
}
void CUDATreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) {
// initialize SerialTreeLearner
SerialTreeLearner::Init(train_data, is_constant_hessian);
// some additional variables needed for GPU trainer
num_feature_groups_ = train_data_->num_feature_groups();
// Initialize GPU buffers and kernels: get device info
InitGPU(config_->num_gpu);
}
// some functions used for debugging the GPU histogram construction
#if CUDA_DEBUG > 0
void PrintHistograms(hist_t* h, size_t size) {
double total_hess = 0;
for (size_t i = 0; i < size; ++i) {
printf("%03lu=%9.3g,%9.3g\t", i, GET_GRAD(h, i), GET_HESS(h, i));
if ((i & 3) == 3)
printf("\n");
total_hess += GET_HESS(h, i);
}
printf("\nSum hessians: %9.3g\n", total_hess);
}
union Float_t {
int64_t i;
double f;
static int64_t ulp_diff(Float_t a, Float_t b) {
return abs(a.i - b.i);
}
};
int CompareHistograms(hist_t* h1, hist_t* h2, size_t size, int feature_id, int dp_flag, int const_flag) {
int i;
int retval = 0;
printf("Comparing Histograms, feature_id = %d, size = %d\n", feature_id, static_cast<int>(size));
if (dp_flag) { // double precision
double af, bf;
int64_t ai, bi;
for (i = 0; i < static_cast<int>(size); ++i) {
af = GET_GRAD(h1, i);
bf = GET_GRAD(h2, i);
if ((((std::fabs(af - bf))/af) >= 1e-6) && ((std::fabs(af - bf)) >= 1e-6)) {
printf("i = %5d, h1.grad %13.6lf, h2.grad %13.6lf\n", i, af, bf);
++retval;
}
if (const_flag) {
ai = GET_HESS((reinterpret_cast<int64_t *>(h1)), i);
bi = GET_HESS((reinterpret_cast<int64_t *>(h2)), i);
if (ai != bi) {
printf("i = %5d, h1.hess %" PRId64 ", h2.hess %" PRId64 "\n", i, ai, bi);
++retval;
}
} else {
af = GET_HESS(h1, i);
bf = GET_HESS(h2, i);
if (((std::fabs(af - bf))/af) >= 1e-6) {
printf("i = %5d, h1.hess %13.6lf, h2.hess %13.6lf\n", i, af, bf);
++retval;
}
}
}
} else { // single precision
float af, bf;
int ai, bi;
for (i = 0; i < static_cast<int>(size); ++i) {
af = GET_GRAD(h1, i);
bf = GET_GRAD(h2, i);
if ((((std::fabs(af - bf))/af) >= 1e-6) && ((std::fabs(af - bf)) >= 1e-6)) {
printf("i = %5d, h1.grad %13.6f, h2.grad %13.6f\n", i, af, bf);
++retval;
}
if (const_flag) {
ai = GET_HESS(h1, i);
bi = GET_HESS(h2, i);
if (ai != bi) {
printf("i = %5d, h1.hess %d, h2.hess %d\n", i, ai, bi);
++retval;
}
} else {
af = GET_HESS(h1, i);
bf = GET_HESS(h2, i);
if (((std::fabs(af - bf))/af) >= 1e-5) {
printf("i = %5d, h1.hess %13.6f, h2.hess %13.6f\n", i, af, bf);
++retval;
}
}
}
}
printf("DONE Comparing Histograms...\n");
return retval;
}
#endif
int CUDATreeLearner::GetNumWorkgroupsPerFeature(data_size_t leaf_num_data) {
// we roughly want 256 workgroups per device, and we have num_dense_feature4_ feature tuples.
// also guarantee that there are at least 2K examples per workgroup
double x = 256.0 / num_dense_feature_groups_;
int exp_workgroups_per_feature = static_cast<int>(ceil(log2(x)));
double t = leaf_num_data / 1024.0;
Log::Debug("We can have at most %d workgroups per feature4 for efficiency reasons\n"
"Best workgroup size per feature for full utilization is %d\n", static_cast<int>(ceil(t)), (1 << exp_workgroups_per_feature));
exp_workgroups_per_feature = std::min(exp_workgroups_per_feature, static_cast<int>(ceil(log(static_cast<double>(t))/log(2.0))));
if (exp_workgroups_per_feature < 0)
exp_workgroups_per_feature = 0;
if (exp_workgroups_per_feature > kMaxLogWorkgroupsPerFeature)
exp_workgroups_per_feature = kMaxLogWorkgroupsPerFeature;
return exp_workgroups_per_feature;
}
void CUDATreeLearner::GPUHistogram(data_size_t leaf_num_data, bool use_all_features) {
// we have already copied ordered gradients, ordered hessians and indices to GPU
// decide the best number of workgroups working on one feature4 tuple
// set work group size based on feature size
// each 2^exp_workgroups_per_feature workgroups work on a feature4 tuple
int exp_workgroups_per_feature = GetNumWorkgroupsPerFeature(leaf_num_data);
std::vector<int> num_gpu_workgroups;
ThreadData *thread_data = reinterpret_cast<ThreadData*>(_mm_malloc(sizeof(ThreadData) * num_gpu_, 16));
for (int device_id = 0; device_id < num_gpu_; ++device_id) {
int num_gpu_feature_groups = num_gpu_feature_groups_[device_id];
int num_workgroups = (1 << exp_workgroups_per_feature) * num_gpu_feature_groups;
num_gpu_workgroups.push_back(num_workgroups);
if (num_workgroups > preallocd_max_num_wg_[device_id]) {
preallocd_max_num_wg_.at(device_id) = num_workgroups;
CUDASUCCESS_OR_FATAL(cudaFree(device_subhistograms_[device_id]));
CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_subhistograms_[device_id]), static_cast<size_t>(num_workgroups * dword_features_ * device_bin_size_ * (3 * hist_bin_entry_sz_ / 2))));
}
// set thread_data
SetThreadData(thread_data, device_id, histogram_size_, leaf_num_data, use_all_features,
num_workgroups, exp_workgroups_per_feature);
}
for (int device_id = 0; device_id < num_gpu_; ++device_id) {
if (pthread_create(cpu_threads_[device_id], NULL, launch_cuda_histogram, reinterpret_cast<void *>(&thread_data[device_id]))) {
Log::Fatal("Error in creating threads.");
}
}
/* Wait for the threads to finish */
for (int device_id = 0; device_id < num_gpu_; ++device_id) {
if (pthread_join(*(cpu_threads_[device_id]), NULL)) {
Log::Fatal("Error in joining threads.");
}
}
for (int device_id = 0; device_id < num_gpu_; ++device_id) {
// copy the results asynchronously. Size depends on if double precision is used
size_t output_size = num_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_;
size_t host_output_offset = offset_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_;
CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(reinterpret_cast<char*>(host_histogram_outputs_) + host_output_offset, device_histogram_outputs_[device_id], output_size, cudaMemcpyDeviceToHost, stream_[device_id]));
CUDASUCCESS_OR_FATAL(cudaEventRecord(histograms_wait_obj_[device_id], stream_[device_id]));
}
}
template <typename HistType>
void CUDATreeLearner::WaitAndGetHistograms(FeatureHistogram* leaf_histogram_array) {
HistType* hist_outputs = reinterpret_cast<HistType*>(host_histogram_outputs_);
#pragma omp parallel for schedule(static, num_gpu_)
for (int device_id = 0; device_id < num_gpu_; ++device_id) {
// when the output is ready, the computation is done
CUDASUCCESS_OR_FATAL(cudaEventSynchronize(histograms_wait_obj_[device_id]));
}
HistType* histograms = reinterpret_cast<HistType*>(leaf_histogram_array[0].RawData() - kHistOffset);
#pragma omp parallel for schedule(static)
for (int i = 0; i < num_dense_feature_groups_; ++i) {
if (!feature_masks_[i]) {
continue;
}
int dense_group_index = dense_feature_group_map_[i];
auto old_histogram_array = histograms + train_data_->GroupBinBoundary(dense_group_index) * 2;
int bin_size = train_data_->FeatureGroupNumBin(dense_group_index);
for (int j = 0; j < bin_size; ++j) {
GET_GRAD(old_histogram_array, j) = GET_GRAD(hist_outputs, i * device_bin_size_+ j);
GET_HESS(old_histogram_array, j) = GET_HESS(hist_outputs, i * device_bin_size_+ j);
}
}
}
void CUDATreeLearner::CountDenseFeatureGroups() {
num_dense_feature_groups_ = 0;
for (int i = 0; i < num_feature_groups_; ++i) {
if (!train_data_->IsMultiGroup(i)) {
num_dense_feature_groups_++;
}
}
if (!num_dense_feature_groups_) {
Log::Warning("GPU acceleration is disabled because no non-trival dense features can be found");
}
}
void CUDATreeLearner::prevAllocateGPUMemory() {
// how many feature-group tuples we have
// leave some safe margin for prefetching
// 256 work-items per workgroup. Each work-item prefetches one tuple for that feature
allocated_num_data_ = std::max(num_data_ + 256 * (1 << kMaxLogWorkgroupsPerFeature), allocated_num_data_);
// clear sparse/dense maps
dense_feature_group_map_.clear();
sparse_feature_group_map_.clear();
// do nothing it there is no dense feature
if (!num_dense_feature_groups_) {
return;
}
// calculate number of feature groups per gpu
num_gpu_feature_groups_.resize(num_gpu_);
offset_gpu_feature_groups_.resize(num_gpu_);
int num_features_per_gpu = num_dense_feature_groups_ / num_gpu_;
int remain_features = num_dense_feature_groups_ - num_features_per_gpu * num_gpu_;
int offset = 0;
for (int i = 0; i < num_gpu_; ++i) {
offset_gpu_feature_groups_.at(i) = offset;
num_gpu_feature_groups_.at(i) = (i < remain_features) ? num_features_per_gpu + 1 : num_features_per_gpu;
offset += num_gpu_feature_groups_.at(i);
}
feature_masks_.resize(num_dense_feature_groups_);
Log::Debug("Resized feature masks");
ptr_pinned_feature_masks_ = feature_masks_.data();
Log::Debug("Memset pinned_feature_masks_");
memset(ptr_pinned_feature_masks_, 0, num_dense_feature_groups_);
// histogram bin entry size depends on the precision (single/double)
hist_bin_entry_sz_ = 2 * (config_->gpu_use_dp ? sizeof(hist_t) : sizeof(gpu_hist_t)); // two elements in this "size"
CUDASUCCESS_OR_FATAL(cudaHostAlloc(reinterpret_cast<void **>(&host_histogram_outputs_), static_cast<size_t>(num_dense_feature_groups_ * device_bin_size_ * hist_bin_entry_sz_), cudaHostAllocPortable));
nthreads_ = std::min(omp_get_max_threads(), num_dense_feature_groups_ / dword_features_);
nthreads_ = std::max(nthreads_, 1);
}
// allocate GPU memory for each GPU
void CUDATreeLearner::AllocateGPUMemory() {
#pragma omp parallel for schedule(static, num_gpu_)
for (int device_id = 0; device_id < num_gpu_; ++device_id) {
// do nothing it there is no gpu feature
int num_gpu_feature_groups = num_gpu_feature_groups_[device_id];
if (num_gpu_feature_groups) {
CUDASUCCESS_OR_FATAL(cudaSetDevice(device_id));
// allocate memory for all features
if (device_features_[device_id] != NULL) {
CUDASUCCESS_OR_FATAL(cudaFree(device_features_[device_id]));
}
CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_features_[device_id]), static_cast<size_t>(num_gpu_feature_groups * num_data_ * sizeof(uint8_t))));
Log::Debug("Allocated device_features_ addr=%p sz=%lu", device_features_[device_id], num_gpu_feature_groups * num_data_);
// allocate space for gradients and hessians on device
// we will copy gradients and hessians in after ordered_gradients_ and ordered_hessians_ are constructed
if (device_gradients_[device_id] != NULL) {
CUDASUCCESS_OR_FATAL(cudaFree(device_gradients_[device_id]));
}
if (device_hessians_[device_id] != NULL) {
CUDASUCCESS_OR_FATAL(cudaFree(device_hessians_[device_id]));
}
if (device_feature_masks_[device_id] != NULL) {
CUDASUCCESS_OR_FATAL(cudaFree(device_feature_masks_[device_id]));
}
CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_gradients_[device_id]), static_cast<size_t>(allocated_num_data_ * sizeof(score_t))));
CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_hessians_[device_id]), static_cast<size_t>(allocated_num_data_ * sizeof(score_t))));
CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_feature_masks_[device_id]), static_cast<size_t>(num_gpu_feature_groups)));
// copy indices to the device
if (device_data_indices_[device_id] != NULL) {
CUDASUCCESS_OR_FATAL(cudaFree(device_data_indices_[device_id]));
}
CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_data_indices_[device_id]), static_cast<size_t>(allocated_num_data_ * sizeof(data_size_t))));
CUDASUCCESS_OR_FATAL(cudaMemsetAsync(device_data_indices_[device_id], 0, allocated_num_data_ * sizeof(data_size_t), stream_[device_id]));
Log::Debug("Memset device_data_indices_");
// create output buffer, each feature has a histogram with device_bin_size_ bins,
// each work group generates a sub-histogram of dword_features_ features.
if (!device_subhistograms_[device_id]) {
// only initialize once here, as this will not need to change when ResetTrainingData() is called
CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_subhistograms_[device_id]), static_cast<size_t>(preallocd_max_num_wg_[device_id] * dword_features_ * device_bin_size_ * (3 * hist_bin_entry_sz_ / 2))));
Log::Debug("created device_subhistograms_: %p", device_subhistograms_[device_id]);
}
// create atomic counters for inter-group coordination
CUDASUCCESS_OR_FATAL(cudaFree(sync_counters_[device_id]));
CUDASUCCESS_OR_FATAL(cudaMalloc(&(sync_counters_[device_id]), static_cast<size_t>(num_gpu_feature_groups * sizeof(int))));
CUDASUCCESS_OR_FATAL(cudaMemsetAsync(sync_counters_[device_id], 0, num_gpu_feature_groups * sizeof(int), stream_[device_id]));
// The output buffer is allocated to host directly, to overlap compute and data transfer
CUDASUCCESS_OR_FATAL(cudaFree(device_histogram_outputs_[device_id]));
CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_histogram_outputs_[device_id]), static_cast<size_t>(num_gpu_feature_groups * device_bin_size_ * hist_bin_entry_sz_)));
}
}
}
void CUDATreeLearner::ResetGPUMemory() {
// clear sparse/dense maps
dense_feature_group_map_.clear();
sparse_feature_group_map_.clear();
}
void CUDATreeLearner::copyDenseFeature() {
if (num_feature_groups_ == 0) {
LGBM_config_::current_learner = use_cpu_learner;
return;
}
Log::Debug("Started copying dense features from CPU to GPU");
// find the dense feature-groups and group then into Feature4 data structure (several feature-groups packed into 4 bytes)
size_t copied_feature = 0;
// set device info
int device_id = 0;
uint8_t* device_features = device_features_[device_id];
CUDASUCCESS_OR_FATAL(cudaSetDevice(device_id));
Log::Debug("Started copying dense features from CPU to GPU - 1");
for (int i = 0; i < num_feature_groups_; ++i) {
// looking for dword_features_ non-sparse feature-groups
if (!train_data_->IsMultiGroup(i)) {
dense_feature_group_map_.push_back(i);
auto sizes_in_byte = train_data_->FeatureGroupSizesInByte(i);
void* tmp_data = train_data_->FeatureGroupData(i);
Log::Debug("Started copying dense features from CPU to GPU - 2");
CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(&device_features[copied_feature * num_data_], tmp_data, sizes_in_byte, cudaMemcpyHostToDevice, stream_[device_id]));
Log::Debug("Started copying dense features from CPU to GPU - 3");
copied_feature++;
// reset device info
if (copied_feature == static_cast<size_t>(num_gpu_feature_groups_[device_id])) {
CUDASUCCESS_OR_FATAL(cudaEventRecord(features_future_[device_id], stream_[device_id]));
device_id += 1;
copied_feature = 0;
if (device_id < num_gpu_) {
device_features = device_features_[device_id];
CUDASUCCESS_OR_FATAL(cudaSetDevice(device_id));
}
}
} else {
sparse_feature_group_map_.push_back(i);
}
}
}
// InitGPU w/ num_gpu
void CUDATreeLearner::InitGPU(int num_gpu) {
// Get the max bin size, used for selecting best GPU kernel
max_num_bin_ = 0;
#if CUDA_DEBUG >= 1
printf("bin_size: ");
#endif
for (int i = 0; i < num_feature_groups_; ++i) {
if (train_data_->IsMultiGroup(i)) {
continue;
}
#if CUDA_DEBUG >= 1
printf("%d, ", train_data_->FeatureGroupNumBin(i));
#endif
max_num_bin_ = std::max(max_num_bin_, train_data_->FeatureGroupNumBin(i));
}
#if CUDA_DEBUG >= 1
printf("\n");
#endif
if (max_num_bin_ <= 16) {
device_bin_size_ = 16;
histogram_size_ = 16;
dword_features_ = 1;
} else if (max_num_bin_ <= 64) {
device_bin_size_ = 64;
histogram_size_ = 64;
dword_features_ = 1;
} else if (max_num_bin_ <= 256) {
Log::Debug("device_bin_size_ = 256");
device_bin_size_ = 256;
histogram_size_ = 256;
dword_features_ = 1;
} else {
Log::Fatal("bin size %d cannot run on GPU", max_num_bin_);
}
if (max_num_bin_ == 65) {
Log::Warning("Setting max_bin to 63 is sugguested for best performance");
}
if (max_num_bin_ == 17) {
Log::Warning("Setting max_bin to 15 is sugguested for best performance");
}
// get num_dense_feature_groups_
CountDenseFeatureGroups();
if (num_gpu > num_dense_feature_groups_) num_gpu = num_dense_feature_groups_;
// initialize GPU
int gpu_count;
CUDASUCCESS_OR_FATAL(cudaGetDeviceCount(&gpu_count));
num_gpu_ = (gpu_count < num_gpu) ? gpu_count : num_gpu;
// set cpu threads
cpu_threads_ = reinterpret_cast<pthread_t **>(_mm_malloc(sizeof(pthread_t *)*num_gpu_, 16));
for (int device_id = 0; device_id < num_gpu_; ++device_id) {
cpu_threads_[device_id] = reinterpret_cast<pthread_t *>(_mm_malloc(sizeof(pthread_t), 16));
}
// resize device memory pointers
device_features_.resize(num_gpu_);
device_gradients_.resize(num_gpu_);
device_hessians_.resize(num_gpu_);
device_feature_masks_.resize(num_gpu_);
device_data_indices_.resize(num_gpu_);
sync_counters_.resize(num_gpu_);
device_subhistograms_.resize(num_gpu_);
device_histogram_outputs_.resize(num_gpu_);
// create stream & events to handle multiple GPUs
preallocd_max_num_wg_.resize(num_gpu_, 1024);
stream_.resize(num_gpu_);
hessians_future_.resize(num_gpu_);
gradients_future_.resize(num_gpu_);
indices_future_.resize(num_gpu_);
features_future_.resize(num_gpu_);
kernel_start_.resize(num_gpu_);
kernel_wait_obj_.resize(num_gpu_);
histograms_wait_obj_.resize(num_gpu_);
for (int i = 0; i < num_gpu_; ++i) {
CUDASUCCESS_OR_FATAL(cudaSetDevice(i));
CUDASUCCESS_OR_FATAL(cudaStreamCreate(&(stream_[i])));
CUDASUCCESS_OR_FATAL(cudaEventCreate(&(hessians_future_[i])));
CUDASUCCESS_OR_FATAL(cudaEventCreate(&(gradients_future_[i])));
CUDASUCCESS_OR_FATAL(cudaEventCreate(&(indices_future_[i])));
CUDASUCCESS_OR_FATAL(cudaEventCreate(&(features_future_[i])));
CUDASUCCESS_OR_FATAL(cudaEventCreate(&(kernel_start_[i])));
CUDASUCCESS_OR_FATAL(cudaEventCreate(&(kernel_wait_obj_[i])));
CUDASUCCESS_OR_FATAL(cudaEventCreate(&(histograms_wait_obj_[i])));
}
allocated_num_data_ = 0;
prevAllocateGPUMemory();
AllocateGPUMemory();
copyDenseFeature();
}
Tree* CUDATreeLearner::Train(const score_t* gradients, const score_t *hessians) {
Tree *ret = SerialTreeLearner::Train(gradients, hessians);
return ret;
}
void CUDATreeLearner::ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) {
// check data size
data_size_t old_allocated_num_data = allocated_num_data_;
SerialTreeLearner::ResetTrainingDataInner(train_data, is_constant_hessian, reset_multi_val_bin);
#if ResetTrainingData_DEBUG == 1
serial_time = std::chrono::steady_clock::now() - start_serial_time;
#endif
num_feature_groups_ = train_data_->num_feature_groups();
// GPU memory has to been reallocated because data may have been changed
#if ResetTrainingData_DEBUG == 1
auto start_alloc_gpu_time = std::chrono::steady_clock::now();
#endif
// AllocateGPUMemory only when the number of data increased
int old_num_feature_groups = num_dense_feature_groups_;
CountDenseFeatureGroups();
if ((old_allocated_num_data < (num_data_ + 256 * (1 << kMaxLogWorkgroupsPerFeature))) || (old_num_feature_groups < num_dense_feature_groups_)) {
prevAllocateGPUMemory();
AllocateGPUMemory();
} else {
ResetGPUMemory();
}
copyDenseFeature();
#if ResetTrainingData_DEBUG == 1
alloc_gpu_time = std::chrono::steady_clock::now() - start_alloc_gpu_time;
#endif
// setup GPU kernel arguments after we allocating all the buffers
#if ResetTrainingData_DEBUG == 1
auto start_set_arg_time = std::chrono::steady_clock::now();
#endif
#if ResetTrainingData_DEBUG == 1
set_arg_time = std::chrono::steady_clock::now() - start_set_arg_time;
reset_training_data_time = std::chrono::steady_clock::now() - start_reset_training_data_time;
Log::Info("reset_training_data_time: %f secs.", reset_training_data_time.count() * 1e-3);
Log::Info("serial_time: %f secs.", serial_time.count() * 1e-3);
Log::Info("alloc_gpu_time: %f secs.", alloc_gpu_time.count() * 1e-3);
Log::Info("set_arg_time: %f secs.", set_arg_time.count() * 1e-3);
#endif
}
void CUDATreeLearner::BeforeTrain() {
#if cudaMemcpy_DEBUG == 1
std::chrono::duration<double, std::milli> device_hessians_time = std::chrono::milliseconds(0);
std::chrono::duration<double, std::milli> device_gradients_time = std::chrono::milliseconds(0);
#endif
SerialTreeLearner::BeforeTrain();
#if CUDA_DEBUG >= 2
printf("CUDATreeLearner::BeforeTrain() Copying initial full gradients and hessians to device\n");
#endif
// Copy initial full hessians and gradients to GPU.
// We start copying as early as possible, instead of at ConstructHistogram().
if ((hessians_ != NULL) && (gradients_ != NULL)) {
if (!use_bagging_ && num_dense_feature_groups_) {
Log::Debug("CudaTreeLearner::BeforeTrain() No baggings, dense_feature_groups_=%d", num_dense_feature_groups_);
for (int device_id = 0; device_id < num_gpu_; ++device_id) {
if (!(share_state_->is_constant_hessian)) {
Log::Debug("CUDATreeLearner::BeforeTrain(): Starting hessians_ -> device_hessians_");
#if cudaMemcpy_DEBUG == 1
auto start_device_hessians_time = std::chrono::steady_clock::now();
#endif
CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_hessians_[device_id], hessians_, num_data_*sizeof(score_t), cudaMemcpyHostToDevice, stream_[device_id]));
CUDASUCCESS_OR_FATAL(cudaEventRecord(hessians_future_[device_id], stream_[device_id]));
#if cudaMemcpy_DEBUG == 1
device_hessians_time = std::chrono::steady_clock::now() - start_device_hessians_time;
#endif
Log::Debug("queued copy of device_hessians_");
}
#if cudaMemcpy_DEBUG == 1
auto start_device_gradients_time = std::chrono::steady_clock::now();
#endif
CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_gradients_[device_id], gradients_, num_data_ * sizeof(score_t), cudaMemcpyHostToDevice, stream_[device_id]));
CUDASUCCESS_OR_FATAL(cudaEventRecord(gradients_future_[device_id], stream_[device_id]));
#if cudaMemcpy_DEBUG == 1
device_gradients_time = std::chrono::steady_clock::now() - start_device_gradients_time;
#endif
Log::Debug("CUDATreeLearner::BeforeTrain: issued gradients_ -> device_gradients_");
}
}
}
// use bagging
if ((hessians_ != NULL) && (gradients_ != NULL)) {
if (data_partition_->leaf_count(0) != num_data_ && num_dense_feature_groups_) {
// On GPU, we start copying indices, gradients and hessians now, instead at ConstructHistogram()
// copy used gradients and hessians to ordered buffer
const data_size_t* indices = data_partition_->indices();
data_size_t cnt = data_partition_->leaf_count(0);
// transfer the indices to GPU
for (int device_id = 0; device_id < num_gpu_; ++device_id) {
CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_data_indices_[device_id], indices, cnt * sizeof(*indices), cudaMemcpyHostToDevice, stream_[device_id]));
CUDASUCCESS_OR_FATAL(cudaEventRecord(indices_future_[device_id], stream_[device_id]));
if (!(share_state_->is_constant_hessian)) {
CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_hessians_[device_id], const_cast<void*>(reinterpret_cast<const void*>(&(hessians_[0]))), num_data_ * sizeof(score_t), cudaMemcpyHostToDevice, stream_[device_id]));
CUDASUCCESS_OR_FATAL(cudaEventRecord(hessians_future_[device_id], stream_[device_id]));
}
CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_gradients_[device_id], const_cast<void*>(reinterpret_cast<const void*>(&(gradients_[0]))), num_data_ * sizeof(score_t), cudaMemcpyHostToDevice, stream_[device_id]));
CUDASUCCESS_OR_FATAL(cudaEventRecord(gradients_future_[device_id], stream_[device_id]));
}
}
}
}
bool CUDATreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
int smaller_leaf;
data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
// only have root
if (right_leaf < 0) {
smaller_leaf = -1;
} else if (num_data_in_left_child < num_data_in_right_child) {
smaller_leaf = left_leaf;
} else {
smaller_leaf = right_leaf;
}
// Copy indices, gradients and hessians as early as possible
if (smaller_leaf >= 0 && num_dense_feature_groups_) {
// only need to initialize for smaller leaf
// Get leaf boundary
const data_size_t* indices = data_partition_->indices();
data_size_t begin = data_partition_->leaf_begin(smaller_leaf);
data_size_t end = begin + data_partition_->leaf_count(smaller_leaf);
for (int device_id = 0; device_id < num_gpu_; ++device_id) {
CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_data_indices_[device_id], &indices[begin], (end-begin) * sizeof(data_size_t), cudaMemcpyHostToDevice, stream_[device_id]));
CUDASUCCESS_OR_FATAL(cudaEventRecord(indices_future_[device_id], stream_[device_id]));
}
}
const bool ret = SerialTreeLearner::BeforeFindBestSplit(tree, left_leaf, right_leaf);
return ret;
}
bool CUDATreeLearner::ConstructGPUHistogramsAsync(
const std::vector<int8_t>& is_feature_used,
const data_size_t* data_indices, data_size_t num_data) {
if (num_data <= 0) {
return false;
}
// do nothing if no features can be processed on GPU
if (!num_dense_feature_groups_) {
Log::Debug("no dense feature groups, returning");
return false;
}
// copy data indices if it is not null
if (data_indices != nullptr && num_data != num_data_) {
for (int device_id = 0; device_id < num_gpu_; ++device_id) {
CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_data_indices_[device_id], data_indices, num_data * sizeof(data_size_t), cudaMemcpyHostToDevice, stream_[device_id]));
CUDASUCCESS_OR_FATAL(cudaEventRecord(indices_future_[device_id], stream_[device_id]));
}
}
// converted indices in is_feature_used to feature-group indices
std::vector<int8_t> is_feature_group_used(num_feature_groups_, 0);
#pragma omp parallel for schedule(static, 1024) if (num_features_ >= 2048)
for (int i = 0; i < num_features_; ++i) {
if (is_feature_used[i]) {
int feature_group = train_data_->Feature2Group(i);
is_feature_group_used[feature_group] = (train_data_->FeatureGroupNumBin(feature_group) <= 16) ? 2 : 1;
}
}
// construct the feature masks for dense feature-groups
int used_dense_feature_groups = 0;
#pragma omp parallel for schedule(static, 1024) reduction(+:used_dense_feature_groups) if (num_dense_feature_groups_ >= 2048)
for (int i = 0; i < num_dense_feature_groups_; ++i) {
if (is_feature_group_used[dense_feature_group_map_[i]]) {
feature_masks_[i] = is_feature_group_used[dense_feature_group_map_[i]];
++used_dense_feature_groups;
} else {
feature_masks_[i] = 0;
}
}
bool use_all_features = ((used_dense_feature_groups == num_dense_feature_groups_) && (data_indices != nullptr));
// if no feature group is used, just return and do not use GPU
if (used_dense_feature_groups == 0) {
return false;
}
// if not all feature groups are used, we need to transfer the feature mask to GPU
// otherwise, we will use a specialized GPU kernel with all feature groups enabled
// We now copy even if all features are used.
#pragma omp parallel for schedule(static, num_gpu_)
for (int device_id = 0; device_id < num_gpu_; ++device_id) {
int offset = offset_gpu_feature_groups_[device_id];
CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_feature_masks_[device_id], ptr_pinned_feature_masks_ + offset, num_gpu_feature_groups_[device_id] , cudaMemcpyHostToDevice, stream_[device_id]));
}
// All data have been prepared, now run the GPU kernel
GPUHistogram(num_data, use_all_features);
return true;
}
void CUDATreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
std::vector<int8_t> is_sparse_feature_used(num_features_, 0);
std::vector<int8_t> is_dense_feature_used(num_features_, 0);
int num_dense_features = 0, num_sparse_features = 0;
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue;
if (!is_feature_used[feature_index]) continue;
if (train_data_->IsMultiGroup(train_data_->Feature2Group(feature_index))) {
is_sparse_feature_used[feature_index] = 1;
num_sparse_features++;
} else {
is_dense_feature_used[feature_index] = 1;
num_dense_features++;
}
}
// construct smaller leaf
hist_t* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - kHistOffset;
// Check workgroups per feature4 tuple..
int exp_workgroups_per_feature = GetNumWorkgroupsPerFeature(smaller_leaf_splits_->num_data_in_leaf());
// if the workgroup per feature is 1 (2^0), return as the work is too small for a GPU
if (exp_workgroups_per_feature == 0) {
return SerialTreeLearner::ConstructHistograms(is_feature_used, use_subtract);
}
// ConstructGPUHistogramsAsync will return true if there are availabe feature groups dispatched to GPU
bool is_gpu_used = ConstructGPUHistogramsAsync(is_feature_used,
nullptr, smaller_leaf_splits_->num_data_in_leaf());
// then construct sparse features on CPU
// We set data_indices to null to avoid rebuilding ordered gradients/hessians
if (num_sparse_features > 0) {
train_data_->ConstructHistograms(is_sparse_feature_used,
smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(),
gradients_, hessians_,
ordered_gradients_.data(), ordered_hessians_.data(),
share_state_.get(),
ptr_smaller_leaf_hist_data);
}
// wait for GPU to finish, only if GPU is actually used
if (is_gpu_used) {
if (config_->gpu_use_dp) {
// use double precision
WaitAndGetHistograms<hist_t>(smaller_leaf_histogram_array_);
} else {
// use single precision
WaitAndGetHistograms<gpu_hist_t>(smaller_leaf_histogram_array_);
}
}
// Compare GPU histogram with CPU histogram, useful for debuggin GPU code problem
// #define CUDA_DEBUG_COMPARE
#ifdef CUDA_DEBUG_COMPARE
printf("Start Comparing_Histogram between GPU and CPU, num_dense_feature_groups_ = %d\n", num_dense_feature_groups_);
bool compare = true;
for (int i = 0; i < num_dense_feature_groups_; ++i) {
if (!feature_masks_[i])
continue;
int dense_feature_group_index = dense_feature_group_map_[i];
size_t size = train_data_->FeatureGroupNumBin(dense_feature_group_index);
hist_t* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - kHistOffset;
hist_t* current_histogram = ptr_smaller_leaf_hist_data + train_data_->GroupBinBoundary(dense_feature_group_index) * 2;
hist_t* gpu_histogram = new hist_t[size * 2];
data_size_t num_data = smaller_leaf_splits_->num_data_in_leaf();
printf("Comparing histogram for feature %d, num_data %d, num_data_ = %d, %lu bins\n", dense_feature_group_index, num_data, num_data_, size);
std::copy(current_histogram, current_histogram + size * 2, gpu_histogram);
std::memset(current_histogram, 0, size * sizeof(hist_t) * 2);
if (train_data_->FeatureGroupBin(dense_feature_group_index) == nullptr) {
continue;
}
if (num_data == num_data_) {
if (share_state_->is_constant_hessian) {
printf("ConstructHistogram(): num_data == num_data_ is_constant_hessian\n");
train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram(
0,
num_data,
gradients_,
current_histogram);
} else {
printf("ConstructHistogram(): num_data == num_data_\n");
train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram(
0,
num_data,
gradients_, hessians_,
current_histogram);
}
} else {
if (share_state_->is_constant_hessian) {
printf("ConstructHistogram(): is_constant_hessian\n");
train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram(
smaller_leaf_splits_->data_indices(),
0,
num_data,
gradients_,
current_histogram);
} else {
printf("ConstructHistogram(): 4, num_data = %d, num_data_ = %d\n", num_data, num_data_);
train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram(
smaller_leaf_splits_->data_indices(),
0,
num_data,
gradients_, hessians_,
current_histogram);
}
}
int retval;
if ((num_data != num_data_) && compare) {
retval = CompareHistograms(gpu_histogram, current_histogram, size, dense_feature_group_index, config_->gpu_use_dp, share_state_->is_constant_hessian);
printf("CompareHistograms reports %d errors\n", retval);
compare = false;
}
retval = CompareHistograms(gpu_histogram, current_histogram, size, dense_feature_group_index, config_->gpu_use_dp, share_state_->is_constant_hessian);
if (num_data == num_data_) {
printf("CompareHistograms reports %d errors\n", retval);
} else {
printf("CompareHistograms reports %d errors\n", retval);
}
std::copy(gpu_histogram, gpu_histogram + size * 2, current_histogram);
delete [] gpu_histogram;
}
printf("End Comparing Histogram between GPU and CPU\n");
fflush(stderr);
fflush(stdout);
#endif
if (larger_leaf_histogram_array_ != nullptr && !use_subtract) {
// construct larger leaf
hist_t* ptr_larger_leaf_hist_data = larger_leaf_histogram_array_[0].RawData() - kHistOffset;
is_gpu_used = ConstructGPUHistogramsAsync(is_feature_used,
larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf());
// then construct sparse features on CPU
// We set data_indices to null to avoid rebuilding ordered gradients/hessians
if (num_sparse_features > 0) {
train_data_->ConstructHistograms(is_sparse_feature_used,
larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf(),
gradients_, hessians_,
ordered_gradients_.data(), ordered_hessians_.data(),
share_state_.get(),
ptr_larger_leaf_hist_data);
}
// wait for GPU to finish, only if GPU is actually used
if (is_gpu_used) {
if (config_->gpu_use_dp) {
// use double precision
WaitAndGetHistograms<hist_t>(larger_leaf_histogram_array_);
} else {
// use single precision
WaitAndGetHistograms<gpu_hist_t>(larger_leaf_histogram_array_);
}
}
}
}
void CUDATreeLearner::FindBestSplits(const Tree* tree) {
SerialTreeLearner::FindBestSplits(tree);
#if CUDA_DEBUG >= 3
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
size_t bin_size = train_data_->FeatureNumBin(feature_index) + 1;
printf("CUDATreeLearner::FindBestSplits() Feature %d bin_size=%zd smaller leaf:\n", feature_index, bin_size);
PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size);
if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->leaf_index() < 0) { continue; }
printf("CUDATreeLearner::FindBestSplits() Feature %d bin_size=%zd larger leaf:\n", feature_index, bin_size);
PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size);
}
#endif
}
void CUDATreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];
#if CUDA_DEBUG >= 2
printf("Splitting leaf %d with feature %d thresh %d gain %f stat %f %f %f %f\n", best_Leaf, best_split_info.feature, best_split_info.threshold, best_split_info.gain, best_split_info.left_sum_gradient, best_split_info.right_sum_gradient, best_split_info.left_sum_hessian, best_split_info.right_sum_hessian);
#endif
SerialTreeLearner::Split(tree, best_Leaf, left_leaf, right_leaf);
if (Network::num_machines() == 1) {
// do some sanity check for the GPU algorithm
if (best_split_info.left_count < best_split_info.right_count) {
if ((best_split_info.left_count != smaller_leaf_splits_->num_data_in_leaf()) ||
(best_split_info.right_count!= larger_leaf_splits_->num_data_in_leaf())) {
Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf());
}
} else {
if ((best_split_info.left_count != larger_leaf_splits_->num_data_in_leaf()) ||
(best_split_info.right_count!= smaller_leaf_splits_->num_data_in_leaf())) {
Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf());
}
}
}
}
} // namespace LightGBM
#undef cudaMemcpy_DEBUG
#endif // USE_CUDA
/*!
* Copyright (c) 2020 IBM Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_TREELEARNER_CUDA_TREE_LEARNER_H_
#define LIGHTGBM_TREELEARNER_CUDA_TREE_LEARNER_H_
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/dataset.h>
#include <LightGBM/feature_group.h>
#include <LightGBM/tree.h>
#include <string>
#include <cmath>
#include <cstdio>
#include <memory>
#include <random>
#include <vector>
#ifdef USE_CUDA
#include <cuda_runtime.h>
#endif
#include "feature_histogram.hpp"
#include "serial_tree_learner.h"
#include "data_partition.hpp"
#include "split_info.hpp"
#include "leaf_splits.hpp"
#ifdef USE_CUDA
#include <LightGBM/cuda/vector_cudahost.h>
#include "cuda_kernel_launcher.h"
using json11::Json;
namespace LightGBM {
/*!
* \brief CUDA-based parallel learning algorithm.
*/
class CUDATreeLearner: public SerialTreeLearner {
public:
explicit CUDATreeLearner(const Config* tree_config);
~CUDATreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) override;
Tree* Train(const score_t* gradients, const score_t *hessians);
void SetBaggingData(const Dataset* subset, const data_size_t* used_indices, data_size_t num_data) override {
SerialTreeLearner::SetBaggingData(subset, used_indices, num_data);
if (subset == nullptr && used_indices != nullptr) {
if (num_data != num_data_) {
use_bagging_ = true;
return;
}
}
use_bagging_ = false;
}
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
void FindBestSplits(const Tree* tree) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
private:
typedef float gpu_hist_t;
/*!
* \brief Find the best number of workgroups processing one feature for maximizing efficiency
* \param leaf_num_data The number of data examples on the current leaf being processed
* \return Log2 of the best number for workgroups per feature, in range 0...kMaxLogWorkgroupsPerFeature
*/
int GetNumWorkgroupsPerFeature(data_size_t leaf_num_data);
/*!
* \brief Initialize GPU device
* \param num_gpu: number of maximum gpus
*/
void InitGPU(int num_gpu);
/*!
* \brief Allocate memory for GPU computation // alloc only
*/
void CountDenseFeatureGroups(); // compute num_dense_feature_group
void prevAllocateGPUMemory(); // compute CPU-side param calculation & Pin HostMemory
void AllocateGPUMemory();
/*!
* \ ResetGPUMemory
*/
void ResetGPUMemory();
/*!
* \ copy dense feature from CPU to GPU
*/
void copyDenseFeature();
/*!
* \brief Compute GPU feature histogram for the current leaf.
* Indices, gradients and hessians have been copied to the device.
* \param leaf_num_data Number of data on current leaf
* \param use_all_features Set to true to not use feature masks, with a faster kernel
*/
void GPUHistogram(data_size_t leaf_num_data, bool use_all_features);
void SetThreadData(ThreadData* thread_data, int device_id, int histogram_size,
int leaf_num_data, bool use_all_features,
int num_workgroups, int exp_workgroups_per_feature) {
ThreadData* td = &thread_data[device_id];
td->device_id = device_id;
td->histogram_size = histogram_size;
td->leaf_num_data = leaf_num_data;
td->num_data = num_data_;
td->use_all_features = use_all_features;
td->is_constant_hessian = share_state_->is_constant_hessian;
td->num_workgroups = num_workgroups;
td->stream = stream_[device_id];
td->device_features = device_features_[device_id];
td->device_feature_masks = reinterpret_cast<uint8_t *>(device_feature_masks_[device_id]);
td->device_data_indices = device_data_indices_[device_id];
td->device_gradients = device_gradients_[device_id];
td->device_hessians = device_hessians_[device_id];
td->hessians_const = hessians_[0];
td->device_subhistograms = device_subhistograms_[device_id];
td->sync_counters = sync_counters_[device_id];
td->device_histogram_outputs = device_histogram_outputs_[device_id];
td->exp_workgroups_per_feature = exp_workgroups_per_feature;
td->kernel_start = &(kernel_start_[device_id]);
td->kernel_wait_obj = &(kernel_wait_obj_[device_id]);
td->kernel_input_wait_time = &(kernel_input_wait_time_[device_id]);
size_t output_size = num_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_;
size_t host_output_offset = offset_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_;
td->output_size = output_size;
td->host_histogram_output = reinterpret_cast<char*>(host_histogram_outputs_) + host_output_offset;
td->histograms_wait_obj = &(histograms_wait_obj_[device_id]);
}
/*!
* \brief Wait for GPU kernel execution and read histogram
* \param histograms Destination of histogram results from GPU.
*/
template <typename HistType>
void WaitAndGetHistograms(FeatureHistogram* leaf_histogram_array);
/*!
* \brief Construct GPU histogram asynchronously.
* Interface is similar to Dataset::ConstructHistograms().
* \param is_feature_used A predicate vector for enabling each feature
* \param data_indices Array of data example IDs to be included in histogram, will be copied to GPU.
* Set to nullptr to skip copy to GPU.
* \param num_data Number of data examples to be included in histogram
* \return true if GPU kernel is launched, false if GPU is not used
*/
bool ConstructGPUHistogramsAsync(
const std::vector<int8_t>& is_feature_used,
const data_size_t* data_indices, data_size_t num_data);
/*! brief Log2 of max number of workgroups per feature*/
const int kMaxLogWorkgroupsPerFeature = 10; // 2^10
/*! brief Max total number of workgroups with preallocated workspace.
* If we use more than this number of workgroups, we have to reallocate subhistograms */
std::vector<int> preallocd_max_num_wg_;
/*! \brief True if bagging is used */
bool use_bagging_;
/*! \brief GPU command queue object */
std::vector<cudaStream_t> stream_;
/*! \brief total number of feature-groups */
int num_feature_groups_;
/*! \brief total number of dense feature-groups, which will be processed on GPU */
int num_dense_feature_groups_;
std::vector<int> num_gpu_feature_groups_;
std::vector<int> offset_gpu_feature_groups_;
/*! \brief On GPU we read one DWORD (4-byte) of features of one example once.
* With bin size > 16, there are 4 features per DWORD.
* With bin size <=16, there are 8 features per DWORD.
*/
int dword_features_;
/*! \brief Max number of bins of training data, used to determine
* which GPU kernel to use */
int max_num_bin_;
/*! \brief Used GPU kernel bin size (64, 256) */
int histogram_size_;
int device_bin_size_;
/*! \brief Size of histogram bin entry, depending if single or double precision is used */
size_t hist_bin_entry_sz_;
/*! \brief Indices of all dense feature-groups */
std::vector<int> dense_feature_group_map_;
/*! \brief Indices of all sparse feature-groups */
std::vector<int> sparse_feature_group_map_;
/*! \brief GPU memory object holding the training data */
std::vector<uint8_t*> device_features_;
/*! \brief GPU memory object holding the ordered gradient */
std::vector<score_t*> device_gradients_;
/*! \brief Pointer to pinned memory of ordered gradient */
void * ptr_pinned_gradients_ = nullptr;
/*! \brief GPU memory object holding the ordered hessian */
std::vector<score_t*> device_hessians_;
/*! \brief Pointer to pinned memory of ordered hessian */
void * ptr_pinned_hessians_ = nullptr;
/*! \brief A vector of feature mask. 1 = feature used, 0 = feature not used */
std::vector<char> feature_masks_;
/*! \brief GPU memory object holding the feature masks */
std::vector<char*> device_feature_masks_;
/*! \brief Pointer to pinned memory of feature masks */
char* ptr_pinned_feature_masks_ = nullptr;
/*! \brief GPU memory object holding indices of the leaf being processed */
std::vector<data_size_t*> device_data_indices_;
/*! \brief GPU memory object holding counters for workgroup coordination */
std::vector<int*> sync_counters_;
/*! \brief GPU memory object holding temporary sub-histograms per workgroup */
std::vector<char*> device_subhistograms_;
/*! \brief Host memory object for histogram output (GPU will write to Host memory directly) */
std::vector<void*> device_histogram_outputs_;
/*! \brief Host memory pointer for histogram outputs */
void *host_histogram_outputs_;
/*! CUDA waitlist object for waiting for data transfer before kernel execution */
std::vector<cudaEvent_t> kernel_wait_obj_;
/*! CUDA waitlist object for reading output histograms after kernel execution */
std::vector<cudaEvent_t> histograms_wait_obj_;
/*! CUDA Asynchronous waiting object for copying indices */
std::vector<cudaEvent_t> indices_future_;
/*! Asynchronous waiting object for copying gradients */
std::vector<cudaEvent_t> gradients_future_;
/*! Asynchronous waiting object for copying hessians */
std::vector<cudaEvent_t> hessians_future_;
/*! Asynchronous waiting object for copying dense features */
std::vector<cudaEvent_t> features_future_;
// host-side buffer for converting feature data into featre4 data
int nthreads_; // number of Feature4* vector on host4_vecs_
std::vector<cudaEvent_t> kernel_start_;
std::vector<float> kernel_time_; // measure histogram kernel time
std::vector<std::chrono::duration<double, std::milli>> kernel_input_wait_time_;
int num_gpu_;
int allocated_num_data_; // allocated data instances
pthread_t **cpu_threads_; // pthread, 1 cpu thread / gpu
};
} // namespace LightGBM
#else // USE_CUDA
// When GPU support is not compiled in, quit with an error message
namespace LightGBM {
class CUDATreeLearner: public SerialTreeLearner {
public:
#pragma warning(disable : 4702)
explicit CUDATreeLearner(const Config* tree_config) : SerialTreeLearner(tree_config) {
Log::Fatal("CUDA Tree Learner was not enabled in this build.\n"
"Please recompile with CMake option -DUSE_CUDA=1");
}
};
} // namespace LightGBM
#endif // USE_CUDA
#endif // LIGHTGBM_TREELEARNER_CUDA_TREE_LEARNER_H_
......@@ -256,6 +256,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, in
}
// instantiate template classes, otherwise linker cannot find the code
template class DataParallelTreeLearner<CUDATreeLearner>;
template class DataParallelTreeLearner<GPUTreeLearner>;
template class DataParallelTreeLearner<SerialTreeLearner>;
......
......@@ -77,6 +77,7 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(
}
// instantiate template classes, otherwise linker cannot find the code
template class FeatureParallelTreeLearner<CUDATreeLearner>;
template class FeatureParallelTreeLearner<GPUTreeLearner>;
template class FeatureParallelTreeLearner<SerialTreeLearner>;
} // namespace LightGBM
......@@ -52,7 +52,7 @@ void PrintHistograms(hist_t* h, size_t size) {
double total_hess = 0;
for (size_t i = 0; i < size; ++i) {
printf("%03lu=%9.3g,%9.3g\t", i, GET_GRAD(h, i), GET_HESS(h, i));
if ((i & 2) == 2)
if ((i & 3) == 3)
printf("\n");
total_hess += GET_HESS(h, i);
}
......@@ -1068,10 +1068,10 @@ void GPUTreeLearner::FindBestSplits(const Tree* tree) {
}
size_t bin_size = train_data_->FeatureNumBin(feature_index) + 1;
printf("Feature %d smaller leaf:\n", feature_index);
PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - 1, bin_size);
PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size);
if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; }
printf("Feature %d larger leaf:\n", feature_index);
PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - 1, bin_size);
PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size);
}
#endif
}
......
/*!
* Copyright (c) 2020 IBM Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#include <LightGBM/meta.h>
#include <cstdint>
#include <cstdio>
#include "histogram_16_64_256.hu"
namespace LightGBM {
// atomic add for float number in local memory
inline __device__ void atomic_local_add_f(acc_type *addr, const acc_type val) {
atomicAdd(addr, static_cast<acc_type>(val));
}
// histogram16 stuff
#ifdef ENABLE_ALL_FEATURES
#ifdef IGNORE_INDICES
#define KERNEL_NAME histogram16_fulldata
#else // IGNORE_INDICES
#define KERNEL_NAME histogram16
#endif // IGNORE_INDICES
#else // ENABLE_ALL_FEATURES
#error "ENABLE_ALL_FEATURES should always be 1"
#define KERNEL_NAME histogram16
#endif // ENABLE_ALL_FEATURES
#define NUM_BINS 16
#define LOCAL_MEM_SIZE ((sizeof(unsigned int) + 2 * sizeof(acc_type)) * NUM_BINS)
// this function will be called by histogram16
// we have one sub-histogram of one feature in local memory, and need to read others
inline void __device__ within_kernel_reduction16x4(const acc_type* __restrict__ feature_sub_hist,
const unsigned int skip_id,
const unsigned int old_val_cont_bin0,
const uint16_t num_sub_hist,
acc_type* __restrict__ output_buf,
acc_type* __restrict__ local_hist,
const size_t power_feature_workgroups) {
const uint16_t ltid = threadIdx.x;
acc_type grad_bin = local_hist[ltid * 2];
acc_type hess_bin = local_hist[ltid * 2 + 1];
unsigned int* __restrict__ local_cnt = reinterpret_cast<unsigned int *>(local_hist + 2 * NUM_BINS);
unsigned int cont_bin;
if (power_feature_workgroups != 0) {
cont_bin = ltid ? local_cnt[ltid] : old_val_cont_bin0;
} else {
cont_bin = local_cnt[ltid];
}
uint16_t i;
if (power_feature_workgroups != 0) {
// add all sub-histograms for feature
const acc_type* __restrict__ p = feature_sub_hist + ltid;
for (i = 0; i < skip_id; ++i) {
grad_bin += *p; p += NUM_BINS;
hess_bin += *p; p += NUM_BINS;
cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
// skip the counters we already have
p += 3 * NUM_BINS;
for (i = i + 1; i < num_sub_hist; ++i) {
grad_bin += *p; p += NUM_BINS;
hess_bin += *p; p += NUM_BINS;
cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
}
__syncthreads();
output_buf[ltid * 2 + 0] = grad_bin;
output_buf[ltid * 2 + 1] = hess_bin;
}
#if USE_CONSTANT_BUF == 1
__kernel void KERNEL_NAME(__global const uchar* restrict feature_data_base,
__constant const uchar* restrict feature_masks __attribute__((max_constant_size(65536))),
const data_size_t feature_size,
__constant const data_size_t* restrict data_indices __attribute__((max_constant_size(65536))),
const data_size_t num_data,
__constant const score_t* restrict ordered_gradients __attribute__((max_constant_size(65536))),
#if CONST_HESSIAN == 0
__constant const score_t* restrict ordered_hessians __attribute__((max_constant_size(65536))),
#else
const score_t const_hessian,
#endif
char* __restrict__ output_buf,
volatile int * sync_counters,
acc_type* __restrict__ hist_buf_base,
const size_t power_feature_workgroups) {
#else
__global__ void KERNEL_NAME(const uchar* feature_data_base,
const uchar* __restrict__ feature_masks,
const data_size_t feature_size,
const data_size_t* data_indices,
const data_size_t num_data,
const score_t* ordered_gradients,
#if CONST_HESSIAN == 0
const score_t* ordered_hessians,
#else
const score_t const_hessian,
#endif
char* __restrict__ output_buf,
volatile int * sync_counters,
acc_type* __restrict__ hist_buf_base,
const size_t power_feature_workgroups) {
#endif
// allocate the local memory array aligned with float2, to guarantee correct alignment on NVIDIA platforms
// otherwise a "Misaligned Address" exception may occur
__shared__ float2 shared_array[LOCAL_MEM_SIZE/sizeof(float2)];
const unsigned int gtid = blockIdx.x * blockDim.x + threadIdx.x;
const uint16_t ltid = threadIdx.x;
const uint16_t lsize = NUM_BINS; // get_local_size(0);
const uint16_t group_id = blockIdx.x;
// local memory per workgroup is 3 KB
// clear local memory
unsigned int *ptr = reinterpret_cast<unsigned int *>(shared_array);
for (int i = ltid; i < LOCAL_MEM_SIZE/sizeof(unsigned int); i += lsize) {
ptr[i] = 0;
}
__syncthreads();
// gradient/hessian histograms
// assume this starts at 32 * 4 = 128-byte boundary // What does it mean? boundary??
// total size: 2 * 256 * size_of(float) = 2 KB
// organization: each feature/grad/hessian is at a different bank,
// as indepedent of the feature value as possible
acc_type *gh_hist = reinterpret_cast<acc_type *>(shared_array);
// counter histogram
// total size: 256 * size_of(unsigned int) = 1 KB
unsigned int *cnt_hist = reinterpret_cast<unsigned int *>(gh_hist + 2 * NUM_BINS);
// odd threads (1, 3, ...) compute histograms for hessians first
// even thread (0, 2, ...) compute histograms for gradients first
// etc.
uchar is_hessian_first = ltid & 1;
uint16_t feature_id = group_id >> power_feature_workgroups;
// each 2^POWER_FEATURE_WORKGROUPS workgroups process on one feature (compile-time constant)
// feature_size is the number of examples per feature
const uchar *feature_data = feature_data_base + feature_id * feature_size;
// size of threads that process this feature4
const unsigned int subglobal_size = lsize * (1 << power_feature_workgroups);
// equavalent thread ID in this subgroup for this feature4
const unsigned int subglobal_tid = gtid - feature_id * subglobal_size;
data_size_t ind;
data_size_t ind_next;
#ifdef IGNORE_INDICES
ind = subglobal_tid;
#else
ind = data_indices[subglobal_tid];
#endif
// extract feature mask, when a byte is set to 0, that feature is disabled
uchar feature_mask = feature_masks[feature_id];
// exit if the feature is masked
if (!feature_mask) {
return;
} else {
feature_mask = feature_mask - 1; // feature_mask is used for get feature (1: 4bit feature, 0: 8bit feature)
}
// STAGE 1: read feature data, and gradient and hessian
// first half of the threads read feature data from global memory
// We will prefetch data into the "next" variable at the beginning of each iteration
uchar feature;
uchar feature_next;
uint16_t bin;
feature = feature_data[ind >> feature_mask];
if (feature_mask) {
feature = (feature >> ((ind & 1) << 2)) & 0xf;
}
bin = feature;
acc_type grad_bin = 0.0f, hess_bin = 0.0f;
acc_type *addr_bin;
// store gradient and hessian
score_t grad, hess;
score_t grad_next, hess_next;
grad = ordered_gradients[ind];
#if CONST_HESSIAN == 0
hess = ordered_hessians[ind];
#endif
// there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4
for (unsigned int i = subglobal_tid; i < num_data; i += subglobal_size) {
// prefetch the next iteration variables
// we don't need bondary check because we have made the buffer large
int i_next = i + subglobal_size;
#ifdef IGNORE_INDICES
// we need to check to bounds here
ind_next = i_next < num_data ? i_next : i;
#else
ind_next = data_indices[i_next];
#endif
grad_next = ordered_gradients[ind_next];
#if CONST_HESSIAN == 0
hess_next = ordered_hessians[ind_next];
#endif
// STAGE 2: accumulate gradient and hessian
if (bin != feature) {
addr_bin = gh_hist + bin * 2 + is_hessian_first;
#if CONST_HESSIAN == 0
acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin;
atomic_local_add_f(addr_bin, acc_bin);
addr_bin = addr_bin + 1 - 2 * is_hessian_first;
acc_bin = is_hessian_first ? grad_bin : hess_bin;
atomic_local_add_f(addr_bin, acc_bin);
#elif CONST_HESSIAN == 1
atomic_local_add_f(addr_bin, grad_bin);
#endif
bin = feature;
grad_bin = grad;
hess_bin = hess;
} else {
grad_bin += grad;
hess_bin += hess;
}
// prefetch the next iteration variables
feature_next = feature_data[ind_next >> feature_mask];
// STAGE 3: accumulate counter
atomicAdd(cnt_hist + feature, 1);
// STAGE 4: update next stat
grad = grad_next;
hess = hess_next;
if (!feature_mask) {
feature = feature_next;
} else {
feature = (feature_next >> ((ind_next & 1) << 2)) & 0xf;
}
}
addr_bin = gh_hist + bin * 2 + is_hessian_first;
#if CONST_HESSIAN == 0
acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin;
atomic_local_add_f(addr_bin, acc_bin);
addr_bin = addr_bin + 1 - 2 * is_hessian_first;
acc_bin = is_hessian_first ? grad_bin : hess_bin;
atomic_local_add_f(addr_bin, acc_bin);
#elif CONST_HESSIAN == 1
atomic_local_add_f(addr_bin, grad_bin);
#endif
__syncthreads();
#if CONST_HESSIAN == 1
// make a final reduction
gh_hist[ltid * 2] += gh_hist[ltid * 2 + 1];
gh_hist[ltid * 2 + 1] = const_hessian * cnt_hist[ltid]; // counter move to this position
__syncthreads();
#endif
#if POWER_FEATURE_WORKGROUPS != 0
acc_type *__restrict__ output = reinterpret_cast<acc_type *>(output_buf) + group_id * 3 * NUM_BINS;
// write gradients and hessians
acc_type *__restrict__ ptr_f = output;
for (uint16_t i = ltid; i < 2 * NUM_BINS; i += lsize) {
// even threads read gradients, odd threads read hessians
acc_type value = gh_hist[i];
ptr_f[(i & 1) * NUM_BINS + (i >> 1)] = value;
}
// write counts
acc_int_type *__restrict__ ptr_i = reinterpret_cast<acc_int_type *>(output + 2 * NUM_BINS);
for (uint16_t i = ltid; i < NUM_BINS; i += lsize) {
unsigned int value = cnt_hist[i];
ptr_i[i] = value;
}
__syncthreads();
__threadfence();
unsigned int * counter_val = cnt_hist;
// backup the old value
unsigned int old_val = *counter_val;
if (ltid == 0) {
// all workgroups processing the same feature add this counter
*counter_val = atomicAdd(const_cast<int*>(sync_counters + feature_id), 1);
}
// make sure everyone in this workgroup is here
__syncthreads();
// everyone in this workgroup: if we are the last workgroup, then do reduction!
if (*counter_val == (1 << power_feature_workgroups) - 1) {
if (ltid == 0) {
sync_counters[feature_id] = 0;
}
#else
}
// only 1 work group, no need to increase counter
// the reduction will become a simple copy
{
unsigned int old_val; // dummy
#endif
// locate our feature's block in output memory
unsigned int output_offset = (feature_id << power_feature_workgroups);
acc_type const * __restrict__ feature_subhists =
reinterpret_cast<acc_type *>(output_buf) + output_offset * 3 * NUM_BINS;
// skip reading the data already in local memory
unsigned int skip_id = group_id - output_offset;
// locate output histogram location for this feature4
acc_type *__restrict__ hist_buf = hist_buf_base + feature_id * 2 * NUM_BINS;
within_kernel_reduction16x4(feature_subhists, skip_id, old_val, 1 << power_feature_workgroups, hist_buf, reinterpret_cast<acc_type *>(shared_array), power_feature_workgroups);
}
}
// end of histogram16 stuff
// histogram64 stuff
#undef KERNEL_NAME
#undef NUM_BINS
#undef LOCAL_MEM_SIZE
#ifdef ENABLE_ALL_FEATURES
#ifdef IGNORE_INDICES
#define KERNEL_NAME histogram64_fulldata
#else // IGNORE_INDICES
#define KERNEL_NAME histogram64 // seems like ENABLE_ALL_FEATURES is set to 1 in the header if its disabled
// #define KERNEL_NAME histogram64_allfeats
#endif // IGNORE_INDICES
#else // ENABLE_ALL_FEATURES
#error "ENABLE_ALL_FEATURES should always be 1"
#define KERNEL_NAME histogram64
#endif // ENABLE_ALL_FEATURES
#define NUM_BINS 64
#define LOCAL_MEM_SIZE ((sizeof(unsigned int) + 2 * sizeof(acc_type)) * NUM_BINS)
// this function will be called by histogram64
// we have one sub-histogram of one feature in local memory, and need to read others
inline void __device__ within_kernel_reduction64x4(const acc_type* __restrict__ feature_sub_hist,
const unsigned int skip_id,
const unsigned int old_val_cont_bin0,
const uint16_t num_sub_hist,
acc_type* __restrict__ output_buf,
acc_type* __restrict__ local_hist,
const size_t power_feature_workgroups) {
const uint16_t ltid = threadIdx.x;
acc_type grad_bin = local_hist[ltid * 2];
acc_type hess_bin = local_hist[ltid * 2 + 1];
unsigned int* __restrict__ local_cnt = reinterpret_cast<unsigned int *>(local_hist + 2 * NUM_BINS);
unsigned int cont_bin;
if (power_feature_workgroups != 0) {
cont_bin = ltid ? local_cnt[ltid] : old_val_cont_bin0;
} else {
cont_bin = local_cnt[ltid];
}
uint16_t i;
if (power_feature_workgroups != 0) {
// add all sub-histograms for feature
const acc_type* __restrict__ p = feature_sub_hist + ltid;
for (i = 0; i < skip_id; ++i) {
grad_bin += *p; p += NUM_BINS;
hess_bin += *p; p += NUM_BINS;
cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
// skip the counters we already have
p += 3 * NUM_BINS;
for (i = i + 1; i < num_sub_hist; ++i) {
grad_bin += *p; p += NUM_BINS;
hess_bin += *p; p += NUM_BINS;
cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
}
__syncthreads();
output_buf[ltid * 2 + 0] = grad_bin;
output_buf[ltid * 2 + 1] = hess_bin;
}
#if USE_CONSTANT_BUF == 1
__kernel void KERNEL_NAME(__global const uchar* restrict feature_data_base,
__constant const uchar* restrict feature_masks __attribute__((max_constant_size(65536))),
const data_size_t feature_size,
__constant const data_size_t* restrict data_indices __attribute__((max_constant_size(65536))),
const data_size_t num_data,
__constant const score_t* restrict ordered_gradients __attribute__((max_constant_size(65536))),
#if CONST_HESSIAN == 0
__constant const score_t* restrict ordered_hessians __attribute__((max_constant_size(65536))),
#else
const score_t const_hessian,
#endif
char* __restrict__ output_buf,
volatile int * sync_counters,
acc_type* __restrict__ hist_buf_base,
const size_t power_feature_workgroups) {
#else
__global__ void KERNEL_NAME(const uchar* feature_data_base,
const uchar* __restrict__ feature_masks,
const data_size_t feature_size,
const data_size_t* data_indices,
const data_size_t num_data,
const score_t* ordered_gradients,
#if CONST_HESSIAN == 0
const score_t* ordered_hessians,
#else
const score_t const_hessian,
#endif
char* __restrict__ output_buf,
volatile int * sync_counters,
acc_type* __restrict__ hist_buf_base,
const size_t power_feature_workgroups) {
#endif
// allocate the local memory array aligned with float2, to guarantee correct alignment on NVIDIA platforms
// otherwise a "Misaligned Address" exception may occur
__shared__ float2 shared_array[LOCAL_MEM_SIZE/sizeof(float2)];
const unsigned int gtid = blockIdx.x * blockDim.x + threadIdx.x;
const uint16_t ltid = threadIdx.x;
const uint16_t lsize = NUM_BINS; // get_local_size(0);
const uint16_t group_id = blockIdx.x;
// local memory per workgroup is 3 KB
// clear local memory
unsigned int *ptr = reinterpret_cast<unsigned int *>(shared_array);
for (int i = ltid; i < LOCAL_MEM_SIZE/sizeof(unsigned int); i += lsize) {
ptr[i] = 0;
}
__syncthreads();
// gradient/hessian histograms
// assume this starts at 32 * 4 = 128-byte boundary // What does it mean? boundary??
// total size: 2 * 256 * size_of(float) = 2 KB
// organization: each feature/grad/hessian is at a different bank,
// as indepedent of the feature value as possible
acc_type *gh_hist = reinterpret_cast<acc_type *>(shared_array);
// counter histogram
// total size: 256 * size_of(unsigned int) = 1 KB
unsigned int *cnt_hist = reinterpret_cast<unsigned int *>(gh_hist + 2 * NUM_BINS);
// odd threads (1, 3, ...) compute histograms for hessians first
// even thread (0, 2, ...) compute histograms for gradients first
// etc.
uchar is_hessian_first = ltid & 1;
uint16_t feature_id = group_id >> power_feature_workgroups;
// each 2^POWER_FEATURE_WORKGROUPS workgroups process on one feature (compile-time constant)
// feature_size is the number of examples per feature
const uchar *feature_data = feature_data_base + feature_id * feature_size;
// size of threads that process this feature4
const unsigned int subglobal_size = lsize * (1 << power_feature_workgroups);
// equavalent thread ID in this subgroup for this feature4
const unsigned int subglobal_tid = gtid - feature_id * subglobal_size;
data_size_t ind;
data_size_t ind_next;
#ifdef IGNORE_INDICES
ind = subglobal_tid;
#else
ind = data_indices[subglobal_tid];
#endif
// extract feature mask, when a byte is set to 0, that feature is disabled
uchar feature_mask = feature_masks[feature_id];
// exit if the feature is masked
if (!feature_mask) {
return;
} else {
feature_mask = feature_mask - 1; // feature_mask is used for get feature (1: 4bit feature, 0: 8bit feature)
}
// STAGE 1: read feature data, and gradient and hessian
// first half of the threads read feature data from global memory
// We will prefetch data into the "next" variable at the beginning of each iteration
uchar feature;
uchar feature_next;
uint16_t bin;
feature = feature_data[ind >> feature_mask];
if (feature_mask) {
feature = (feature >> ((ind & 1) << 2)) & 0xf;
}
bin = feature;
acc_type grad_bin = 0.0f, hess_bin = 0.0f;
acc_type *addr_bin;
// store gradient and hessian
score_t grad, hess;
score_t grad_next, hess_next;
grad = ordered_gradients[ind];
#if CONST_HESSIAN == 0
hess = ordered_hessians[ind];
#endif
// there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4
for (unsigned int i = subglobal_tid; i < num_data; i += subglobal_size) {
// prefetch the next iteration variables
// we don't need bondary check because we have made the buffer large
int i_next = i + subglobal_size;
#ifdef IGNORE_INDICES
// we need to check to bounds here
ind_next = i_next < num_data ? i_next : i;
#else
ind_next = data_indices[i_next];
#endif
grad_next = ordered_gradients[ind_next];
#if CONST_HESSIAN == 0
hess_next = ordered_hessians[ind_next];
#endif
// STAGE 2: accumulate gradient and hessian
if (bin != feature) {
addr_bin = gh_hist + bin * 2 + is_hessian_first;
#if CONST_HESSIAN == 0
acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin;
atomic_local_add_f(addr_bin, acc_bin);
addr_bin = addr_bin + 1 - 2 * is_hessian_first;
acc_bin = is_hessian_first ? grad_bin : hess_bin;
atomic_local_add_f(addr_bin, acc_bin);
#elif CONST_HESSIAN == 1
atomic_local_add_f(addr_bin, grad_bin);
#endif
bin = feature;
grad_bin = grad;
hess_bin = hess;
} else {
grad_bin += grad;
hess_bin += hess;
}
// prefetch the next iteration variables
feature_next = feature_data[ind_next >> feature_mask];
// STAGE 3: accumulate counter
atomicAdd(cnt_hist + feature, 1);
// STAGE 4: update next stat
grad = grad_next;
hess = hess_next;
if (!feature_mask) {
feature = feature_next;
} else {
feature = (feature_next >> ((ind_next & 1) << 2)) & 0xf;
}
}
addr_bin = gh_hist + bin * 2 + is_hessian_first;
#if CONST_HESSIAN == 0
acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin;
atomic_local_add_f(addr_bin, acc_bin);
addr_bin = addr_bin + 1 - 2 * is_hessian_first;
acc_bin = is_hessian_first ? grad_bin : hess_bin;
atomic_local_add_f(addr_bin, acc_bin);
#elif CONST_HESSIAN == 1
atomic_local_add_f(addr_bin, grad_bin);
#endif
__syncthreads();
#if CONST_HESSIAN == 1
// make a final reduction
gh_hist[ltid * 2] += gh_hist[ltid * 2 + 1];
gh_hist[ltid * 2 + 1] = const_hessian * cnt_hist[ltid]; // counter move to this position
__syncthreads();
#endif
#if POWER_FEATURE_WORKGROUPS != 0
acc_type *__restrict__ output = reinterpret_cast<acc_type *>(output_buf) + group_id * 3 * NUM_BINS;
// write gradients and hessians
acc_type *__restrict__ ptr_f = output;
for (uint16_t i = ltid; i < 2 * NUM_BINS; i += lsize) {
// even threads read gradients, odd threads read hessians
acc_type value = gh_hist[i];
ptr_f[(i & 1) * NUM_BINS + (i >> 1)] = value;
}
// write counts
acc_int_type *__restrict__ ptr_i = reinterpret_cast<acc_int_type *>(output + 2 * NUM_BINS);
for (uint16_t i = ltid; i < NUM_BINS; i += lsize) {
unsigned int value = cnt_hist[i];
ptr_i[i] = value;
}
__syncthreads();
__threadfence();
unsigned int * counter_val = cnt_hist;
// backup the old value
unsigned int old_val = *counter_val;
if (ltid == 0) {
// all workgroups processing the same feature add this counter
*counter_val = atomicAdd(const_cast<int*>(sync_counters + feature_id), 1);
}
// make sure everyone in this workgroup is here
__syncthreads();
// everyone in this workgroup: if we are the last workgroup, then do reduction!
if (*counter_val == (1 << power_feature_workgroups) - 1) {
if (ltid == 0) {
sync_counters[feature_id] = 0;
}
#else
}
// only 1 work group, no need to increase counter
// the reduction will become a simple copy
{
unsigned int old_val; // dummy
#endif
// locate our feature's block in output memory
unsigned int output_offset = (feature_id << power_feature_workgroups);
acc_type const * __restrict__ feature_subhists =
reinterpret_cast<acc_type *>(output_buf) + output_offset * 3 * NUM_BINS;
// skip reading the data already in local memory
unsigned int skip_id = group_id - output_offset;
// locate output histogram location for this feature4
acc_type *__restrict__ hist_buf = hist_buf_base + feature_id * 2 * NUM_BINS;
within_kernel_reduction64x4(feature_subhists, skip_id, old_val, 1 << power_feature_workgroups, hist_buf, reinterpret_cast<acc_type *>(shared_array), power_feature_workgroups);
}
}
// end of histogram64 stuff
// histogram256 stuff
#undef KERNEL_NAME
#undef NUM_BINS
#undef LOCAL_MEM_SIZE
#ifdef ENABLE_ALL_FEATURES
#ifdef IGNORE_INDICES
#define KERNEL_NAME histogram256_fulldata
#else // IGNORE_INDICES
#define KERNEL_NAME histogram256 // seems like ENABLE_ALL_FEATURES is set to 1 in the header if its disabled
// #define KERNEL_NAME histogram256_allfeats
#endif // IGNORE_INDICES
#else // ENABLE_ALL_FEATURES
#error "ENABLE_ALL_FEATURES should always be 1"
#define KERNEL_NAME histogram256
#endif // ENABLE_ALL_FEATURES
#define NUM_BINS 256
#define LOCAL_MEM_SIZE ((sizeof(unsigned int) + 2 * sizeof(acc_type)) * NUM_BINS)
// this function will be called by histogram256
// we have one sub-histogram of one feature in local memory, and need to read others
inline void __device__ within_kernel_reduction256x4(const acc_type* __restrict__ feature_sub_hist,
const unsigned int skip_id,
const unsigned int old_val_cont_bin0,
const uint16_t num_sub_hist,
acc_type* __restrict__ output_buf,
acc_type* __restrict__ local_hist,
const size_t power_feature_workgroups) {
const uint16_t ltid = threadIdx.x;
acc_type grad_bin = local_hist[ltid * 2];
acc_type hess_bin = local_hist[ltid * 2 + 1];
unsigned int* __restrict__ local_cnt = reinterpret_cast<unsigned int *>(local_hist + 2 * NUM_BINS);
unsigned int cont_bin;
if (power_feature_workgroups != 0) {
cont_bin = ltid ? local_cnt[ltid] : old_val_cont_bin0;
} else {
cont_bin = local_cnt[ltid];
}
uint16_t i;
if (power_feature_workgroups != 0) {
// add all sub-histograms for feature
const acc_type* __restrict__ p = feature_sub_hist + ltid;
for (i = 0; i < skip_id; ++i) {
grad_bin += *p; p += NUM_BINS;
hess_bin += *p; p += NUM_BINS;
cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
// skip the counters we already have
p += 3 * NUM_BINS;
for (i = i + 1; i < num_sub_hist; ++i) {
grad_bin += *p; p += NUM_BINS;
hess_bin += *p; p += NUM_BINS;
cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
}
__syncthreads();
output_buf[ltid * 2 + 0] = grad_bin;
output_buf[ltid * 2 + 1] = hess_bin;
}
#if USE_CONSTANT_BUF == 1
__kernel void KERNEL_NAME(__global const uchar* restrict feature_data_base,
__constant const uchar* restrict feature_masks __attribute__((max_constant_size(65536))),
const data_size_t feature_size,
__constant const data_size_t* restrict data_indices __attribute__((max_constant_size(65536))),
const data_size_t num_data,
__constant const score_t* restrict ordered_gradients __attribute__((max_constant_size(65536))),
#if CONST_HESSIAN == 0
__constant const score_t* restrict ordered_hessians __attribute__((max_constant_size(65536))),
#else
const score_t const_hessian,
#endif
char* __restrict__ output_buf,
volatile int * sync_counters,
acc_type* __restrict__ hist_buf_base,
const size_t power_feature_workgroups) {
#else
__global__ void KERNEL_NAME(const uchar* feature_data_base,
const uchar* __restrict__ feature_masks,
const data_size_t feature_size,
const data_size_t* data_indices,
const data_size_t num_data,
const score_t* ordered_gradients,
#if CONST_HESSIAN == 0
const score_t* ordered_hessians,
#else
const score_t const_hessian,
#endif
char* __restrict__ output_buf,
volatile int * sync_counters,
acc_type* __restrict__ hist_buf_base,
const size_t power_feature_workgroups) {
#endif
// allocate the local memory array aligned with float2, to guarantee correct alignment on NVIDIA platforms
// otherwise a "Misaligned Address" exception may occur
__shared__ float2 shared_array[LOCAL_MEM_SIZE/sizeof(float2)];
const unsigned int gtid = blockIdx.x * blockDim.x + threadIdx.x;
const uint16_t ltid = threadIdx.x;
const uint16_t lsize = NUM_BINS; // get_local_size(0);
const uint16_t group_id = blockIdx.x;
// local memory per workgroup is 3 KB
// clear local memory
unsigned int *ptr = reinterpret_cast<unsigned int *>(shared_array);
for (int i = ltid; i < LOCAL_MEM_SIZE/sizeof(unsigned int); i += lsize) {
ptr[i] = 0;
}
__syncthreads();
// gradient/hessian histograms
// assume this starts at 32 * 4 = 128-byte boundary // What does it mean? boundary??
// total size: 2 * 256 * size_of(float) = 2 KB
// organization: each feature/grad/hessian is at a different bank,
// as indepedent of the feature value as possible
acc_type *gh_hist = reinterpret_cast<acc_type *>(shared_array);
// counter histogram
// total size: 256 * size_of(unsigned int) = 1 KB
unsigned int *cnt_hist = reinterpret_cast<unsigned int *>(gh_hist + 2 * NUM_BINS);
// odd threads (1, 3, ...) compute histograms for hessians first
// even thread (0, 2, ...) compute histograms for gradients first
// etc.
uchar is_hessian_first = ltid & 1;
uint16_t feature_id = group_id >> power_feature_workgroups;
// each 2^POWER_FEATURE_WORKGROUPS workgroups process on one feature (compile-time constant)
// feature_size is the number of examples per feature
const uchar *feature_data = feature_data_base + feature_id * feature_size;
// size of threads that process this feature4
const unsigned int subglobal_size = lsize * (1 << power_feature_workgroups);
// equavalent thread ID in this subgroup for this feature4
const unsigned int subglobal_tid = gtid - feature_id * subglobal_size;
data_size_t ind;
data_size_t ind_next;
#ifdef IGNORE_INDICES
ind = subglobal_tid;
#else
ind = data_indices[subglobal_tid];
#endif
// extract feature mask, when a byte is set to 0, that feature is disabled
uchar feature_mask = feature_masks[feature_id];
// exit if the feature is masked
if (!feature_mask) {
return;
} else {
feature_mask = feature_mask - 1; // feature_mask is used for get feature (1: 4bit feature, 0: 8bit feature)
}
// STAGE 1: read feature data, and gradient and hessian
// first half of the threads read feature data from global memory
// We will prefetch data into the "next" variable at the beginning of each iteration
uchar feature;
uchar feature_next;
uint16_t bin;
feature = feature_data[ind >> feature_mask];
if (feature_mask) {
feature = (feature >> ((ind & 1) << 2)) & 0xf;
}
bin = feature;
acc_type grad_bin = 0.0f, hess_bin = 0.0f;
acc_type *addr_bin;
// store gradient and hessian
score_t grad, hess;
score_t grad_next, hess_next;
grad = ordered_gradients[ind];
#if CONST_HESSIAN == 0
hess = ordered_hessians[ind];
#endif
// there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4
for (unsigned int i = subglobal_tid; i < num_data; i += subglobal_size) {
// prefetch the next iteration variables
// we don't need bondary check because we have made the buffer large
int i_next = i + subglobal_size;
#ifdef IGNORE_INDICES
// we need to check to bounds here
ind_next = i_next < num_data ? i_next : i;
#else
ind_next = data_indices[i_next];
#endif
grad_next = ordered_gradients[ind_next];
#if CONST_HESSIAN == 0
hess_next = ordered_hessians[ind_next];
#endif
// STAGE 2: accumulate gradient and hessian
if (bin != feature) {
addr_bin = gh_hist + bin * 2 + is_hessian_first;
#if CONST_HESSIAN == 0
acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin;
atomic_local_add_f(addr_bin, acc_bin);
addr_bin = addr_bin + 1 - 2 * is_hessian_first;
acc_bin = is_hessian_first ? grad_bin : hess_bin;
atomic_local_add_f(addr_bin, acc_bin);
#elif CONST_HESSIAN == 1
atomic_local_add_f(addr_bin, grad_bin);
#endif
bin = feature;
grad_bin = grad;
hess_bin = hess;
} else {
grad_bin += grad;
hess_bin += hess;
}
// prefetch the next iteration variables
feature_next = feature_data[ind_next >> feature_mask];
// STAGE 3: accumulate counter
atomicAdd(cnt_hist + feature, 1);
// STAGE 4: update next stat
grad = grad_next;
hess = hess_next;
if (!feature_mask) {
feature = feature_next;
} else {
feature = (feature_next >> ((ind_next & 1) << 2)) & 0xf;
}
}
addr_bin = gh_hist + bin * 2 + is_hessian_first;
#if CONST_HESSIAN == 0
acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin;
atomic_local_add_f(addr_bin, acc_bin);
addr_bin = addr_bin + 1 - 2 * is_hessian_first;
acc_bin = is_hessian_first ? grad_bin : hess_bin;
atomic_local_add_f(addr_bin, acc_bin);
#elif CONST_HESSIAN == 1
atomic_local_add_f(addr_bin, grad_bin);
#endif
__syncthreads();
#if CONST_HESSIAN == 1
// make a final reduction
gh_hist[ltid * 2] += gh_hist[ltid * 2 + 1];
gh_hist[ltid * 2 + 1] = const_hessian * cnt_hist[ltid]; // counter move to this position
__syncthreads();
#endif
#if POWER_FEATURE_WORKGROUPS != 0
acc_type *__restrict__ output = reinterpret_cast<acc_type *>(output_buf) + group_id * 3 * NUM_BINS;
// write gradients and hessians
acc_type *__restrict__ ptr_f = output;
for (uint16_t i = ltid; i < 2 * NUM_BINS; i += lsize) {
// even threads read gradients, odd threads read hessians
acc_type value = gh_hist[i];
ptr_f[(i & 1) * NUM_BINS + (i >> 1)] = value;
}
// write counts
acc_int_type *__restrict__ ptr_i = reinterpret_cast<acc_int_type *>(output + 2 * NUM_BINS);
for (uint16_t i = ltid; i < NUM_BINS; i += lsize) {
unsigned int value = cnt_hist[i];
ptr_i[i] = value;
}
__syncthreads();
__threadfence();
unsigned int * counter_val = cnt_hist;
// backup the old value
unsigned int old_val = *counter_val;
if (ltid == 0) {
// all workgroups processing the same feature add this counter
*counter_val = atomicAdd(const_cast<int*>(sync_counters + feature_id), 1);
}
// make sure everyone in this workgroup is here
__syncthreads();
// everyone in this workgroup: if we are the last workgroup, then do reduction!
if (*counter_val == (1 << power_feature_workgroups) - 1) {
if (ltid == 0) {
sync_counters[feature_id] = 0;
}
#else
}
// only 1 work group, no need to increase counter
// the reduction will become a simple copy
{
unsigned int old_val; // dummy
#endif
// locate our feature's block in output memory
unsigned int output_offset = (feature_id << power_feature_workgroups);
acc_type const * __restrict__ feature_subhists =
reinterpret_cast<acc_type *>(output_buf) + output_offset * 3 * NUM_BINS;
// skip reading the data already in local memory
unsigned int skip_id = group_id - output_offset;
// locate output histogram location for this feature4
acc_type *__restrict__ hist_buf = hist_buf_base + feature_id * 2 * NUM_BINS;
within_kernel_reduction256x4(feature_subhists, skip_id, old_val, 1 << power_feature_workgroups, hist_buf, reinterpret_cast<acc_type *>(shared_array), power_feature_workgroups);
}
}
// end of histogram256 stuff
} // namespace LightGBM
/*!
* Copyright (c) 2020 IBM Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_
#define LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_
#include "LightGBM/meta.h"
namespace LightGBM {
// use double precision or not
#ifndef USE_DP_FLOAT
#define USE_DP_FLOAT 1
#endif
// ignore hessian, and use the local memory for hessian as an additional bank for gradient
#ifndef CONST_HESSIAN
#define CONST_HESSIAN 0
#endif
typedef unsigned char uchar;
template<typename T>
__device__ double as_double(const T t) {
static_assert(sizeof(T) == sizeof(double), "size mismatch");
double d;
memcpy(&d, &t, sizeof(T));
return d;
}
template<typename T>
__device__ unsigned long long as_ulong_ulong(const T t) {
static_assert(sizeof(T) == sizeof(unsigned long long), "size mismatch");
unsigned long long u;
memcpy(&u, &t, sizeof(T));
return u;
}
template<typename T>
__device__ float as_float(const T t) {
static_assert(sizeof(T) == sizeof(float), "size mismatch");
float f;
memcpy(&f, &t, sizeof(T));
return f;
}
template<typename T>
__device__ unsigned int as_uint(const T t) {
static_assert(sizeof(T) == sizeof(unsigned int), "size_mismatch");
unsigned int u;
memcpy(&u, &t, sizeof(T));
return u;
}
template<typename T>
__device__ uchar4 as_uchar4(const T t) {
static_assert(sizeof(T) == sizeof(uchar4), "size mismatch");
uchar4 u;
memcpy(&u, &t, sizeof(T));
return u;
}
#if USE_DP_FLOAT == 1
typedef double acc_type;
typedef unsigned long long acc_int_type;
#define as_acc_type as_double
#define as_acc_int_type as_ulong_ulong
#else
typedef float acc_type;
typedef unsigned int acc_int_type;
#define as_acc_type as_float
#define as_acc_int_type as_uint
#endif
// use all features and do not use feature mask
#ifndef ENABLE_ALL_FEATURES
#define ENABLE_ALL_FEATURES 1
#endif
// define all of the different kernels
#define DECLARE_CONST_BUF(name) \
__global__ void name(__global const uchar* restrict feature_data_base, \
const uchar* restrict feature_masks,\
const data_size_t feature_size,\
const data_size_t* restrict data_indices, \
const data_size_t num_data, \
const score_t* restrict ordered_gradients, \
const score_t* restrict ordered_hessians,\
char* __restrict__ output_buf,\
volatile int * sync_counters,\
acc_type* __restrict__ hist_buf_base, \
const size_t power_feature_workgroups);
#define DECLARE_CONST_HES_CONST_BUF(name) \
__global__ void name(const uchar* __restrict__ feature_data_base, \
const uchar* __restrict__ feature_masks,\
const data_size_t feature_size,\
const data_size_t* __restrict__ data_indices, \
const data_size_t num_data, \
const score_t* __restrict__ ordered_gradients, \
const score_t const_hessian,\
char* __restrict__ output_buf,\
volatile int * sync_counters,\
acc_type* __restrict__ hist_buf_base, \
const size_t power_feature_workgroups);
#define DECLARE_CONST_HES(name) \
__global__ void name(const uchar* feature_data_base, \
const uchar* __restrict__ feature_masks,\
const data_size_t feature_size,\
const data_size_t* data_indices, \
const data_size_t num_data, \
const score_t* ordered_gradients, \
const score_t const_hessian,\
char* __restrict__ output_buf, \
volatile int * sync_counters,\
acc_type* __restrict__ hist_buf_base, \
const size_t power_feature_workgroups);
#define DECLARE(name) \
__global__ void name(const uchar* feature_data_base, \
const uchar* __restrict__ feature_masks,\
const data_size_t feature_size,\
const data_size_t* data_indices, \
const data_size_t num_data, \
const score_t* ordered_gradients, \
const score_t* ordered_hessians,\
char* __restrict__ output_buf, \
volatile int * sync_counters,\
acc_type* __restrict__ hist_buf_base, \
const size_t power_feature_workgroups);
DECLARE_CONST_HES(histogram16_allfeats);
DECLARE_CONST_HES(histogram16_fulldata);
DECLARE_CONST_HES(histogram16);
DECLARE(histogram16_allfeats);
DECLARE(histogram16_fulldata);
DECLARE(histogram16);
DECLARE_CONST_HES(histogram64_allfeats);
DECLARE_CONST_HES(histogram64_fulldata);
DECLARE_CONST_HES(histogram64);
DECLARE(histogram64_allfeats);
DECLARE(histogram64_fulldata);
DECLARE(histogram64);
DECLARE_CONST_HES(histogram256_allfeats);
DECLARE_CONST_HES(histogram256_fulldata);
DECLARE_CONST_HES(histogram256);
DECLARE(histogram256_allfeats);
DECLARE(histogram256_fulldata);
DECLARE(histogram256);
} // namespace LightGBM
#endif // LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_
......@@ -12,6 +12,7 @@
#include <memory>
#include <vector>
#include "cuda_tree_learner.h"
#include "gpu_tree_learner.h"
#include "serial_tree_learner.h"
......
......@@ -326,7 +326,16 @@ void SerialTreeLearner::FindBestSplits(const Tree* tree) {
is_feature_used[feature_index] = 1;
}
bool use_subtract = parent_leaf_histogram_array_ != nullptr;
#ifdef USE_CUDA
if (LGBM_config_::current_learner == use_cpu_learner) {
SerialTreeLearner::ConstructHistograms(is_feature_used, use_subtract);
} else {
ConstructHistograms(is_feature_used, use_subtract);
}
#else
ConstructHistograms(is_feature_used, use_subtract);
#endif
FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree);
}
......
......@@ -8,6 +8,7 @@
#include <LightGBM/dataset.h>
#include <LightGBM/tree.h>
#include <LightGBM/tree_learner.h>
#include <LightGBM/cuda/vector_cudahost.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/json11.h>
#include <LightGBM/utils/random.h>
......@@ -201,6 +202,11 @@ class SerialTreeLearner: public TreeLearner {
std::vector<score_t, boost::alignment::aligned_allocator<score_t, 4096>> ordered_gradients_;
/*! \brief hessians of current iteration, ordered for cache optimized, aligned to 4K page */
std::vector<score_t, boost::alignment::aligned_allocator<score_t, 4096>> ordered_hessians_;
#elif USE_CUDA
/*! \brief gradients of current iteration, ordered for cache optimized */
std::vector<score_t, CHAllocator<score_t>> ordered_gradients_;
/*! \brief hessians of current iteration, ordered for cache optimized */
std::vector<score_t, CHAllocator<score_t>> ordered_hessians_;
#else
/*! \brief gradients of current iteration, ordered for cache optimized */
std::vector<score_t, Common::AlignmentAllocator<score_t, kAlignedSize>> ordered_gradients_;
......
......@@ -4,6 +4,7 @@
*/
#include <LightGBM/tree_learner.h>
#include "cuda_tree_learner.h"
#include "gpu_tree_learner.h"
#include "parallel_tree_learner.h"
#include "serial_tree_learner.h"
......@@ -31,6 +32,16 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con
} else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<GPUTreeLearner>(config);
}
} else if (device_type == std::string("cuda")) {
if (learner_type == std::string("serial")) {
return new CUDATreeLearner(config);
} else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<CUDATreeLearner>(config);
} else if (learner_type == std::string("data")) {
return new DataParallelTreeLearner<CUDATreeLearner>(config);
} else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<CUDATreeLearner>(config);
}
}
return nullptr;
}
......
......@@ -454,6 +454,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf,
}
// instantiate template classes, otherwise linker cannot find the code
template class VotingParallelTreeLearner<CUDATreeLearner>;
template class VotingParallelTreeLearner<GPUTreeLearner>;
template class VotingParallelTreeLearner<SerialTreeLearner>;
} // namespace LightGBM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment