Commit 0bb4a825 authored by Huan Zhang's avatar Huan Zhang Committed by Guolin Ke
Browse files

Initial GPU acceleration support for LightGBM (#368)

* add dummy gpu solver code

* initial GPU code

* fix crash bug

* first working version

* use asynchronous copy

* use a better kernel for root

* parallel read histogram

* sparse features now works, but no acceleration, compute on CPU

* compute sparse feature on CPU simultaneously

* fix big bug; add gpu selection; add kernel selection

* better debugging

* clean up

* add feature scatter

* Add sparse_threshold control

* fix a bug in feature scatter

* clean up debug

* temporarily add OpenCL kernels for k=64,256

* fix up CMakeList and definition USE_GPU

* add OpenCL kernels as string literals

* Add boost.compute as a submodule

* add boost dependency into CMakeList

* fix opencl pragma

* use pinned memory for histogram

* use pinned buffer for gradients and hessians

* better debugging message

* add double precision support on GPU

* fix boost version in CMakeList

* Add a README

* reconstruct GPU initialization code for ResetTrainingData

* move data to GPU in parallel

* fix a bug during feature copy

* update gpu kernels

* update gpu code

* initial port to LightGBM v2

* speedup GPU data loading process

* Add 4-bit bin support to GPU

* re-add sparse_threshold parameter

* remove kMaxNumWorkgroups and allows an unlimited number of features

* add feature mask support for skipping unused features

* enable kernel cache

* use GPU kernels withoug feature masks when all features are used

* REAdme.

* REAdme.

* update README

* fix typos (#349)

* change compile to gcc on Apple as default

* clean vscode related file

* refine api of constructing from sampling data.

* fix bug in the last commit.

* more efficient algorithm to sample k from n.

* fix bug in filter bin

* change to boost from average output.

* fix tests.

* only stop training when all classes are finshed in multi-class.

* limit the max tree output. change hessian in multi-class objective.

* robust tree model loading.

* fix test.

* convert the probabilities to raw score in boost_from_average of classification.

* fix the average label for binary classification.

* Add boost_from_average to docs (#354)

* don't use "ConvertToRawScore" for self-defined objective function.

* boost_from_average seems doesn't work well in binary classification. remove it.

* For a better jump link (#355)

* Update Python-API.md

* for a better jump in page

A space is needed between `#` and the headers content according to Github's markdown format [guideline](https://guides.github.com/features/mastering-markdown/)

After adding the spaces, we can jump to the exact position in page by click the link.

* fixed something mentioned by @wxchan

* Update Python-API.md

* add FitByExistingTree.

* adapt GPU tree learner for FitByExistingTree

* avoid NaN output.

* update boost.compute

* fix typos (#361)

* fix broken links (#359)

* update README

* disable GPU acceleration by default

* fix image url

* cleanup debug macro

* remove old README

* do not save sparse_threshold_ in FeatureGroup

* add details for new GPU settings

* ignore submodule when doing pep8 check

* allocate workspace for at least one thread during builing Feature4

* move sparse_threshold to class Dataset

* remove duplicated code in GPUTreeLearner::Split

* Remove duplicated code in FindBestThresholds and BeforeFindBestSplit

* do not rebuild ordered gradients and hessians for sparse features

* support feature groups in GPUTreeLearner

* Initial parallel learners with GPU support

* add option device, cleanup code

* clean up FindBestThresholds; add some omp parallel

* constant hessian optimization for GPU

* Fix GPUTreeLearner crash when there is zero feature

* use np.testing.assert_almost_equal() to compare lists of floats in tests

* travis for GPU
parent db3d1f89
#ifndef LIGHTGBM_TREELEARNER_GPU_TREE_LEARNER_H_
#define LIGHTGBM_TREELEARNER_GPU_TREE_LEARNER_H_
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/dataset.h>
#include <LightGBM/tree.h>
#include <LightGBM/feature_group.h>
#include "feature_histogram.hpp"
#include "serial_tree_learner.h"
#include "data_partition.hpp"
#include "split_info.hpp"
#include "leaf_splits.hpp"
#include <cstdio>
#include <vector>
#include <random>
#include <cmath>
#include <memory>
#ifdef USE_GPU
#define BOOST_COMPUTE_THREAD_SAFE
#define BOOST_COMPUTE_HAVE_THREAD_LOCAL
// Use Boost.Compute on-disk kernel cache
#define BOOST_COMPUTE_USE_OFFLINE_CACHE
#include <boost/compute/core.hpp>
#include <boost/compute/container/vector.hpp>
#include <boost/align/aligned_allocator.hpp>
namespace LightGBM {
/*!
* \brief GPU-based parallel learning algorithm.
*/
class GPUTreeLearner: public SerialTreeLearner {
public:
explicit GPUTreeLearner(const TreeConfig* tree_config);
~GPUTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingData(const Dataset* train_data) override;
Tree* Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian) override;
void SetBaggingData(const data_size_t* used_indices, data_size_t num_data) override {
SerialTreeLearner::SetBaggingData(used_indices, num_data);
// determine if we are using bagging before we construct the data partition
// thus we can start data movement to GPU earlier
if (used_indices != nullptr) {
if (num_data != num_data_) {
use_bagging_ = true;
return;
}
}
use_bagging_ = false;
}
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
void FindBestThresholds() override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
private:
/*! \brief 4-byte feature tuple used by GPU kernels */
struct Feature4 {
union {
unsigned char s[4];
struct {
unsigned char s0;
unsigned char s1;
unsigned char s2;
unsigned char s3;
};
};
};
/*! \brief Single precision histogram entiry for GPU */
struct GPUHistogramBinEntry {
score_t sum_gradients;
score_t sum_hessians;
uint32_t cnt;
};
/*!
* \brief Find the best number of workgroups processing one feature for maximizing efficiency
* \param leaf_num_data The number of data examples on the current leaf being processed
* \return Log2 of the best number for workgroups per feature, in range 0...kMaxLogWorkgroupsPerFeature
*/
int GetNumWorkgroupsPerFeature(data_size_t leaf_num_data);
/*!
* \brief Initialize GPU device, context and command queues
* Also compiles the OpenCL kernel
* \param platform_id OpenCL platform ID
* \param device_id OpenCL device ID
*/
void InitGPU(int platform_id, int device_id);
/*!
* \brief Allocate memory for GPU computation
*/
void AllocateGPUMemory();
/*!
* \brief Compile OpenCL GPU source code to kernel binaries
*/
void BuildGPUKernels();
/*!
* \brief Setup GPU kernel arguments, preparing for launching
*/
void SetupKernelArguments();
/*!
* \brief Compute GPU feature histogram for the current leaf.
* Indices, gradients and hessians have been copied to the device.
* \param leaf_num_data Number of data on current leaf
* \param use_all_features Set to true to not use feature masks, with a faster kernel
*/
void GPUHistogram(data_size_t leaf_num_data, bool use_all_features);
/*!
* \brief Wait for GPU kernel execution and read histogram
* \param histograms Destination of histogram results from GPU.
* \param is_feature_used A predicate vector for enabling each feature
*/
template <typename HistType>
void WaitAndGetHistograms(HistogramBinEntry* histograms, const std::vector<int8_t>& is_feature_used);
/*!
* \brief Construct GPU histogram asynchronously.
* Interface is similar to Dataset::ConstructHistograms().
* \param is_feature_used A predicate vector for enabling each feature
* \param data_indices Array of data example IDs to be included in histogram, will be copied to GPU.
* Set to nullptr to skip copy to GPU.
* \param num_data Number of data examples to be included in histogram
* \param gradients Array of gradients for all examples.
* \param hessians Array of hessians for all examples.
* \param ordered_gradients Ordered gradients will be generated and copied to GPU when gradients is not nullptr,
* Set gradients to nullptr to skip copy to GPU.
* \param ordered_hessians Ordered hessians will be generated and copied to GPU when hessians is not nullptr,
* Set hessians to nullptr to skip copy to GPU.
* \return true if GPU kernel is launched, false if GPU is not used
*/
bool ConstructGPUHistogramsAsync(
const std::vector<int8_t>& is_feature_used,
const data_size_t* data_indices, data_size_t num_data,
const score_t* gradients, const score_t* hessians,
score_t* ordered_gradients, score_t* ordered_hessians);
/*! brief Log2 of max number of workgroups per feature*/
const int kMaxLogWorkgroupsPerFeature = 10; // 2^10
/*! brief Max total number of workgroups with preallocated workspace.
* If we use more than this number of workgroups, we have to reallocate subhistograms */
int preallocd_max_num_wg_ = 1024;
/*! \brief True if bagging is used */
bool use_bagging_;
/*! \brief GPU device object */
boost::compute::device dev_;
/*! \brief GPU context object */
boost::compute::context ctx_;
/*! \brief GPU command queue object */
boost::compute::command_queue queue_;
/*! \brief GPU kernel for 256 bins */
const char *kernel256_src_ =
#include "ocl/histogram256.cl"
;
/*! \brief GPU kernel for 64 bins */
const char *kernel64_src_ =
#include "ocl/histogram64.cl"
;
/*! \brief GPU kernel for 64 bins */
const char *kernel16_src_ =
#include "ocl/histogram16.cl"
;
/*! \brief Currently used kernel source */
std::string kernel_source_;
/*! \brief Currently used kernel name */
std::string kernel_name_;
/*! \brief a array of histogram kernels with different number
of workgroups per feature */
std::vector<boost::compute::kernel> histogram_kernels_;
/*! \brief a array of histogram kernels with different number
of workgroups per feature, with all features enabled to avoid branches */
std::vector<boost::compute::kernel> histogram_allfeats_kernels_;
/*! \brief a array of histogram kernels with different number
of workgroups per feature, and processing the whole dataset */
std::vector<boost::compute::kernel> histogram_fulldata_kernels_;
/*! \brief total number of feature-groups */
int num_feature_groups_;
/*! \brief total number of dense feature-groups, which will be processed on GPU */
int num_dense_feature_groups_;
/*! \brief On GPU we read one DWORD (4-byte) of features of one example once.
* With bin size > 16, there are 4 features per DWORD.
* With bin size <=16, there are 8 features per DWORD.
* */
int dword_features_;
/*! \brief total number of dense feature-group tuples on GPU.
* Each feature tuple is 4-byte (4 features if each feature takes a byte) */
int num_dense_feature4_;
/*! \brief Max number of bins of training data, used to determine
* which GPU kernel to use */
int max_num_bin_;
/*! \brief Used GPU kernel bin size (64, 256) */
int device_bin_size_;
/*! \brief Size of histogram bin entry, depending if single or double precision is used */
size_t hist_bin_entry_sz_;
/*! \brief Indices of all dense feature-groups */
std::vector<int> dense_feature_group_map_;
/*! \brief Indices of all sparse feature-groups */
std::vector<int> sparse_feature_group_map_;
/*! \brief Multipliers of all dense feature-groups, used for redistributing bins */
std::vector<int> device_bin_mults_;
/*! \brief GPU memory object holding the training data */
std::unique_ptr<boost::compute::vector<Feature4>> device_features_;
/*! \brief GPU memory object holding the ordered gradient */
boost::compute::buffer device_gradients_;
/*! \brief Pinned memory object for ordered gradient */
boost::compute::buffer pinned_gradients_;
/*! \brief Pointer to pinned memory of ordered gradient */
void * ptr_pinned_gradients_ = nullptr;
/*! \brief GPU memory object holding the ordered hessian */
boost::compute::buffer device_hessians_;
/*! \brief Pinned memory object for ordered hessian */
boost::compute::buffer pinned_hessians_;
/*! \brief Pointer to pinned memory of ordered hessian */
void * ptr_pinned_hessians_ = nullptr;
/*! \brief A vector of feature mask. 1 = feature used, 0 = feature not used */
std::vector<char, boost::alignment::aligned_allocator<char, 4096>> feature_masks_;
/*! \brief GPU memory object holding the feature masks */
boost::compute::buffer device_feature_masks_;
/*! \brief Pinned memory object for feature masks */
boost::compute::buffer pinned_feature_masks_;
/*! \brief Pointer to pinned memory of feature masks */
void * ptr_pinned_feature_masks_ = nullptr;
/*! \brief GPU memory object holding indices of the leaf being processed */
std::unique_ptr<boost::compute::vector<data_size_t>> device_data_indices_;
/*! \brief GPU memory object holding counters for workgroup coordination */
std::unique_ptr<boost::compute::vector<int>> sync_counters_;
/*! \brief GPU memory object holding temporary sub-histograms per workgroup */
std::unique_ptr<boost::compute::vector<char>> device_subhistograms_;
/*! \brief Host memory object for histogram output (GPU will write to Host memory directly) */
boost::compute::buffer device_histogram_outputs_;
/*! \brief Host memory pointer for histogram outputs */
void * host_histogram_outputs_;
/*! \brief OpenCL waitlist object for waiting for data transfer before kernel execution */
boost::compute::wait_list kernel_wait_obj_;
/*! \brief OpenCL waitlist object for reading output histograms after kernel execution */
boost::compute::wait_list histograms_wait_obj_;
/*! \brief Asynchronous waiting object for copying indices */
boost::compute::future<void> indices_future_;
/*! \brief Asynchronous waiting object for copying gradients */
boost::compute::event gradients_future_;
/*! \brief Asynchronous waiting object for copying hessians */
boost::compute::event hessians_future_;
};
} // namespace LightGBM
#else
// When GPU support is not compiled in, quit with an error message
namespace LightGBM {
class GPUTreeLearner: public SerialTreeLearner {
public:
explicit GPUTreeLearner(const TreeConfig* tree_config) : SerialTreeLearner(tree_config) {
Log::Fatal("GPU Tree Learner was not enabled in this build. Recompile with CMake option -DUSE_GPU=1");
}
};
}
#endif // USE_GPU
#endif // LightGBM_TREELEARNER_GPU_TREE_LEARNER_H_
// this file can either be read and passed to an OpenCL compiler directly,
// or included in a C++11 source file as a string literal
#ifndef __OPENCL_VERSION__
// If we are including this file in C++,
// the entire source file following (except the last #endif) will become
// a raw string literal. The extra ")" is just for mathcing parentheses
// to make the editor happy. The extra ")" and extra endif will be skipped.
// DO NOT add anything between here and the next #ifdef, otherwise you need
// to modify the skip count at the end of this file.
R""()
#endif
#ifndef _HISTOGRAM_16_KERNEL_
#define _HISTOGRAM_16_KERNEL_
#pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
// Configurable options:
// NUM_BANKS should be a power of 2
#ifndef NUM_BANKS
#define NUM_BANKS 8
#endif
// how many bits in thread ID represent the bank = log2(NUM_BANKS)
#ifndef BANK_BITS
#define BANK_BITS 3
#endif
// use double precision or not
#ifndef USE_DP_FLOAT
#define USE_DP_FLOAT 0
#endif
// ignore hessian, and use the local memory for hessian as an additional bank for gradient
#ifndef CONST_HESSIAN
#define CONST_HESSIAN 0
#endif
#define LOCAL_SIZE_0 256
#define NUM_BINS 16
// if USE_DP_FLOAT is set to 1, we will use double precision for the accumulator
#if USE_DP_FLOAT == 1
#pragma OPENCL EXTENSION cl_khr_fp64 : enable
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
typedef double acc_type;
typedef ulong acc_int_type;
#define as_acc_type as_double
#define as_acc_int_type as_ulong
#else
typedef float acc_type;
typedef uint acc_int_type;
#define as_acc_type as_float
#define as_acc_int_type as_uint
#endif
// number of features to process in a 4-byte feature tuple
#define DWORD_FEATURES 8
// number of bits per feature
#define FEATURE_BITS (sizeof(uchar4) * 8 / DWORD_FEATURES)
// bit mask for number of features to process in a 4-byte feature tuple
#define DWORD_FEATURES_MASK (DWORD_FEATURES - 1)
// log2 of number of features to process in a 4-byte feature tuple
#define LOG2_DWORD_FEATURES 3
// mask for getting the bank ID
#define BANK_MASK (NUM_BANKS - 1)
// 8 features, each has a gradient and a hessian
#define HG_BIN_MULT (NUM_BANKS * DWORD_FEATURES * 2)
// 8 features, each has a counter
#define CNT_BIN_MULT (NUM_BANKS * DWORD_FEATURES)
// local memory size in bytes
#define LOCAL_MEM_SIZE (DWORD_FEATURES * (sizeof(uint) + 2 * sizeof(acc_type)) * NUM_BINS * NUM_BANKS)
// unroll the atomic operation for a few times. Takes more code space,
// but compiler can generate better code for faster atomics.
#define UNROLL_ATOMIC 1
// Options passed by compiler at run time:
// IGNORE_INDICES will be set when the kernel does not
// #define IGNORE_INDICES
// #define POWER_FEATURE_WORKGROUPS 10
// use all features and do not use feature mask
#ifndef ENABLE_ALL_FEATURES
#define ENABLE_ALL_FEATURES 1
#endif
// detect Nvidia platforms
#ifdef cl_nv_pragma_unroll
#define NVIDIA 1
#endif
// use binary patching for AMD GCN 1.2 or newer
#ifndef AMD_USE_DS_ADD_F32
#define AMD_USE_DS_ADD_F32 0
#endif
typedef uint data_size_t;
typedef float score_t;
#define ATOMIC_FADD_SUB1 { \
expected.f_val = current.f_val; \
next.f_val = expected.f_val + val; \
current.u_val = atom_cmpxchg((volatile __local acc_int_type *)addr, expected.u_val, next.u_val); \
if (current.u_val == expected.u_val) \
goto end; \
}
#define ATOMIC_FADD_SUB2 ATOMIC_FADD_SUB1 \
ATOMIC_FADD_SUB1
#define ATOMIC_FADD_SUB4 ATOMIC_FADD_SUB2 \
ATOMIC_FADD_SUB2
#define ATOMIC_FADD_SUB8 ATOMIC_FADD_SUB4 \
ATOMIC_FADD_SUB4
#define ATOMIC_FADD_SUB16 ATOMIC_FADD_SUB8 \
ATOMIC_FADD_SUB8
#define ATOMIC_FADD_SUB32 ATOMIC_FADD_SUB16\
ATOMIC_FADD_SUB16
#define ATOMIC_FADD_SUB64 ATOMIC_FADD_SUB32\
ATOMIC_FADD_SUB32
// atomic add for float number in local memory
inline void atomic_local_add_f(__local acc_type *addr, const float val)
{
union{
acc_int_type u_val;
acc_type f_val;
} next, expected, current;
#if (NVIDIA == 1 && USE_DP_FLOAT == 0)
float res = 0;
asm volatile ("atom.shared.add.f32 %0, [%1], %2;" : "=f"(res) : "l"(addr), "f"(val));
#elif (AMD_USE_DS_ADD_F32 == 1 && USE_DP_FLAT == 0)
// this instruction (DS_AND_U32) will be patched into a DS_ADD_F32
// we need to hack here because DS_ADD_F32 is not exposed via OpenCL
atom_and((__local acc_int_type *)addr, as_acc_int_type(val));
#else
current.f_val = *addr;
#if UNROLL_ATOMIC == 1
// provide a fast path
// then do the complete loop
// this should work on all devices
ATOMIC_FADD_SUB8
ATOMIC_FADD_SUB4
#endif
do {
expected.f_val = current.f_val;
next.f_val = expected.f_val + val;
current.u_val = atom_cmpxchg((volatile __local acc_int_type *)addr, expected.u_val, next.u_val);
} while (current.u_val != expected.u_val);
end:
;
#endif
}
// this function will be called by histogram16
// we have one sub-histogram of one feature in registers, and need to read others
void within_kernel_reduction16x8(uchar8 feature_mask,
__global const acc_type* restrict feature4_sub_hist,
const uint skip_id,
acc_type stat_val, uint cnt_val,
const ushort num_sub_hist,
__global acc_type* restrict output_buf,
__local acc_type * restrict local_hist) {
const ushort ltid = get_local_id(0); // range 0 - 255
const ushort lsize = LOCAL_SIZE_0;
ushort feature_id = ltid & DWORD_FEATURES_MASK; // range 0 - 7
uchar is_hessian_first = (ltid >> LOG2_DWORD_FEATURES) & 1; // hessian or gradient
ushort bin_id = ltid >> (LOG2_DWORD_FEATURES + 1); // range 0 - 16
ushort i;
#if POWER_FEATURE_WORKGROUPS != 0
// if there is only 1 work group, no need to do the reduction
// add all sub-histograms for 4 features
__global const acc_type* restrict p = feature4_sub_hist + ltid;
for (i = 0; i < skip_id; ++i) {
// 256 threads working on 8 features' 16 bins, gradient and hessian
stat_val += *p;
p += NUM_BINS * DWORD_FEATURES * 2;
if (ltid < LOCAL_SIZE_0 / 2) {
cnt_val += as_acc_int_type(*p);
}
p += NUM_BINS * DWORD_FEATURES;
}
// skip the counters we already have
p += 3 * DWORD_FEATURES * NUM_BINS;
for (i = i + 1; i < num_sub_hist; ++i) {
stat_val += *p;
p += NUM_BINS * DWORD_FEATURES * 2;
if (ltid < LOCAL_SIZE_0 / 2) {
cnt_val += as_acc_int_type(*p);
}
p += NUM_BINS * DWORD_FEATURES;
}
#endif
// printf("thread %d:feature=%d, bin_id=%d, hessian=%d, stat_val=%f, cnt=%d", ltid, feature_id, bin_id, is_hessian_first, stat_val, cnt_val);
// now overwrite the local_hist for final reduction and output
// reverse the f7...f0 order to match the real order
feature_id = DWORD_FEATURES_MASK - feature_id;
local_hist[feature_id * 3 * NUM_BINS + bin_id * 3 + is_hessian_first] = stat_val;
bin_id = ltid >> (LOG2_DWORD_FEATURES); // range 0 - 16, for counter
if (ltid < LOCAL_SIZE_0 / 2) {
local_hist[feature_id * 3 * NUM_BINS + bin_id * 3 + 2] = as_acc_type((acc_int_type)cnt_val);
}
barrier(CLK_LOCAL_MEM_FENCE);
for (i = ltid; i < DWORD_FEATURES * 3 * NUM_BINS; i += lsize) {
output_buf[i] = local_hist[i];
}
}
__attribute__((reqd_work_group_size(LOCAL_SIZE_0, 1, 1)))
#if USE_CONSTANT_BUF == 1
__kernel void histogram16(__global const uchar4* restrict feature_data_base,
__constant const uchar8* restrict feature_masks __attribute__((max_constant_size(65536))),
const data_size_t feature_size,
__constant const data_size_t* restrict data_indices __attribute__((max_constant_size(65536))),
const data_size_t num_data,
__constant const score_t* restrict ordered_gradients __attribute__((max_constant_size(65536))),
#if CONST_HESSIAN == 0
__constant const score_t* restrict ordered_hessians __attribute__((max_constant_size(65536))),
#else
const score_t const_hessian,
#endif
__global char* restrict output_buf,
__global volatile int * sync_counters,
__global acc_type* restrict hist_buf_base) {
#else
__kernel void histogram16(__global const uchar4* feature_data_base,
__constant const uchar8* restrict feature_masks __attribute__((max_constant_size(65536))),
const data_size_t feature_size,
__global const data_size_t* data_indices,
const data_size_t num_data,
__global const score_t* ordered_gradients,
#if CONST_HESSIAN == 0
__global const score_t* ordered_hessians,
#else
const score_t const_hessian,
#endif
__global char* restrict output_buf,
__global volatile int * sync_counters,
__global acc_type* restrict hist_buf_base) {
#endif
// allocate the local memory array aligned with float2, to guarantee correct alignment on NVIDIA platforms
// otherwise a "Misaligned Address" exception may occur
__local float2 shared_array[LOCAL_MEM_SIZE/sizeof(float2)];
const uint gtid = get_global_id(0);
const uint gsize = get_global_size(0);
const ushort ltid = get_local_id(0);
const ushort lsize = LOCAL_SIZE_0; // get_local_size(0);
const ushort group_id = get_group_id(0);
// local memory per workgroup is 12 KB
// clear local memory
__local uint * ptr = (__local uint *) shared_array;
for (int i = ltid; i < LOCAL_MEM_SIZE/sizeof(uint); i += lsize) {
ptr[i] = 0;
}
barrier(CLK_LOCAL_MEM_FENCE);
// gradient/hessian histograms
// assume this starts at 32 * 4 = 128-byte boundary
// each bank: 2 * 8 * 16 * size_of(float) = 1 KB
// there are 8 banks (sub-histograms) used by 256 threads total 8 KB
/* memory layout of gh_hist:
-----------------------------------------------------------------------------------------------
bk0_g_f0_bin0 bk0_g_f1_bin0 bk0_g_f2_bin0 bk0_g_f3_bin0 bk0_g_f4_bin0 bk0_g_f5_bin0 bk0_g_f6_bin0 bk0_g_f7_bin0
bk0_h_f0_bin0 bk0_h_f1_bin0 bk0_h_f2_bin0 bk0_h_f3_bin0 bk0_h_f4_bin0 bk0_h_f5_bin0 bk0_h_f6_bin0 bk0_h_f7_bin0
bk1_g_f0_bin0 bk1_g_f1_bin0 bk1_g_f2_bin0 bk1_g_f3_bin0 bk1_g_f4_bin0 bk1_g_f5_bin0 bk1_g_f6_bin0 bk1_g_f7_bin0
bk1_h_f0_bin0 bk1_h_f1_bin0 bk1_h_f2_bin0 bk1_h_f3_bin0 bk1_h_f4_bin0 bk1_h_f5_bin0 bk1_h_f6_bin0 bk1_h_f7_bin0
bk2_g_f0_bin0 bk2_g_f1_bin0 bk2_g_f2_bin0 bk2_g_f3_bin0 bk2_g_f4_bin0 bk2_g_f5_bin0 bk2_g_f6_bin0 bk2_g_f7_bin0
bk2_h_f0_bin0 bk2_h_f1_bin0 bk2_h_f2_bin0 bk2_h_f3_bin0 bk2_h_f4_bin0 bk2_h_f5_bin0 bk2_h_f6_bin0 bk2_h_f7_bin0
bk3_g_f0_bin0 bk3_g_f1_bin0 bk3_g_f2_bin0 bk3_g_f3_bin0 bk3_g_f4_bin0 bk3_g_f5_bin0 bk3_g_f6_bin0 bk3_g_f7_bin0
bk3_h_f0_bin0 bk3_h_f1_bin0 bk3_h_f2_bin0 bk3_h_f3_bin0 bk3_h_f4_bin0 bk3_h_f5_bin0 bk3_h_f6_bin0 bk3_h_f7_bin0
bk4_g_f0_bin0 bk4_g_f1_bin0 bk4_g_f2_bin0 bk4_g_f3_bin0 bk4_g_f4_bin0 bk4_g_f5_bin0 bk4_g_f6_bin0 bk4_g_f7_bin0
bk4_h_f0_bin0 bk4_h_f1_bin0 bk4_h_f2_bin0 bk4_h_f3_bin0 bk4_h_f4_bin0 bk4_h_f5_bin0 bk4_h_f6_bin0 bk4_h_f7_bin0
bk5_g_f0_bin0 bk5_g_f1_bin0 bk5_g_f2_bin0 bk5_g_f3_bin0 bk5_g_f4_bin0 bk5_g_f5_bin0 bk5_g_f6_bin0 bk5_g_f7_bin0
bk5_h_f0_bin0 bk5_h_f1_bin0 bk5_h_f2_bin0 bk5_h_f3_bin0 bk5_h_f4_bin0 bk5_h_f5_bin0 bk5_h_f6_bin0 bk5_h_f7_bin0
bk6_g_f0_bin0 bk6_g_f1_bin0 bk6_g_f2_bin0 bk6_g_f3_bin0 bk6_g_f4_bin0 bk6_g_f5_bin0 bk6_g_f6_bin0 bk6_g_f7_bin0
bk6_h_f0_bin0 bk6_h_f1_bin0 bk6_h_f2_bin0 bk6_h_f3_bin0 bk6_h_f4_bin0 bk6_h_f5_bin0 bk6_h_f6_bin0 bk6_h_f7_bin0
bk7_g_f0_bin0 bk7_g_f1_bin0 bk7_g_f2_bin0 bk7_g_f3_bin0 bk7_g_f4_bin0 bk7_g_f5_bin0 bk7_g_f6_bin0 bk7_g_f7_bin0
bk7_h_f0_bin0 bk7_h_f1_bin0 bk7_h_f2_bin0 bk7_h_f3_bin0 bk7_h_f4_bin0 bk7_h_f5_bin0 bk7_h_f6_bin0 bk7_h_f7_bin0
...
bk0_g_f0_bin16 bk0_g_f1_bin16 bk0_g_f2_bin16 bk0_g_f3_bin16 bk0_g_f4_bin16 bk0_g_f5_bin16 bk0_g_f6_bin16 bk0_g_f7_bin16
bk0_h_f0_bin16 bk0_h_f1_bin16 bk0_h_f2_bin16 bk0_h_f3_bin16 bk0_h_f4_bin16 bk0_h_f5_bin16 bk0_h_f6_bin16 bk0_h_f7_bin16
bk1_g_f0_bin16 bk1_g_f1_bin16 bk1_g_f2_bin16 bk1_g_f3_bin16 bk1_g_f4_bin16 bk1_g_f5_bin16 bk1_g_f6_bin16 bk1_g_f7_bin16
bk1_h_f0_bin16 bk1_h_f1_bin16 bk1_h_f2_bin16 bk1_h_f3_bin16 bk1_h_f4_bin16 bk1_h_f5_bin16 bk1_h_f6_bin16 bk1_h_f7_bin16
bk2_g_f0_bin16 bk2_g_f1_bin16 bk2_g_f2_bin16 bk2_g_f3_bin16 bk2_g_f4_bin16 bk2_g_f5_bin16 bk2_g_f6_bin16 bk2_g_f7_bin16
bk2_h_f0_bin16 bk2_h_f1_bin16 bk2_h_f2_bin16 bk2_h_f3_bin16 bk2_h_f4_bin16 bk2_h_f5_bin16 bk2_h_f6_bin16 bk2_h_f7_bin16
bk3_g_f0_bin16 bk3_g_f1_bin16 bk3_g_f2_bin16 bk3_g_f3_bin16 bk3_g_f4_bin16 bk3_g_f5_bin16 bk3_g_f6_bin16 bk3_g_f7_bin16
bk3_h_f0_bin16 bk3_h_f1_bin16 bk3_h_f2_bin16 bk3_h_f3_bin16 bk3_h_f4_bin16 bk3_h_f5_bin16 bk3_h_f6_bin16 bk3_h_f7_bin16
bk4_g_f0_bin16 bk4_g_f1_bin16 bk4_g_f2_bin16 bk4_g_f3_bin16 bk4_g_f4_bin16 bk4_g_f5_bin16 bk4_g_f6_bin16 bk4_g_f7_bin16
bk4_h_f0_bin16 bk4_h_f1_bin16 bk4_h_f2_bin16 bk4_h_f3_bin16 bk4_h_f4_bin16 bk4_h_f5_bin16 bk4_h_f6_bin16 bk4_h_f7_bin16
bk5_g_f0_bin16 bk5_g_f1_bin16 bk5_g_f2_bin16 bk5_g_f3_bin16 bk5_g_f4_bin16 bk5_g_f5_bin16 bk5_g_f6_bin16 bk5_g_f7_bin16
bk5_h_f0_bin16 bk5_h_f1_bin16 bk5_h_f2_bin16 bk5_h_f3_bin16 bk5_h_f4_bin16 bk5_h_f5_bin16 bk5_h_f6_bin16 bk5_h_f7_bin16
bk6_g_f0_bin16 bk6_g_f1_bin16 bk6_g_f2_bin16 bk6_g_f3_bin16 bk6_g_f4_bin16 bk6_g_f5_bin16 bk6_g_f6_bin16 bk6_g_f7_bin16
bk6_h_f0_bin16 bk6_h_f1_bin16 bk6_h_f2_bin16 bk6_h_f3_bin16 bk6_h_f4_bin16 bk6_h_f5_bin16 bk6_h_f6_bin16 bk6_h_f7_bin16
bk7_g_f0_bin16 bk7_g_f1_bin16 bk7_g_f2_bin16 bk7_g_f3_bin16 bk7_g_f4_bin16 bk7_g_f5_bin16 bk7_g_f6_bin16 bk7_g_f7_bin16
bk7_h_f0_bin16 bk7_h_f1_bin16 bk7_h_f2_bin16 bk7_h_f3_bin16 bk7_h_f4_bin16 bk7_h_f5_bin16 bk7_h_f6_bin16 bk7_h_f7_bin16
-----------------------------------------------------------------------------------------------
*/
// with this organization, the LDS/shared memory bank is independent of the bin value
// all threads within a quarter-wavefront (half-warp) will not have any bank conflict
__local acc_type * gh_hist = (__local acc_type *)shared_array;
// counter histogram
// each bank: 8 * 16 * size_of(uint) = 0.5 KB
// there are 8 banks used by 256 threads total 4 KB
/* memory layout in cnt_hist:
-----------------------------------------------
bk0_c_f0_bin0 bk0_c_f1_bin0 bk0_c_f2_bin0 bk0_c_f3_bin0 bk0_c_f4_bin0 bk0_c_f5_bin0 bk0_c_f6_bin0 bk0_c_f7_bin0
bk1_c_f0_bin0 bk1_c_f1_bin0 bk1_c_f2_bin0 bk1_c_f3_bin0 bk1_c_f4_bin0 bk1_c_f5_bin0 bk1_c_f6_bin0 bk1_c_f7_bin0
bk2_c_f0_bin0 bk2_c_f1_bin0 bk2_c_f2_bin0 bk2_c_f3_bin0 bk2_c_f4_bin0 bk2_c_f5_bin0 bk2_c_f6_bin0 bk2_c_f7_bin0
bk3_c_f0_bin0 bk3_c_f1_bin0 bk3_c_f2_bin0 bk3_c_f3_bin0 bk3_c_f4_bin0 bk3_c_f5_bin0 bk3_c_f6_bin0 bk3_c_f7_bin0
bk4_c_f0_bin0 bk4_c_f1_bin0 bk4_c_f2_bin0 bk4_c_f3_bin0 bk4_c_f4_bin0 bk4_c_f5_bin0 bk4_c_f6_bin0 bk4_c_f7_bin0
bk5_c_f0_bin0 bk5_c_f1_bin0 bk5_c_f2_bin0 bk5_c_f3_bin0 bk5_c_f4_bin0 bk5_c_f5_bin0 bk5_c_f6_bin0 bk5_c_f7_bin0
bk6_c_f0_bin0 bk6_c_f1_bin0 bk6_c_f2_bin0 bk6_c_f3_bin0 bk6_c_f4_bin0 bk6_c_f5_bin0 bk6_c_f6_bin0 bk6_c_f7_bin0
bk7_c_f0_bin0 bk7_c_f1_bin0 bk7_c_f2_bin0 bk7_c_f3_bin0 bk7_c_f4_bin0 bk7_c_f5_bin0 bk7_c_f6_bin0 bk7_c_f7_bin0
...
bk0_c_f0_bin16 bk0_c_f1_bin16 bk0_c_f2_bin16 bk0_c_f3_bin16 bk0_c_f4_bin16 bk0_c_f5_bin16 bk0_c_f6_bin16 bk0_c_f7_bin0
bk1_c_f0_bin16 bk1_c_f1_bin16 bk1_c_f2_bin16 bk1_c_f3_bin16 bk1_c_f4_bin16 bk1_c_f5_bin16 bk1_c_f6_bin16 bk1_c_f7_bin0
bk2_c_f0_bin16 bk2_c_f1_bin16 bk2_c_f2_bin16 bk2_c_f3_bin16 bk2_c_f4_bin16 bk2_c_f5_bin16 bk2_c_f6_bin16 bk2_c_f7_bin0
bk3_c_f0_bin16 bk3_c_f1_bin16 bk3_c_f2_bin16 bk3_c_f3_bin16 bk3_c_f4_bin16 bk3_c_f5_bin16 bk3_c_f6_bin16 bk3_c_f7_bin0
bk4_c_f0_bin16 bk4_c_f1_bin16 bk4_c_f2_bin16 bk4_c_f3_bin16 bk4_c_f4_bin16 bk4_c_f5_bin16 bk4_c_f6_bin16 bk4_c_f7_bin0
bk5_c_f0_bin16 bk5_c_f1_bin16 bk5_c_f2_bin16 bk5_c_f3_bin16 bk5_c_f4_bin16 bk5_c_f5_bin16 bk5_c_f6_bin16 bk5_c_f7_bin0
bk6_c_f0_bin16 bk6_c_f1_bin16 bk6_c_f2_bin16 bk6_c_f3_bin16 bk6_c_f4_bin16 bk6_c_f5_bin16 bk6_c_f6_bin16 bk6_c_f7_bin0
bk7_c_f0_bin16 bk7_c_f1_bin16 bk7_c_f2_bin16 bk7_c_f3_bin16 bk7_c_f4_bin16 bk7_c_f5_bin16 bk7_c_f6_bin16 bk7_c_f7_bin0
-----------------------------------------------
*/
__local uint * cnt_hist = (__local uint *)(gh_hist + 2 * DWORD_FEATURES * NUM_BINS * NUM_BANKS);
// thread 0, 1, 2, 3, 4, 5, 6, 7 compute histograms for gradients first
// thread 8, 9, 10, 11, 12, 13, 14, 15 compute histograms for hessians first
// etc.
uchar is_hessian_first = (ltid >> LOG2_DWORD_FEATURES) & 1;
// thread 0-15 write result to bank0, 16-31 to bank1, 32-47 to bank2, 48-63 to bank3, etc
ushort bank = (ltid >> (LOG2_DWORD_FEATURES + 1)) & BANK_MASK;
ushort group_feature = group_id >> POWER_FEATURE_WORKGROUPS;
// each 2^POWER_FEATURE_WORKGROUPS workgroups process on one feature (compile-time constant)
// feature_size is the number of examples per feature
__global const uchar4* feature_data = feature_data_base + group_feature * feature_size;
// size of threads that process this feature4
const uint subglobal_size = lsize * (1 << POWER_FEATURE_WORKGROUPS);
// equavalent thread ID in this subgroup for this feature4
const uint subglobal_tid = gtid - group_feature * subglobal_size;
// extract feature mask, when a byte is set to 0, that feature is disabled
#if ENABLE_ALL_FEATURES == 1
// hopefully the compiler will propogate the constants and eliminate all branches
uchar8 feature_mask = (uchar8)(0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff);
#else
uchar8 feature_mask = feature_masks[group_feature];
#endif
// exit if all features are masked
if (!as_ulong(feature_mask)) {
return;
}
// STAGE 1: read feature data, and gradient and hessian
// first half of the threads read feature data from global memory
// 4 features stored in a tuple MSB...(0, 1, 2, 3)...LSB
// We will prefetch data into the "next" variable at the beginning of each iteration
uchar4 feature4;
uchar4 feature4_next;
// offset used to rotate feature4 vector, & 0x7
ushort offset = (ltid & DWORD_FEATURES_MASK);
#if ENABLE_ALL_FEATURES == 0
// rotate feature_mask to match the feature order of each thread
feature_mask = as_uchar8(rotate(as_ulong(feature_mask), (ulong)offset*8));
#endif
// store gradient and hessian
float stat1, stat2;
float stat1_next, stat2_next;
ushort bin, addr, addr2;
data_size_t ind;
data_size_t ind_next;
stat1 = ordered_gradients[subglobal_tid];
#if CONST_HESSIAN == 0
stat2 = ordered_hessians[subglobal_tid];
#endif
#ifdef IGNORE_INDICES
ind = subglobal_tid;
#else
ind = data_indices[subglobal_tid];
#endif
feature4 = feature_data[ind];
// there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4
for (uint i = subglobal_tid; i < num_data; i += subglobal_size) {
// prefetch the next iteration variables
// we don't need bondary check because we have made the buffer larger
stat1_next = ordered_gradients[i + subglobal_size];
#if CONST_HESSIAN == 0
stat2_next = ordered_hessians[i + subglobal_size];
#endif
#ifdef IGNORE_INDICES
// we need to check to bounds here
ind_next = i + subglobal_size < num_data ? i + subglobal_size : i;
// start load next feature as early as possible
feature4_next = feature_data[ind_next];
#else
ind_next = data_indices[i + subglobal_size];
#endif
#if CONST_HESSIAN == 0
// swap gradient and hessian for threads 8, 9, 10, 11, 12, 13, 14, 15
float tmp = stat1;
stat1 = is_hessian_first ? stat2 : stat1;
stat2 = is_hessian_first ? tmp : stat2;
// stat1 = select(stat1, stat2, is_hessian_first);
// stat2 = select(stat2, tmp, is_hessian_first);
#endif
// STAGE 2: accumulate gradient and hessian
offset = (ltid & DWORD_FEATURES_MASK);
// printf("thread %x, %08x -> %08x", ltid, as_uint(feature4), rotate(as_uint(feature4), (uint)(offset * FEATURE_BITS)));
feature4 = as_uchar4(rotate(as_uint(feature4), (uint)(offset * FEATURE_BITS)));
if (feature_mask.s7) {
bin = feature4.s3 >> 4;
addr = bin * HG_BIN_MULT + bank * 2 * DWORD_FEATURES + is_hessian_first * DWORD_FEATURES + offset;
addr2 = addr + DWORD_FEATURES - 2 * DWORD_FEATURES * is_hessian_first;
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 0, 1, 2, 3, 4, 5, 6 ,7's gradients for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 0, 1, 2, 3, 4, 5, 6, 7's hessians for example 8, 9, 10, 11, 12, 13, 14, 15
atomic_local_add_f(gh_hist + addr, stat1);
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 0, 1, 2, 3, 4, 5, 6, 7's hessians for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 0, 1, 2, 3, 4, 5, 6, 7's gradients for example 8, 9, 10, 11, 12, 13, 14, 15
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, stat2);
#endif
}
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s6) {
bin = feature4.s3 & 0xf;
addr = bin * HG_BIN_MULT + bank * 2 * DWORD_FEATURES + is_hessian_first * DWORD_FEATURES + offset;
addr2 = addr + DWORD_FEATURES - 2 * DWORD_FEATURES * is_hessian_first;
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 1, 2, 3, 4, 5, 6 ,7, 0's gradients for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 1, 2, 3, 4, 5, 6, 7, 0's hessians for example 8, 9, 10, 11, 12, 13, 14, 15
atomic_local_add_f(gh_hist + addr, stat1);
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 1, 2, 3, 4, 5, 6, 7, 0's hessians for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 1, 2, 3, 4, 5, 6, 7, 0's gradients for example 8, 9, 10, 11, 12, 13, 14, 15
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, stat2);
#endif
}
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s5) {
bin = feature4.s2 >> 4;
addr = bin * HG_BIN_MULT + bank * 2 * DWORD_FEATURES + is_hessian_first * DWORD_FEATURES + offset;
addr2 = addr + DWORD_FEATURES - 2 * DWORD_FEATURES * is_hessian_first;
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 2, 3, 4, 5, 6, 7, 0, 1's gradients for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 2, 3, 4, 5, 6, 7, 0, 1's hessians for example 8, 9, 10, 11, 12, 13, 14, 15
atomic_local_add_f(gh_hist + addr, stat1);
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 2, 3, 4, 5, 6, 7, 0, 1's hessians for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 2, 3, 4, 5, 6, 7, 0, 1's gradients for example 8, 9, 10, 11, 12, 13, 14, 15
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, stat2);
#endif
}
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s4) {
bin = feature4.s2 & 0xf;
addr = bin * HG_BIN_MULT + bank * 2 * DWORD_FEATURES + is_hessian_first * DWORD_FEATURES + offset;
addr2 = addr + DWORD_FEATURES - 2 * DWORD_FEATURES * is_hessian_first;
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 3, 4, 5, 6, 7, 0, 1, 2's gradients for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 3, 4, 5, 6, 7, 0, 1, 2's hessians for example 8, 9, 10, 11, 12, 13, 14, 15
atomic_local_add_f(gh_hist + addr, stat1);
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 3, 4, 5, 6, 7, 0, 1, 2's hessians for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 3, 4, 5, 6, 7, 0, 1, 2's gradients for example 8, 9, 10, 11, 12, 13, 14, 15
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, stat2);
#endif
}
// prefetch the next iteration variables
// we don't need bondary check because if it is out of boundary, ind_next = 0
#ifndef IGNORE_INDICES
feature4_next = feature_data[ind_next];
#endif
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s3) {
bin = feature4.s1 >> 4;
addr = bin * HG_BIN_MULT + bank * 2 * DWORD_FEATURES + is_hessian_first * DWORD_FEATURES + offset;
addr2 = addr + DWORD_FEATURES - 2 * DWORD_FEATURES * is_hessian_first;
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 4, 5, 6, 7, 0, 1, 2, 3's gradients for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 4, 5, 6, 7, 0, 1, 2, 3's hessians for example 8, 9, 10, 11, 12, 13, 14, 15
atomic_local_add_f(gh_hist + addr, stat1);
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 4, 5, 6, 7, 0, 1, 2, 3's hessians for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 4, 5, 6, 7, 0, 1, 2, 3's gradients for example 8, 9, 10, 11, 12, 13, 14, 15
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, stat2);
#endif
}
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s2) {
bin = feature4.s1 & 0xf;
addr = bin * HG_BIN_MULT + bank * 2 * DWORD_FEATURES + is_hessian_first * DWORD_FEATURES + offset;
addr2 = addr + DWORD_FEATURES - 2 * DWORD_FEATURES * is_hessian_first;
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 5, 6, 7, 0, 1, 2, 3, 4's gradients for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 5, 6, 7, 0, 1, 2, 3, 4's hessians for example 8, 9, 10, 11, 12, 13, 14, 15
atomic_local_add_f(gh_hist + addr, stat1);
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 5, 6, 7, 0, 1, 2, 3, 4's hessians for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 5, 6, 7, 0, 1, 2, 3, 4's gradients for example 8, 9, 10, 11, 12, 13, 14, 15
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, stat2);
#endif
}
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s1) {
bin = feature4.s0 >> 4;
addr = bin * HG_BIN_MULT + bank * 2 * DWORD_FEATURES + is_hessian_first * DWORD_FEATURES + offset;
addr2 = addr + DWORD_FEATURES - 2 * DWORD_FEATURES * is_hessian_first;
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 6, 7, 0, 1, 2, 3, 4, 5's gradients for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 6, 7, 0, 1, 2, 3, 4, 5's hessians for example 8, 9, 10, 11, 12, 13, 14, 15
atomic_local_add_f(gh_hist + addr, stat1);
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 6, 7, 0, 1, 2, 3, 4, 5's hessians for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 6, 7, 0, 1, 2, 3, 4, 5's gradients for example 8, 9, 10, 11, 12, 13, 14, 15
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, stat2);
#endif
}
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s0) {
bin = feature4.s0 & 0xf;
addr = bin * HG_BIN_MULT + bank * 2 * DWORD_FEATURES + is_hessian_first * DWORD_FEATURES + offset;
addr2 = addr + DWORD_FEATURES - 2 * DWORD_FEATURES * is_hessian_first;
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 7, 0, 1, 2, 3, 4, 5, 6's gradients for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 7, 0, 1, 2, 3, 4, 5, 6's hessians for example 8, 9, 10, 11, 12, 13, 14, 15
atomic_local_add_f(gh_hist + addr, stat1);
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 7, 0, 1, 2, 3, 4, 5, 6's hessians for example 0, 1, 2, 3, 4, 5, 6, 7
// thread 8, 9, 10, 11, 12, 13, 14, 15 now process feature 7, 0, 1, 2, 3, 4, 5, 6's gradients for example 8, 9, 10, 11, 12, 13, 14, 15
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, stat2);
#endif
}
// STAGE 3: accumulate counter
// there are 8 counters for 8 features
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 0, 1, 2, 3, 4, 5, 6, 7's counts for example 0, 1, 2, 3, 4, 5, 6, 7
offset = (ltid & DWORD_FEATURES_MASK);
if (feature_mask.s7) {
bin = feature4.s3 >> 4;
addr = bin * CNT_BIN_MULT + bank * DWORD_FEATURES + offset;
// printf("thread %x add counter %d feature %d (0)\n", ltid, bin, offset);
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 1, 2, 3, 4, 5, 6, 7, 0's counts for example 0, 1, 2, 3, 4, 5, 6, 7
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s6) {
bin = feature4.s3 & 0xf;
addr = bin * CNT_BIN_MULT + bank * DWORD_FEATURES + offset;
// printf("thread %x add counter %d feature %d (1)\n", ltid, bin, offset);
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 2, 3, 4, 5, 6, 7, 0, 1's counts for example 0, 1, 2, 3, 4, 5, 6, 7
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s5) {
bin = feature4.s2 >> 4;
addr = bin * CNT_BIN_MULT + bank * DWORD_FEATURES + offset;
// printf("thread %x add counter %d feature %d (2)\n", ltid, bin, offset);
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 3, 4, 5, 6, 7, 0, 1, 2's counts for example 0, 1, 2, 3, 4, 5, 6, 7
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s4) {
bin = feature4.s2 & 0xf;
addr = bin * CNT_BIN_MULT + bank * DWORD_FEATURES + offset;
// printf("thread %x add counter %d feature %d (3)\n", ltid, bin, offset);
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 4, 5, 6, 7, 0, 1, 2, 3's counts for example 0, 1, 2, 3, 4, 5, 6, 7
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s3) {
bin = feature4.s1 >> 4;
addr = bin * CNT_BIN_MULT + bank * DWORD_FEATURES + offset;
// printf("thread %x add counter %d feature %d (4)\n", ltid, bin, offset);
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 5, 6, 7, 0, 1, 2, 3, 4's counts for example 0, 1, 2, 3, 4, 5, 6, 7
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s2) {
bin = feature4.s1 & 0xf;
addr = bin * CNT_BIN_MULT + bank * DWORD_FEATURES + offset;
// printf("thread %x add counter %d feature %d (5)\n", ltid, bin, offset);
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 6, 7, 0, 1, 2, 3, 4, 5's counts for example 0, 1, 2, 3, 4, 5, 6, 7
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s1) {
bin = feature4.s0 >> 4;
addr = bin * CNT_BIN_MULT + bank * DWORD_FEATURES + offset;
// printf("thread %x add counter %d feature %d (6)\n", ltid, bin, offset);
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3, 4, 5, 6, 7 now process feature 7, 0, 1, 2, 3, 4, 5, 6's counts for example 0, 1, 2, 3, 4, 5, 6, 7
offset = (offset + 1) & DWORD_FEATURES_MASK;
if (feature_mask.s0) {
bin = feature4.s0 & 0xf;
addr = bin * CNT_BIN_MULT + bank * DWORD_FEATURES + offset;
// printf("thread %x add counter %d feature %d (7)\n", ltid, bin, offset);
atom_inc(cnt_hist + addr);
}
stat1 = stat1_next;
stat2 = stat2_next;
feature4 = feature4_next;
}
barrier(CLK_LOCAL_MEM_FENCE);
#if ENABLE_ALL_FEATURES == 0
// restore feature_mask
feature_mask = feature_masks[group_feature];
#endif
// now reduce the 4 banks of subhistograms into 1
acc_type stat_val = 0.0f;
uint cnt_val = 0;
// 256 threads, working on 8 features and 16 bins, 2 stats
// so each thread has an independent feature/bin/stat to work on.
const ushort feature_id = ltid & DWORD_FEATURES_MASK; // bits 0 - 2 of ltid, range 0 - 7
ushort bin_id = ltid >> (LOG2_DWORD_FEATURES + 1); // bits 3 is is_hessian_first; bits 4 - 7 range 0 - 16 is bin ID
offset = (ltid >> (LOG2_DWORD_FEATURES + 1)) & BANK_MASK; // helps avoid LDS bank conflicts
for (int i = 0; i < NUM_BANKS; ++i) {
ushort bank_id = (i + offset) & BANK_MASK;
stat_val += gh_hist[bin_id * HG_BIN_MULT + bank_id * 2 * DWORD_FEATURES + is_hessian_first * DWORD_FEATURES + feature_id];
}
if (ltid < LOCAL_SIZE_0 / 2) {
// first 128 threads accumulate the 8 * 16 = 128 counter values
bin_id = ltid >> LOG2_DWORD_FEATURES; // bits 3 - 6 range 0 - 16 is bin ID
offset = (ltid >> LOG2_DWORD_FEATURES) & BANK_MASK; // helps avoid LDS bank conflicts
for (int i = 0; i < NUM_BANKS; ++i) {
ushort bank_id = (i + offset) & BANK_MASK;
cnt_val += cnt_hist[bin_id * CNT_BIN_MULT + bank_id * DWORD_FEATURES + feature_id];
}
}
// now thread 0 - 7 holds feature 0 - 7's gradient for bin 0 and counter bin 0
// now thread 8 - 15 holds feature 0 - 7's hessian for bin 0 and counter bin 1
// now thread 16- 23 holds feature 0 - 7's gradient for bin 1 and counter bin 2
// now thread 24- 31 holds feature 0 - 7's hessian for bin 1 and counter bin 3
// etc,
#if CONST_HESSIAN == 1
// Combine the two banks into one, and fill the hessians with counter value * hessian constant
barrier(CLK_LOCAL_MEM_FENCE);
gh_hist[ltid] = stat_val;
if (ltid < LOCAL_SIZE_0 / 2) {
cnt_hist[ltid] = cnt_val;
}
barrier(CLK_LOCAL_MEM_FENCE);
if (is_hessian_first) {
// this is the hessians
// thread 8 - 15 read counters stored by thread 0 - 7
// thread 24- 31 read counters stored by thread 8 - 15
// thread 40- 47 read counters stored by thread 16- 23, etc
stat_val = const_hessian *
cnt_hist[((ltid - DWORD_FEATURES) >> (LOG2_DWORD_FEATURES + 1)) * DWORD_FEATURES + (ltid & DWORD_FEATURES_MASK)];
}
else {
// this is the gradients
// thread 0 - 7 read gradients stored by thread 8 - 15
// thread 16- 23 read gradients stored by thread 24- 31
// thread 32- 39 read gradients stored by thread 40- 47, etc
stat_val += gh_hist[ltid + DWORD_FEATURES];
}
barrier(CLK_LOCAL_MEM_FENCE);
#endif
// write to output
// write gradients and hessians histogram for all 4 features
// output data in linear order for further reduction
// output size = 4 (features) * 3 (counters) * 64 (bins) * sizeof(float)
/* memory layout of output:
g_f0_bin0 g_f1_bin0 g_f2_bin0 g_f3_bin0 g_f4_bin0 g_f5_bin0 g_f6_bin0 g_f7_bin0
h_f0_bin0 h_f1_bin0 h_f2_bin0 h_f3_bin0 h_f4_bin0 h_f5_bin0 h_f6_bin0 h_f7_bin0
g_f0_bin1 g_f1_bin1 g_f2_bin1 g_f3_bin1 g_f4_bin1 g_f5_bin1 g_f6_bin1 g_f7_bin1
h_f0_bin1 h_f1_bin1 h_f2_bin1 h_f3_bin1 h_f4_bin1 h_f5_bin1 h_f6_bin1 h_f7_bin1
...
...
g_f0_bin16 g_f1_bin16 g_f2_bin16 g_f3_bin16 g_f4_bin16 g_f5_bin16 g_f6_bin16 g_f7_bin16
h_f0_bin16 h_f1_bin16 h_f2_bin16 h_f3_bin16 h_f4_bin16 h_f5_bin16 h_f6_bin16 h_f7_bin16
c_f0_bin0 c_f1_bin0 c_f2_bin0 c_f3_bin0 c_f4_bin0 c_f5_bin0 c_f6_bin0 c_f7_bin0
c_f0_bin1 c_f1_bin1 c_f2_bin1 c_f3_bin1 c_f4_bin1 c_f5_bin1 c_f6_bin1 c_f7_bin1
...
c_f0_bin16 c_f1_bin16 c_f2_bin16 c_f3_bin16 c_f4_bin16 c_f5_bin16 c_f6_bin16 c_f7_bin16
*/
// if there is only one workgroup processing this feature4, don't even need to write
uint feature4_id = (group_id >> POWER_FEATURE_WORKGROUPS);
#if POWER_FEATURE_WORKGROUPS != 0
__global acc_type * restrict output = (__global acc_type * restrict)output_buf + group_id * DWORD_FEATURES * 3 * NUM_BINS;
// if g_val and h_val are double, they are converted to float here
// write gradients and hessians for 8 features
output[0 * DWORD_FEATURES * NUM_BINS + ltid] = stat_val;
// write counts for 8 features
if (ltid < LOCAL_SIZE_0 / 2) {
output[2 * DWORD_FEATURES * NUM_BINS + ltid] = as_acc_type((acc_int_type)cnt_val);
}
barrier(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE);
mem_fence(CLK_GLOBAL_MEM_FENCE);
// To avoid the cost of an extra reducting kernel, we have to deal with some
// gray area in OpenCL. We want the last work group that process this feature to
// make the final reduction, and other threads will just quit.
// This requires that the results written by other workgroups available to the
// last workgroup (memory consistency)
#if NVIDIA == 1
// this is equavalent to CUDA __threadfence();
// ensure the writes above goes to main memory and other workgroups can see it
asm volatile("{\n\tmembar.gl;\n\t}\n\t" :::"memory");
#else
// FIXME: how to do the above on AMD GPUs??
// GCN ISA says that the all writes will bypass L1 cache (write through),
// however when the last thread is reading sub-histogram data we have to
// make sure that no part of data is modified in local L1 cache of other workgroups.
// Otherwise reading can be a problem (atomic operations to get consistency).
// But in our case, the sub-histogram of this workgroup cannot be in the cache
// of another workgroup, so the following trick will work just fine.
#endif
// Now, we want one workgroup to do the final reduction.
// Other workgroups processing the same feature quit.
// The is done by using an global atomic counter.
// On AMD GPUs ideally this should be done in GDS,
// but currently there is no easy way to access it via OpenCL.
__local uint * counter_val = cnt_hist;
if (ltid == 0) {
// all workgroups processing the same feature add this counter
*counter_val = atom_inc(sync_counters + feature4_id);
}
// make sure everyone in this workgroup is here
barrier(CLK_LOCAL_MEM_FENCE);
// everyone in this wrokgroup: if we are the last workgroup, then do reduction!
if (*counter_val == (1 << POWER_FEATURE_WORKGROUPS) - 1) {
if (ltid == 0) {
// printf("workgroup %d start reduction!\n", group_id);
// printf("feature_data[0] = %d %d %d %d", feature_data[0].s0, feature_data[0].s1, feature_data[0].s2, feature_data[0].s3);
// clear the sync counter for using it next time
sync_counters[feature4_id] = 0;
}
#else
// only 1 work group, no need to increase counter
// the reduction will become a simple copy
if (1) {
barrier(CLK_LOCAL_MEM_FENCE);
#endif
// locate our feature4's block in output memory
uint output_offset = (feature4_id << POWER_FEATURE_WORKGROUPS);
__global acc_type const * restrict feature4_subhists =
(__global acc_type *)output_buf + output_offset * DWORD_FEATURES * 3 * NUM_BINS;
// skip reading the data already in local memory
uint skip_id = group_id ^ output_offset;
// locate output histogram location for this feature4
__global acc_type* restrict hist_buf = hist_buf_base + feature4_id * DWORD_FEATURES * 3 * NUM_BINS;
within_kernel_reduction16x8(feature_mask, feature4_subhists, skip_id, stat_val, cnt_val,
1 << POWER_FEATURE_WORKGROUPS, hist_buf, (__local acc_type *)shared_array);
}
}
// The following line ends the string literal, adds an extra #endif at the end
// the +9 skips extra characters ")", newline, "#endif" and newline at the beginning
// )"" "\n#endif" + 9
#endif
// this file can either be read and passed to an OpenCL compiler directly,
// or included in a C++11 source file as a string literal
#ifndef __OPENCL_VERSION__
// If we are including this file in C++,
// the entire source file following (except the last #endif) will become
// a raw string literal. The extra ")" is just for mathcing parentheses
// to make the editor happy. The extra ")" and extra endif will be skipped.
// DO NOT add anything between here and the next #ifdef, otherwise you need
// to modify the skip count at the end of this file.
R""()
#endif
#ifndef _HISTOGRAM_256_KERNEL_
#define _HISTOGRAM_256_KERNEL_
#pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
// use double precision or not
#ifndef USE_DP_FLOAT
#define USE_DP_FLOAT 0
#endif
// ignore hessian, and use the local memory for hessian as an additional bank for gradient
#ifndef CONST_HESSIAN
#define CONST_HESSIAN 0
#endif
#define LOCAL_SIZE_0 256
#define NUM_BINS 256
#if USE_DP_FLOAT == 1
#pragma OPENCL EXTENSION cl_khr_fp64 : enable
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
typedef double acc_type;
typedef ulong acc_int_type;
#define as_acc_type as_double
#define as_acc_int_type as_ulong
#else
typedef float acc_type;
typedef uint acc_int_type;
#define as_acc_type as_float
#define as_acc_int_type as_uint
#endif
#define LOCAL_MEM_SIZE (4 * (sizeof(uint) + 2 * sizeof(acc_type)) * NUM_BINS)
// unroll the atomic operation for a few times. Takes more code space,
// but compiler can generate better code for faster atomics.
#define UNROLL_ATOMIC 1
// Options passed by compiler at run time:
// IGNORE_INDICES will be set when the kernel does not
// #define IGNORE_INDICES
// #define POWER_FEATURE_WORKGROUPS 10
// detect Nvidia platforms
#ifdef cl_nv_pragma_unroll
#define NVIDIA 1
#endif
// use all features and do not use feature mask
#ifndef ENABLE_ALL_FEATURES
#define ENABLE_ALL_FEATURES 1
#endif
// use binary patching for AMD GCN 1.2 or newer
#ifndef AMD_USE_DS_ADD_F32
#define AMD_USE_DS_ADD_F32 0
#endif
typedef uint data_size_t;
typedef float score_t;
#define ATOMIC_FADD_SUB1 { \
expected.f_val = current.f_val; \
next.f_val = expected.f_val + val; \
current.u_val = atom_cmpxchg((volatile __local acc_int_type *)addr, expected.u_val, next.u_val); \
if (current.u_val == expected.u_val) \
goto end; \
}
#define ATOMIC_FADD_SUB2 ATOMIC_FADD_SUB1 \
ATOMIC_FADD_SUB1
#define ATOMIC_FADD_SUB4 ATOMIC_FADD_SUB2 \
ATOMIC_FADD_SUB2
#define ATOMIC_FADD_SUB8 ATOMIC_FADD_SUB4 \
ATOMIC_FADD_SUB4
#define ATOMIC_FADD_SUB16 ATOMIC_FADD_SUB8 \
ATOMIC_FADD_SUB8
#define ATOMIC_FADD_SUB32 ATOMIC_FADD_SUB16\
ATOMIC_FADD_SUB16
#define ATOMIC_FADD_SUB64 ATOMIC_FADD_SUB32\
ATOMIC_FADD_SUB32
// atomic add for float number in local memory
inline void atomic_local_add_f(__local acc_type *addr, const float val)
{
union{
acc_int_type u_val;
acc_type f_val;
} next, expected, current;
#if (NVIDIA == 1 && USE_DP_FLOAT == 0)
float res = 0;
asm volatile ("atom.shared.add.f32 %0, [%1], %2;" : "=f"(res) : "l"(addr), "f"(val));
#elif (AMD_USE_DS_ADD_F32 == 1 && USE_DP_FLAT == 0)
// this instruction (DS_AND_U32) will be patched into a DS_ADD_F32
// we need to hack here because DS_ADD_F32 is not exposed via OpenCL
atom_and((__local acc_int_type *)addr, as_acc_int_type(val));
#else
current.f_val = *addr;
#if UNROLL_ATOMIC == 1
// provide a fast path
// then do the complete loop
// this should work on all devices
ATOMIC_FADD_SUB8
ATOMIC_FADD_SUB4
ATOMIC_FADD_SUB2
#endif
do {
expected.f_val = current.f_val;
next.f_val = expected.f_val + val;
current.u_val = atom_cmpxchg((volatile __local acc_int_type *)addr, expected.u_val, next.u_val);
} while (current.u_val != expected.u_val);
end:
;
#endif
}
// this function will be called by histogram256
// we have one sub-histogram of one feature in local memory, and need to read others
void within_kernel_reduction256x4(uchar4 feature_mask,
__global const acc_type* restrict feature4_sub_hist,
const uint skip_id,
const uint old_val_f0_cont_bin0,
const ushort num_sub_hist,
__global acc_type* restrict output_buf,
__local acc_type* restrict local_hist) {
const ushort ltid = get_local_id(0);
const ushort lsize = LOCAL_SIZE_0;
// initialize register counters from our local memory
// TODO: try to avoid bank conflict here
acc_type f0_grad_bin = local_hist[ltid * 8];
acc_type f1_grad_bin = local_hist[ltid * 8 + 1];
acc_type f2_grad_bin = local_hist[ltid * 8 + 2];
acc_type f3_grad_bin = local_hist[ltid * 8 + 3];
acc_type f0_hess_bin = local_hist[ltid * 8 + 4];
acc_type f1_hess_bin = local_hist[ltid * 8 + 5];
acc_type f2_hess_bin = local_hist[ltid * 8 + 6];
acc_type f3_hess_bin = local_hist[ltid * 8 + 7];
__local uint* restrict local_cnt = (__local uint *)(local_hist + 4 * 2 * NUM_BINS);
#if POWER_FEATURE_WORKGROUPS != 0
uint f0_cont_bin = ltid ? local_cnt[ltid * 4] : old_val_f0_cont_bin0;
#else
uint f0_cont_bin = local_cnt[ltid * 4];
#endif
uint f1_cont_bin = local_cnt[ltid * 4 + 1];
uint f2_cont_bin = local_cnt[ltid * 4 + 2];
uint f3_cont_bin = local_cnt[ltid * 4 + 3];
ushort i;
// printf("%d-pre(skip %d): %f %f %f %f %f %f %f %f %d %d %d %d", ltid, skip_id, f0_grad_bin, f1_grad_bin, f2_grad_bin, f3_grad_bin, f0_hess_bin, f1_hess_bin, f2_hess_bin, f3_hess_bin, f0_cont_bin, f1_cont_bin, f2_cont_bin, f3_cont_bin);
#if POWER_FEATURE_WORKGROUPS != 0
// add all sub-histograms for 4 features
__global const acc_type* restrict p = feature4_sub_hist + ltid;
for (i = 0; i < skip_id; ++i) {
if (feature_mask.s3) {
f0_grad_bin += *p; p += NUM_BINS;
f0_hess_bin += *p; p += NUM_BINS;
f0_cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
else {
p += 3 * NUM_BINS;
}
if (feature_mask.s2) {
f1_grad_bin += *p; p += NUM_BINS;
f1_hess_bin += *p; p += NUM_BINS;
f1_cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
else {
p += 3 * NUM_BINS;
}
if (feature_mask.s1) {
f2_grad_bin += *p; p += NUM_BINS;
f2_hess_bin += *p; p += NUM_BINS;
f2_cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
else {
p += 3 * NUM_BINS;
}
if (feature_mask.s0) {
f3_grad_bin += *p; p += NUM_BINS;
f3_hess_bin += *p; p += NUM_BINS;
f3_cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
else {
p += 3 * NUM_BINS;
}
}
// skip the counters we already have
p += 3 * 4 * NUM_BINS;
for (i = i + 1; i < num_sub_hist; ++i) {
if (feature_mask.s3) {
f0_grad_bin += *p; p += NUM_BINS;
f0_hess_bin += *p; p += NUM_BINS;
f0_cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
else {
p += 3 * NUM_BINS;
}
if (feature_mask.s2) {
f1_grad_bin += *p; p += NUM_BINS;
f1_hess_bin += *p; p += NUM_BINS;
f1_cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
else {
p += 3 * NUM_BINS;
}
if (feature_mask.s1) {
f2_grad_bin += *p; p += NUM_BINS;
f2_hess_bin += *p; p += NUM_BINS;
f2_cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
else {
p += 3 * NUM_BINS;
}
if (feature_mask.s0) {
f3_grad_bin += *p; p += NUM_BINS;
f3_hess_bin += *p; p += NUM_BINS;
f3_cont_bin += as_acc_int_type(*p); p += NUM_BINS;
}
else {
p += 3 * NUM_BINS;
}
}
// printf("%d-aft: %f %f %f %f %f %f %f %f %d %d %d %d", ltid, f0_grad_bin, f1_grad_bin, f2_grad_bin, f3_grad_bin, f0_hess_bin, f1_hess_bin, f2_hess_bin, f3_hess_bin, f0_cont_bin, f1_cont_bin, f2_cont_bin, f3_cont_bin);
#endif
// now overwrite the local_hist for final reduction and output
barrier(CLK_LOCAL_MEM_FENCE);
#if USE_DP_FLOAT == 0
// reverse the f3...f0 order to match the real order
local_hist[0 * 3 * NUM_BINS + ltid * 3 + 0] = f3_grad_bin;
local_hist[0 * 3 * NUM_BINS + ltid * 3 + 1] = f3_hess_bin;
local_hist[0 * 3 * NUM_BINS + ltid * 3 + 2] = as_acc_type((acc_int_type)f3_cont_bin);
local_hist[1 * 3 * NUM_BINS + ltid * 3 + 0] = f2_grad_bin;
local_hist[1 * 3 * NUM_BINS + ltid * 3 + 1] = f2_hess_bin;
local_hist[1 * 3 * NUM_BINS + ltid * 3 + 2] = as_acc_type((acc_int_type)f2_cont_bin);
local_hist[2 * 3 * NUM_BINS + ltid * 3 + 0] = f1_grad_bin;
local_hist[2 * 3 * NUM_BINS + ltid * 3 + 1] = f1_hess_bin;
local_hist[2 * 3 * NUM_BINS + ltid * 3 + 2] = as_acc_type((acc_int_type)f1_cont_bin);
local_hist[3 * 3 * NUM_BINS + ltid * 3 + 0] = f0_grad_bin;
local_hist[3 * 3 * NUM_BINS + ltid * 3 + 1] = f0_hess_bin;
local_hist[3 * 3 * NUM_BINS + ltid * 3 + 2] = as_acc_type((acc_int_type)f0_cont_bin);
barrier(CLK_LOCAL_MEM_FENCE);
/*
for (ushort i = ltid; i < 4 * 3 * NUM_BINS; i += lsize) {
output_buf[i] = local_hist[i];
}
*/
i = ltid;
if (feature_mask.s0) {
output_buf[i] = local_hist[i];
output_buf[i + NUM_BINS] = local_hist[i + NUM_BINS];
output_buf[i + 2 * NUM_BINS] = local_hist[i + 2 * NUM_BINS];
}
i += 1 * 3 * NUM_BINS;
if (feature_mask.s1) {
output_buf[i] = local_hist[i];
output_buf[i + NUM_BINS] = local_hist[i + NUM_BINS];
output_buf[i + 2 * NUM_BINS] = local_hist[i + 2 * NUM_BINS];
}
i += 1 * 3 * NUM_BINS;
if (feature_mask.s2) {
output_buf[i] = local_hist[i];
output_buf[i + NUM_BINS] = local_hist[i + NUM_BINS];
output_buf[i + 2 * NUM_BINS] = local_hist[i + 2 * NUM_BINS];
}
i += 1 * 3 * NUM_BINS;
if (feature_mask.s3 && i < 4 * 3 * NUM_BINS) {
output_buf[i] = local_hist[i];
output_buf[i + NUM_BINS] = local_hist[i + NUM_BINS];
output_buf[i + 2 * NUM_BINS] = local_hist[i + 2 * NUM_BINS];
}
#else
// when double precision is used, we need to write twice, because local memory size is not enough
local_hist[0 * 3 * NUM_BINS + ltid * 3 + 0] = f3_grad_bin;
local_hist[0 * 3 * NUM_BINS + ltid * 3 + 1] = f3_hess_bin;
local_hist[0 * 3 * NUM_BINS + ltid * 3 + 2] = as_acc_type((acc_int_type)f3_cont_bin);
local_hist[1 * 3 * NUM_BINS + ltid * 3 + 0] = f2_grad_bin;
local_hist[1 * 3 * NUM_BINS + ltid * 3 + 1] = f2_hess_bin;
local_hist[1 * 3 * NUM_BINS + ltid * 3 + 2] = as_acc_type((acc_int_type)f2_cont_bin);
barrier(CLK_LOCAL_MEM_FENCE);
/*
for (ushort i = ltid; i < 2 * 3 * NUM_BINS; i += lsize) {
output_buf[i] = local_hist[i];
}
*/
i = ltid;
if (feature_mask.s0) {
output_buf[i] = local_hist[i];
output_buf[i + NUM_BINS] = local_hist[i + NUM_BINS];
output_buf[i + 2 * NUM_BINS] = local_hist[i + 2 * NUM_BINS];
}
i += 1 * 3 * NUM_BINS;
if (feature_mask.s1) {
output_buf[i] = local_hist[i];
output_buf[i + NUM_BINS] = local_hist[i + NUM_BINS];
output_buf[i + 2 * NUM_BINS] = local_hist[i + 2 * NUM_BINS];
}
barrier(CLK_LOCAL_MEM_FENCE);
local_hist[0 * 3 * NUM_BINS + ltid * 3 + 0] = f1_grad_bin;
local_hist[0 * 3 * NUM_BINS + ltid * 3 + 1] = f1_hess_bin;
local_hist[0 * 3 * NUM_BINS + ltid * 3 + 2] = as_acc_type((acc_int_type)f1_cont_bin);
local_hist[1 * 3 * NUM_BINS + ltid * 3 + 0] = f0_grad_bin;
local_hist[1 * 3 * NUM_BINS + ltid * 3 + 1] = f0_hess_bin;
local_hist[1 * 3 * NUM_BINS + ltid * 3 + 2] = as_acc_type((acc_int_type)f0_cont_bin);
barrier(CLK_LOCAL_MEM_FENCE);
/*
for (ushort i = ltid; i < 2 * 3 * NUM_BINS; i += lsize) {
output_buf[i + 2 * 3 * NUM_BINS] = local_hist[i];
}
*/
i = ltid;
if (feature_mask.s2) {
output_buf[i + 2 * 3 * NUM_BINS] = local_hist[i];
output_buf[i + 2 * 3 * NUM_BINS + NUM_BINS] = local_hist[i + NUM_BINS];
output_buf[i + 2 * 3 * NUM_BINS + 2 * NUM_BINS] = local_hist[i + 2 * NUM_BINS];
}
i += 1 * 3 * NUM_BINS;
if (feature_mask.s3) {
output_buf[i + 2 * 3 * NUM_BINS] = local_hist[i];
output_buf[i + 2 * 3 * NUM_BINS + NUM_BINS] = local_hist[i + NUM_BINS];
output_buf[i + 2 * 3 * NUM_BINS + 2 * NUM_BINS] = local_hist[i + 2 * NUM_BINS];
}
#endif
}
#define printf
__attribute__((reqd_work_group_size(LOCAL_SIZE_0, 1, 1)))
#if USE_CONSTANT_BUF == 1
__kernel void histogram256(__global const uchar4* restrict feature_data_base,
__constant const uchar4* restrict feature_masks __attribute__((max_constant_size(65536))),
const data_size_t feature_size,
__constant const data_size_t* restrict data_indices __attribute__((max_constant_size(65536))),
const data_size_t num_data,
__constant const score_t* restrict ordered_gradients __attribute__((max_constant_size(65536))),
#if CONST_HESSIAN == 0
__constant const score_t* restrict ordered_hessians __attribute__((max_constant_size(65536))),
#else
const score_t const_hessian,
#endif
__global char* restrict output_buf,
__global volatile int * sync_counters,
__global acc_type* restrict hist_buf_base) {
#else
__kernel void histogram256(__global const uchar4* feature_data_base,
__constant const uchar4* restrict feature_masks __attribute__((max_constant_size(65536))),
const data_size_t feature_size,
__global const data_size_t* data_indices,
const data_size_t num_data,
__global const score_t* ordered_gradients,
#if CONST_HESSIAN == 0
__global const score_t* ordered_hessians,
#else
const score_t const_hessian,
#endif
__global char* restrict output_buf,
__global volatile int * sync_counters,
__global acc_type* restrict hist_buf_base) {
#endif
// allocate the local memory array aligned with float2, to guarantee correct alignment on NVIDIA platforms
// otherwise a "Misaligned Address" exception may occur
__local float2 shared_array[LOCAL_MEM_SIZE/sizeof(float2)];
const uint gtid = get_global_id(0);
const uint gsize = get_global_size(0);
const ushort ltid = get_local_id(0);
const ushort lsize = LOCAL_SIZE_0; // get_local_size(0);
const ushort group_id = get_group_id(0);
// local memory per workgroup is 12 KB
// clear local memory
__local uint * ptr = (__local uint *) shared_array;
for (int i = ltid; i < LOCAL_MEM_SIZE/sizeof(uint); i += lsize) {
ptr[i] = 0;
}
barrier(CLK_LOCAL_MEM_FENCE);
// gradient/hessian histograms
// assume this starts at 32 * 4 = 128-byte boundary
// total size: 2 * 4 * 256 * size_of(float) = 8 KB
// organization: each feature/grad/hessian is at a different bank,
// as indepedent of the feature value as possible
__local acc_type * gh_hist = (__local acc_type *)shared_array;
// counter histogram
// total size: 4 * 256 * size_of(uint) = 4 KB
__local uint * cnt_hist = (__local uint *)(gh_hist + 2 * 4 * NUM_BINS);
// thread 0, 1, 2, 3 compute histograms for gradients first
// thread 4, 5, 6, 7 compute histograms for hessians first
// etc.
uchar is_hessian_first = (ltid >> 2) & 1;
ushort group_feature = group_id >> POWER_FEATURE_WORKGROUPS;
// each 2^POWER_FEATURE_WORKGROUPS workgroups process on one feature (compile-time constant)
// feature_size is the number of examples per feature
__global const uchar4* feature_data = feature_data_base + group_feature * feature_size;
// size of threads that process this feature4
const uint subglobal_size = lsize * (1 << POWER_FEATURE_WORKGROUPS);
// equavalent thread ID in this subgroup for this feature4
const uint subglobal_tid = gtid - group_feature * subglobal_size;
// extract feature mask, when a byte is set to 0, that feature is disabled
#if ENABLE_ALL_FEATURES == 1
// hopefully the compiler will propogate the constants and eliminate all branches
uchar4 feature_mask = (uchar4)(0xff, 0xff, 0xff, 0xff);
#else
uchar4 feature_mask = feature_masks[group_feature];
#endif
// exit if all features are masked
if (!as_uint(feature_mask)) {
return;
}
// STAGE 1: read feature data, and gradient and hessian
// first half of the threads read feature data from global memory
// 4 features stored in a tuple MSB...(0, 1, 2, 3)...LSB
// We will prefetch data into the "next" variable at the beginning of each iteration
uchar4 feature4;
uchar4 feature4_next;
uchar4 feature4_prev;
// offset used to rotate feature4 vector
ushort offset = (ltid & 0x3);
// store gradient and hessian
float stat1, stat2;
float stat1_next, stat2_next;
ushort bin, addr, addr2;
data_size_t ind;
data_size_t ind_next;
stat1 = ordered_gradients[subglobal_tid];
#if CONST_HESSIAN == 0
stat2 = ordered_hessians[subglobal_tid];
#endif
#ifdef IGNORE_INDICES
ind = subglobal_tid;
#else
ind = data_indices[subglobal_tid];
#endif
feature4 = feature_data[ind];
feature4_prev = feature4;
feature4_prev = as_uchar4(rotate(as_uint(feature4_prev), (uint)offset*8));
#if ENABLE_ALL_FEATURES == 0
// rotate feature_mask to match the feature order of each thread
feature_mask = as_uchar4(rotate(as_uint(feature_mask), (uint)offset*8));
#endif
acc_type s3_stat1 = 0.0f, s3_stat2 = 0.0f;
acc_type s2_stat1 = 0.0f, s2_stat2 = 0.0f;
acc_type s1_stat1 = 0.0f, s1_stat2 = 0.0f;
acc_type s0_stat1 = 0.0f, s0_stat2 = 0.0f;
// there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4
for (uint i = subglobal_tid; i < num_data; i += subglobal_size) {
// prefetch the next iteration variables
// we don't need bondary check because we have made the buffer larger
stat1_next = ordered_gradients[i + subglobal_size];
#if CONST_HESSIAN == 0
stat2_next = ordered_hessians[i + subglobal_size];
#endif
#ifdef IGNORE_INDICES
// we need to check to bounds here
ind_next = i + subglobal_size < num_data ? i + subglobal_size : i;
// start load next feature as early as possible
feature4_next = feature_data[ind_next];
#else
ind_next = data_indices[i + subglobal_size];
#endif
#if CONST_HESSIAN == 0
// swap gradient and hessian for threads 4, 5, 6, 7
float tmp = stat1;
stat1 = is_hessian_first ? stat2 : stat1;
stat2 = is_hessian_first ? tmp : stat2;
// stat1 = select(stat1, stat2, is_hessian_first);
// stat2 = select(stat2, tmp, is_hessian_first);
#endif
// STAGE 2: accumulate gradient and hessian
offset = (ltid & 0x3);
feature4 = as_uchar4(rotate(as_uint(feature4), (uint)offset*8));
bin = feature4.s3;
if ((bin != feature4_prev.s3) && feature_mask.s3) {
// printf("%3d (%4d): writing s3 %d %d offset %d", ltid, i, bin, feature4_prev.s3, offset);
bin = feature4_prev.s3;
feature4_prev.s3 = feature4.s3;
addr = bin * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
atomic_local_add_f(gh_hist + addr, s3_stat1);
// thread 0, 1, 2, 3 now process feature 0, 1, 2, 3's gradients for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 0, 1, 2, 3's hessians for example 4, 5, 6, 7
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s3_stat2);
#endif
// thread 0, 1, 2, 3 now process feature 0, 1, 2, 3's hessians for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 0, 1, 2, 3's gradients for example 4, 5, 6, 7
s3_stat1 = stat1;
s3_stat2 = stat2;
}
else {
// printf("%3d (%4d): acc s3 %d", ltid, i, bin);
s3_stat1 += stat1;
s3_stat2 += stat2;
}
bin = feature4.s2;
offset = (offset + 1) & 0x3;
if ((bin != feature4_prev.s2) && feature_mask.s2) {
// printf("%3d (%4d): writing s2 %d %d feature %d", ltid, i, bin, feature4_prev.s2, offset);
bin = feature4_prev.s2;
feature4_prev.s2 = feature4.s2;
addr = bin * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
atomic_local_add_f(gh_hist + addr, s2_stat1);
// thread 0, 1, 2, 3 now process feature 1, 2, 3, 0's gradients for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 1, 2, 3, 0's hessians for example 4, 5, 6, 7
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s2_stat2);
#endif
// thread 0, 1, 2, 3 now process feature 1, 2, 3, 0's hessians for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 1, 2, 3, 0's gradients for example 4, 5, 6, 7
s2_stat1 = stat1;
s2_stat2 = stat2;
}
else {
// printf("%3d (%4d): acc s2 %d", ltid, i, bin);
s2_stat1 += stat1;
s2_stat2 += stat2;
}
// prefetch the next iteration variables
// we don't need bondary check because if it is out of boundary, ind_next = 0
#ifndef IGNORE_INDICES
feature4_next = feature_data[ind_next];
#endif
bin = feature4.s1;
offset = (offset + 1) & 0x3;
if ((bin != feature4_prev.s1) && feature_mask.s1) {
// printf("%3d (%4d): writing s1 %d %d feature %d", ltid, i, bin, feature4_prev.s1, offset);
bin = feature4_prev.s1;
feature4_prev.s1 = feature4.s1;
addr = bin * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
atomic_local_add_f(gh_hist + addr, s1_stat1);
// thread 0, 1, 2, 3 now process feature 2, 3, 0, 1's gradients for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 2, 3, 0, 1's hessians for example 4, 5, 6, 7
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s1_stat2);
#endif
// thread 0, 1, 2, 3 now process feature 2, 3, 0, 1's hessians for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 2, 3, 0, 1's gradients for example 4, 5, 6, 7
s1_stat1 = stat1;
s1_stat2 = stat2;
}
else {
// printf("%3d (%4d): acc s1 %d", ltid, i, bin);
s1_stat1 += stat1;
s1_stat2 += stat2;
}
bin = feature4.s0;
offset = (offset + 1) & 0x3;
if ((bin != feature4_prev.s0) && feature_mask.s0) {
// printf("%3d (%4d): writing s0 %d %d feature %d", ltid, i, bin, feature4_prev.s0, offset);
bin = feature4_prev.s0;
feature4_prev.s0 = feature4.s0;
addr = bin * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
atomic_local_add_f(gh_hist + addr, s0_stat1);
// thread 0, 1, 2, 3 now process feature 3, 0, 1, 2's gradients for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 3, 0, 1, 2's hessians for example 4, 5, 6, 7
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s0_stat2);
#endif
// thread 0, 1, 2, 3 now process feature 3, 0, 1, 2's hessians for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 3, 0, 1, 2's gradients for example 4, 5, 6, 7
s0_stat1 = stat1;
s0_stat2 = stat2;
}
else {
// printf("%3d (%4d): acc s0 %d", ltid, i, bin);
s0_stat1 += stat1;
s0_stat2 += stat2;
}
// STAGE 3: accumulate counter
// there are 4 counters for 4 features
// thread 0, 1, 2, 3 now process feature 0, 1, 2, 3's counts for example 0, 1, 2, 3
offset = (ltid & 0x3);
if (feature_mask.s3) {
bin = feature4.s3;
addr = bin * 4 + offset;
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3 now process feature 1, 2, 3, 0's counts for example 0, 1, 2, 3
offset = (offset + 1) & 0x3;
if (feature_mask.s2) {
bin = feature4.s2;
addr = bin * 4 + offset;
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3 now process feature 2, 3, 0, 1's counts for example 0, 1, 2, 3
offset = (offset + 1) & 0x3;
if (feature_mask.s1) {
bin = feature4.s1;
addr = bin * 4 + offset;
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3 now process feature 3, 0, 1, 2's counts for example 0, 1, 2, 3
offset = (offset + 1) & 0x3;
if (feature_mask.s0) {
bin = feature4.s0;
addr = bin * 4 + offset;
atom_inc(cnt_hist + addr);
}
stat1 = stat1_next;
stat2 = stat2_next;
feature4 = feature4_next;
}
bin = feature4_prev.s3;
offset = (ltid & 0x3);
addr = bin * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
atomic_local_add_f(gh_hist + addr, s3_stat1);
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s3_stat2);
#endif
bin = feature4_prev.s2;
offset = (offset + 1) & 0x3;
addr = bin * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
atomic_local_add_f(gh_hist + addr, s2_stat1);
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s2_stat2);
#endif
bin = feature4_prev.s1;
offset = (offset + 1) & 0x3;
addr = bin * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
atomic_local_add_f(gh_hist + addr, s1_stat1);
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s1_stat2);
#endif
bin = feature4_prev.s0;
offset = (offset + 1) & 0x3;
addr = bin * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
atomic_local_add_f(gh_hist + addr, s0_stat1);
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s0_stat2);
#endif
barrier(CLK_LOCAL_MEM_FENCE);
#if ENABLE_ALL_FEATURES == 0
// restore feature_mask
feature_mask = feature_masks[group_feature];
#endif
#if CONST_HESSIAN == 1
barrier(CLK_LOCAL_MEM_FENCE);
// make a final reduction
offset = ltid & 0x3; // helps avoid LDS bank conflicts
gh_hist[ltid * 8 + offset] += gh_hist[ltid * 8 + offset + 4];
gh_hist[ltid * 8 + offset + 4] = const_hessian * cnt_hist[ltid * 4 + offset];
offset = (offset + 1) & 0x3;
gh_hist[ltid * 8 + offset] += gh_hist[ltid * 8 + offset + 4];
gh_hist[ltid * 8 + offset + 4] = const_hessian * cnt_hist[ltid * 4 + offset];
offset = (offset + 1) & 0x3;
gh_hist[ltid * 8 + offset] += gh_hist[ltid * 8 + offset + 4];
gh_hist[ltid * 8 + offset + 4] = const_hessian * cnt_hist[ltid * 4 + offset];
offset = (offset + 1) & 0x3;
gh_hist[ltid * 8 + offset] += gh_hist[ltid * 8 + offset + 4];
gh_hist[ltid * 8 + offset + 4] = const_hessian * cnt_hist[ltid * 4 + offset];
barrier(CLK_LOCAL_MEM_FENCE);
#endif
// write to output
// write gradients and hessians histogram for all 4 features
/* memory layout in gh_hist (total 2 * 4 * 256 * sizeof(float) = 8 KB):
-----------------------------------------------------------------------------------------------
g_f0_bin0 g_f1_bin0 g_f2_bin0 g_f3_bin0 h_f0_bin0 h_f1_bin0 h_f2_bin0 h_f3_bin0
g_f0_bin1 g_f1_bin1 g_f2_bin1 g_f3_bin1 h_f0_bin1 h_f1_bin1 h_f2_bin1 h_f3_bin1
...
g_f0_bin255 g_f1_bin255 g_f2_bin255 g_f3_bin255 h_f0_bin255 h_f1_bin255 h_f2_bin255 h_f3_bin255
-----------------------------------------------------------------------------------------------
*/
/* memory layout in cnt_hist (total 4 * 256 * sizeof(uint) = 4 KB):
-----------------------------------------------
c_f0_bin0 c_f1_bin0 c_f2_bin0 c_f3_bin0
c_f0_bin1 c_f1_bin1 c_f2_bin1 c_f3_bin1
...
c_f0_bin255 c_f1_bin255 c_f2_bin255 c_f3_bin255
-----------------------------------------------
*/
// output data in linear order for further reduction
// output size = 4 (features) * 3 (counters) * 256 (bins) * sizeof(float)
/* memory layout of output:
--------------------------------------------
g_f0_bin0 g_f0_bin1 ... g_f0_bin255 \
h_f0_bin0 h_f0_bin1 ... h_f0_bin255 |
c_f0_bin0 c_f0_bin1 ... c_f0_bin255 |
g_f1_bin0 g_f1_bin1 ... g_f1_bin255 |
h_f1_bin0 h_f1_bin1 ... h_f1_bin255 |
c_f1_bin0 c_f1_bin1 ... c_f1_bin255 |--- 1 sub-histogram block
g_f2_bin0 g_f2_bin1 ... g_f2_bin255 |
h_f2_bin0 h_f2_bin1 ... h_f2_bin255 |
c_f2_bin0 c_f2_bin1 ... c_f2_bin255 |
g_f3_bin0 g_f3_bin1 ... g_f3_bin255 |
h_f3_bin0 h_f3_bin1 ... h_f3_bin255 |
c_f3_bin0 c_f3_bin1 ... c_f3_bin255 /
--------------------------------------------
*/
uint feature4_id = (group_id >> POWER_FEATURE_WORKGROUPS);
// if there is only one workgroup processing this feature4, don't even need to write
#if POWER_FEATURE_WORKGROUPS != 0
__global acc_type * restrict output = (__global acc_type * restrict)output_buf + group_id * 4 * 3 * NUM_BINS;
// write gradients and hessians
__global acc_type * restrict ptr_f = output;
for (ushort j = 0; j < 4; ++j) {
for (ushort i = ltid; i < 2 * NUM_BINS; i += lsize) {
// even threads read gradients, odd threads read hessians
// FIXME: 2-way bank conflict
acc_type value = gh_hist[i * 4 + j];
ptr_f[(i & 1) * NUM_BINS + (i >> 1)] = value;
}
ptr_f += 3 * NUM_BINS;
}
// write counts
__global acc_int_type * restrict ptr_i = (__global acc_int_type * restrict)(output + 2 * NUM_BINS);
for (ushort j = 0; j < 4; ++j) {
for (ushort i = ltid; i < NUM_BINS; i += lsize) {
// FIXME: 2-way bank conflict
uint value = cnt_hist[i * 4 + j];
ptr_i[i] = value;
}
ptr_i += 3 * NUM_BINS;
}
barrier(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE);
mem_fence(CLK_GLOBAL_MEM_FENCE);
// To avoid the cost of an extra reducting kernel, we have to deal with some
// gray area in OpenCL. We want the last work group that process this feature to
// make the final reduction, and other threads will just quit.
// This requires that the results written by other workgroups available to the
// last workgroup (memory consistency)
#if NVIDIA == 1
// this is equavalent to CUDA __threadfence();
// ensure the writes above goes to main memory and other workgroups can see it
asm volatile("{\n\tmembar.gl;\n\t}\n\t" :::"memory");
#else
// FIXME: how to do the above on AMD GPUs??
// GCN ISA says that the all writes will bypass L1 cache (write through),
// however when the last thread is reading sub-histogram data we have to
// make sure that no part of data is modified in local L1 cache of other workgroups.
// Otherwise reading can be a problem (atomic operations to get consistency).
// But in our case, the sub-histogram of this workgroup cannot be in the cache
// of another workgroup, so the following trick will work just fine.
#endif
// Now, we want one workgroup to do the final reduction.
// Other workgroups processing the same feature quit.
// The is done by using an global atomic counter.
// On AMD GPUs ideally this should be done in GDS,
// but currently there is no easy way to access it via OpenCL.
__local uint * counter_val = cnt_hist;
// backup the old value
uint old_val = *counter_val;
if (ltid == 0) {
// all workgroups processing the same feature add this counter
*counter_val = atom_inc(sync_counters + feature4_id);
}
// make sure everyone in this workgroup is here
barrier(CLK_LOCAL_MEM_FENCE);
// everyone in this wrokgroup: if we are the last workgroup, then do reduction!
if (*counter_val == (1 << POWER_FEATURE_WORKGROUPS) - 1) {
if (ltid == 0) {
// printf("workgroup %d: %g %g %g %g %g %g %g %g\n", group_id, gh_hist[0], gh_hist[1], gh_hist[2], gh_hist[3], gh_hist[4], gh_hist[5], gh_hist[6], gh_hist[7]);
// printf("feature_data[0] = %d %d %d %d", feature_data[0].s0, feature_data[0].s1, feature_data[0].s2, feature_data[0].s3);
// clear the sync counter for using it next time
sync_counters[feature4_id] = 0;
}
#else
// only 1 work group, no need to increase counter
// the reduction will become a simple copy
if (1) {
uint old_val; // dummy
#endif
// locate our feature4's block in output memory
uint output_offset = (feature4_id << POWER_FEATURE_WORKGROUPS);
__global acc_type const * restrict feature4_subhists =
(__global acc_type *)output_buf + output_offset * 4 * 3 * NUM_BINS;
// skip reading the data already in local memory
uint skip_id = group_id ^ output_offset;
// locate output histogram location for this feature4
__global acc_type* restrict hist_buf = hist_buf_base + feature4_id * 4 * 3 * NUM_BINS;
within_kernel_reduction256x4(feature_mask, feature4_subhists, skip_id, old_val, 1 << POWER_FEATURE_WORKGROUPS,
hist_buf, (__local acc_type *)shared_array);
// if (ltid == 0)
// printf("workgroup %d reduction done, %g %g %g %g %g %g %g %g\n", group_id, hist_buf[0], hist_buf[3*NUM_BINS], hist_buf[2*3*NUM_BINS], hist_buf[3*3*NUM_BINS], hist_buf[1], hist_buf[3*NUM_BINS+1], hist_buf[2*3*NUM_BINS+1], hist_buf[3*3*NUM_BINS+1]);
}
}
// The following line ends the string literal, adds an extra #endif at the end
// the +9 skips extra characters ")", newline, "#endif" and newline at the beginning
// )"" "\n#endif" + 9
#endif
// this file can either be read and passed to an OpenCL compiler directly,
// or included in a C++11 source file as a string literal
#ifndef __OPENCL_VERSION__
// If we are including this file in C++,
// the entire source file following (except the last #endif) will become
// a raw string literal. The extra ")" is just for mathcing parentheses
// to make the editor happy. The extra ")" and extra endif will be skipped.
// DO NOT add anything between here and the next #ifdef, otherwise you need
// to modify the skip count at the end of this file.
R""()
#endif
#ifndef _HISTOGRAM_64_KERNEL_
#define _HISTOGRAM_64_KERNEL_
#pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
// Configurable options:
// NUM_BANKS should be a power of 2
#ifndef NUM_BANKS
#define NUM_BANKS 4
#endif
// how many bits in thread ID represent the bank = log2(NUM_BANKS)
#ifndef BANK_BITS
#define BANK_BITS 2
#endif
// use double precision or not
#ifndef USE_DP_FLOAT
#define USE_DP_FLOAT 0
#endif
// ignore hessian, and use the local memory for hessian as an additional bank for gradient
#ifndef CONST_HESSIAN
#define CONST_HESSIAN 0
#endif
#define LOCAL_SIZE_0 256
#define NUM_BINS 64
// if USE_DP_FLOAT is set to 1, we will use double precision for the accumulator
#if USE_DP_FLOAT == 1
#pragma OPENCL EXTENSION cl_khr_fp64 : enable
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
typedef double acc_type;
typedef ulong acc_int_type;
#define as_acc_type as_double
#define as_acc_int_type as_ulong
#else
typedef float acc_type;
typedef uint acc_int_type;
#define as_acc_type as_float
#define as_acc_int_type as_uint
#endif
// mask for getting the bank ID
#define BANK_MASK (NUM_BANKS - 1)
// 4 features, each has a gradient and a hessian
#define HG_BIN_MULT (NUM_BANKS * 4 * 2)
// 4 features, each has a counter
#define CNT_BIN_MULT (NUM_BANKS * 4)
// local memory size in bytes
#define LOCAL_MEM_SIZE (4 * (sizeof(uint) + 2 * sizeof(acc_type)) * NUM_BINS * NUM_BANKS)
// unroll the atomic operation for a few times. Takes more code space,
// but compiler can generate better code for faster atomics.
#define UNROLL_ATOMIC 1
// Options passed by compiler at run time:
// IGNORE_INDICES will be set when the kernel does not
// #define IGNORE_INDICES
// #define POWER_FEATURE_WORKGROUPS 10
// use all features and do not use feature mask
#ifndef ENABLE_ALL_FEATURES
#define ENABLE_ALL_FEATURES 1
#endif
// detect Nvidia platforms
#ifdef cl_nv_pragma_unroll
#define NVIDIA 1
#endif
// use binary patching for AMD GCN 1.2 or newer
#ifndef AMD_USE_DS_ADD_F32
#define AMD_USE_DS_ADD_F32 0
#endif
typedef uint data_size_t;
typedef float score_t;
#define ATOMIC_FADD_SUB1 { \
expected.f_val = current.f_val; \
next.f_val = expected.f_val + val; \
current.u_val = atom_cmpxchg((volatile __local acc_int_type *)addr, expected.u_val, next.u_val); \
if (current.u_val == expected.u_val) \
goto end; \
}
#define ATOMIC_FADD_SUB2 ATOMIC_FADD_SUB1 \
ATOMIC_FADD_SUB1
#define ATOMIC_FADD_SUB4 ATOMIC_FADD_SUB2 \
ATOMIC_FADD_SUB2
#define ATOMIC_FADD_SUB8 ATOMIC_FADD_SUB4 \
ATOMIC_FADD_SUB4
#define ATOMIC_FADD_SUB16 ATOMIC_FADD_SUB8 \
ATOMIC_FADD_SUB8
#define ATOMIC_FADD_SUB32 ATOMIC_FADD_SUB16\
ATOMIC_FADD_SUB16
#define ATOMIC_FADD_SUB64 ATOMIC_FADD_SUB32\
ATOMIC_FADD_SUB32
// atomic add for float number in local memory
inline void atomic_local_add_f(__local acc_type *addr, const float val)
{
union{
acc_int_type u_val;
acc_type f_val;
} next, expected, current;
#if (NVIDIA == 1 && USE_DP_FLOAT == 0)
float res = 0;
asm volatile ("atom.shared.add.f32 %0, [%1], %2;" : "=f"(res) : "l"(addr), "f"(val));
#elif (AMD_USE_DS_ADD_F32 == 1 && USE_DP_FLAT == 0)
// this instruction (DS_AND_U32) will be patched into a DS_ADD_F32
// we need to hack here because DS_ADD_F32 is not exposed via OpenCL
atom_and((__local acc_int_type *)addr, as_acc_int_type(val));
#else
current.f_val = *addr;
#if UNROLL_ATOMIC == 1
// provide a fast path
// then do the complete loop
// this should work on all devices
ATOMIC_FADD_SUB8
ATOMIC_FADD_SUB4
ATOMIC_FADD_SUB2
#endif
do {
expected.f_val = current.f_val;
next.f_val = expected.f_val + val;
current.u_val = atom_cmpxchg((volatile __local acc_int_type *)addr, expected.u_val, next.u_val);
} while (current.u_val != expected.u_val);
end:
;
#endif
}
// this function will be called by histogram64
// we have one sub-histogram of one feature in registers, and need to read others
void within_kernel_reduction64x4(uchar4 feature_mask,
__global const acc_type* restrict feature4_sub_hist,
const uint skip_id,
acc_type g_val, acc_type h_val, uint cnt_val,
const ushort num_sub_hist,
__global acc_type* restrict output_buf,
__local acc_type * restrict local_hist) {
const ushort ltid = get_local_id(0); // range 0 - 255
const ushort lsize = LOCAL_SIZE_0;
ushort feature_id = ltid & 3; // range 0 - 4
const ushort bin_id = ltid >> 2; // range 0 - 63W
ushort i;
#if POWER_FEATURE_WORKGROUPS != 0
// if there is only 1 work group, no need to do the reduction
// add all sub-histograms for 4 features
__global const acc_type* restrict p = feature4_sub_hist + ltid;
for (i = 0; i < skip_id; ++i) {
g_val += *p; p += NUM_BINS * 4; // 256 threads working on 4 features' 64 bins
h_val += *p; p += NUM_BINS * 4;
cnt_val += as_acc_int_type(*p); p += NUM_BINS * 4;
}
// skip the counters we already have
p += 3 * 4 * NUM_BINS;
for (i = i + 1; i < num_sub_hist; ++i) {
g_val += *p; p += NUM_BINS * 4;
h_val += *p; p += NUM_BINS * 4;
cnt_val += as_acc_int_type(*p); p += NUM_BINS * 4;
}
#endif
// printf("thread %d: g_val=%f, h_val=%f cnt=%d", ltid, g_val, h_val, cnt_val);
// now overwrite the local_hist for final reduction and output
// reverse the f3...f0 order to match the real order
feature_id = 3 - feature_id;
local_hist[feature_id * 3 * NUM_BINS + bin_id * 3 + 0] = g_val;
local_hist[feature_id * 3 * NUM_BINS + bin_id * 3 + 1] = h_val;
local_hist[feature_id * 3 * NUM_BINS + bin_id * 3 + 2] = as_acc_type((acc_int_type)cnt_val);
barrier(CLK_LOCAL_MEM_FENCE);
i = ltid;
if (feature_mask.s0 && i < 1 * 3 * NUM_BINS) {
output_buf[i] = local_hist[i];
}
i += 1 * 3 * NUM_BINS;
if (feature_mask.s1 && i < 2 * 3 * NUM_BINS) {
output_buf[i] = local_hist[i];
}
i += 1 * 3 * NUM_BINS;
if (feature_mask.s2 && i < 3 * 3 * NUM_BINS) {
output_buf[i] = local_hist[i];
}
i += 1 * 3 * NUM_BINS;
if (feature_mask.s3 && i < 4 * 3 * NUM_BINS) {
output_buf[i] = local_hist[i];
}
}
#define printf
__attribute__((reqd_work_group_size(LOCAL_SIZE_0, 1, 1)))
#if USE_CONSTANT_BUF == 1
__kernel void histogram64(__global const uchar4* restrict feature_data_base,
__constant const uchar4* restrict feature_masks __attribute__((max_constant_size(65536))),
const data_size_t feature_size,
__constant const data_size_t* restrict data_indices __attribute__((max_constant_size(65536))),
const data_size_t num_data,
__constant const score_t* restrict ordered_gradients __attribute__((max_constant_size(65536))),
#if CONST_HESSIAN == 0
__constant const score_t* restrict ordered_hessians __attribute__((max_constant_size(65536))),
#else
const score_t const_hessian,
#endif
__global char* restrict output_buf,
__global volatile int * sync_counters,
__global acc_type* restrict hist_buf_base) {
#else
__kernel void histogram64(__global const uchar4* feature_data_base,
__constant const uchar4* restrict feature_masks __attribute__((max_constant_size(65536))),
const data_size_t feature_size,
__global const data_size_t* data_indices,
const data_size_t num_data,
__global const score_t* ordered_gradients,
#if CONST_HESSIAN == 0
__global const score_t* ordered_hessians,
#else
const score_t const_hessian,
#endif
__global char* restrict output_buf,
__global volatile int * sync_counters,
__global acc_type* restrict hist_buf_base) {
#endif
// allocate the local memory array aligned with float2, to guarantee correct alignment on NVIDIA platforms
// otherwise a "Misaligned Address" exception may occur
__local float2 shared_array[LOCAL_MEM_SIZE/sizeof(float2)];
const uint gtid = get_global_id(0);
const uint gsize = get_global_size(0);
const ushort ltid = get_local_id(0);
const ushort lsize = LOCAL_SIZE_0; // get_local_size(0);
const ushort group_id = get_group_id(0);
// local memory per workgroup is 12 KB
// clear local memory
__local uint * ptr = (__local uint *) shared_array;
for (int i = ltid; i < LOCAL_MEM_SIZE/sizeof(uint); i += lsize) {
ptr[i] = 0;
}
barrier(CLK_LOCAL_MEM_FENCE);
// gradient/hessian histograms
// assume this starts at 32 * 4 = 128-byte boundary
// each bank: 2 * 4 * 64 * size_of(float) = 2 KB
// there are 4 banks (sub-histograms) used by 256 threads total 8 KB
/* memory layout of gh_hist:
-----------------------------------------------------------------------------------------------
bk0_g_f0_bin0 bk0_g_f1_bin0 bk0_g_f2_bin0 bk0_g_f3_bin0 bk0_h_f0_bin0 bk0_h_f1_bin0 bk0_h_f2_bin0 bk0_h_f3_bin0
bk1_g_f0_bin0 bk1_g_f1_bin0 bk1_g_f2_bin0 bk1_g_f3_bin0 bk1_h_f0_bin0 bk1_h_f1_bin0 bk1_h_f2_bin0 bk1_h_f3_bin0
bk2_g_f0_bin0 bk2_g_f1_bin0 bk2_g_f2_bin0 bk2_g_f3_bin0 bk2_h_f0_bin0 bk2_h_f1_bin0 bk2_h_f2_bin0 bk2_h_f3_bin0
bk3_g_f0_bin0 bk3_g_f1_bin0 bk3_g_f2_bin0 bk3_g_f3_bin0 bk3_h_f0_bin0 bk3_h_f1_bin0 bk3_h_f2_bin0 bk3_h_f3_bin0
bk0_g_f0_bin1 bk0_g_f1_bin1 bk0_g_f2_bin1 bk0_g_f3_bin1 bk0_h_f0_bin1 bk0_h_f1_bin1 bk0_h_f2_bin1 bk0_h_f3_bin1
bk1_g_f0_bin1 bk1_g_f1_bin1 bk1_g_f2_bin1 bk1_g_f3_bin1 bk1_h_f0_bin1 bk1_h_f1_bin1 bk1_h_f2_bin1 bk1_h_f3_bin1
bk2_g_f0_bin1 bk2_g_f1_bin1 bk2_g_f2_bin1 bk2_g_f3_bin1 bk2_h_f0_bin1 bk2_h_f1_bin1 bk2_h_f2_bin1 bk2_h_f3_bin1
bk3_g_f0_bin1 bk3_g_f1_bin1 bk3_g_f2_bin1 bk3_g_f3_bin1 bk3_h_f0_bin1 bk3_h_f1_bin1 bk3_h_f2_bin1 bk3_h_f3_bin1
...
bk0_g_f0_bin64 bk0_g_f1_bin64 bk0_g_f2_bin64 bk0_g_f3_bin64 bk0_h_f0_bin64 bk0_h_f1_bin64 bk0_h_f2_bin64 bk0_h_f3_bin64
bk1_g_f0_bin64 bk1_g_f1_bin64 bk1_g_f2_bin64 bk1_g_f3_bin64 bk1_h_f0_bin64 bk1_h_f1_bin64 bk1_h_f2_bin64 bk1_h_f3_bin64
bk2_g_f0_bin64 bk2_g_f1_bin64 bk2_g_f2_bin64 bk2_g_f3_bin64 bk2_h_f0_bin64 bk2_h_f1_bin64 bk2_h_f2_bin64 bk2_h_f3_bin64
bk3_g_f0_bin64 bk3_g_f1_bin64 bk3_g_f2_bin64 bk3_g_f3_bin64 bk3_h_f0_bin64 bk3_h_f1_bin64 bk3_h_f2_bin64 bk3_h_f3_bin64
-----------------------------------------------------------------------------------------------
*/
// with this organization, the LDS/shared memory bank is independent of the bin value
// all threads within a quarter-wavefront (half-warp) will not have any bank conflict
__local acc_type * gh_hist = (__local acc_type *)shared_array;
// counter histogram
// each bank: 4 * 64 * size_of(uint) = 1 KB
// there are 4 banks used by 256 threads total 4 KB
/* memory layout in cnt_hist:
-----------------------------------------------
bk0_c_f0_bin0 bk0_c_f1_bin0 bk0_c_f2_bin0 bk0_c_f3_bin0
bk1_c_f0_bin0 bk1_c_f1_bin0 bk1_c_f2_bin0 bk1_c_f3_bin0
bk2_c_f0_bin0 bk2_c_f1_bin0 bk2_c_f2_bin0 bk2_c_f3_bin0
bk3_c_f0_bin0 bk3_c_f1_bin0 bk3_c_f2_bin0 bk3_c_f3_bin0
bk0_c_f0_bin1 bk0_c_f1_bin1 bk0_c_f2_bin1 bk0_c_f3_bin1
bk1_c_f0_bin1 bk1_c_f1_bin1 bk1_c_f2_bin1 bk1_c_f3_bin1
bk2_c_f0_bin1 bk2_c_f1_bin1 bk2_c_f2_bin1 bk2_c_f3_bin1
bk3_c_f0_bin1 bk3_c_f1_bin1 bk3_c_f2_bin1 bk3_c_f3_bin1
...
bk0_c_f0_bin64 bk0_c_f1_bin64 bk0_c_f2_bin64 bk0_c_f3_bin64
bk1_c_f0_bin64 bk1_c_f1_bin64 bk1_c_f2_bin64 bk1_c_f3_bin64
bk2_c_f0_bin64 bk2_c_f1_bin64 bk2_c_f2_bin64 bk2_c_f3_bin64
bk3_c_f0_bin64 bk3_c_f1_bin64 bk3_c_f2_bin64 bk3_c_f3_bin64
-----------------------------------------------
*/
__local uint * cnt_hist = (__local uint *)(gh_hist + 2 * 4 * NUM_BINS * NUM_BANKS);
// thread 0, 1, 2, 3 compute histograms for gradients first
// thread 4, 5, 6, 7 compute histograms for hessians first
// etc.
uchar is_hessian_first = (ltid >> 2) & 1;
// thread 0-7 write result to bank0, 8-15 to bank1, 16-23 to bank2, 24-31 to bank3
ushort bank = (ltid >> 3) & BANK_MASK;
ushort group_feature = group_id >> POWER_FEATURE_WORKGROUPS;
// each 2^POWER_FEATURE_WORKGROUPS workgroups process on one feature (compile-time constant)
// feature_size is the number of examples per feature
__global const uchar4* feature_data = feature_data_base + group_feature * feature_size;
// size of threads that process this feature4
const uint subglobal_size = lsize * (1 << POWER_FEATURE_WORKGROUPS);
// equavalent thread ID in this subgroup for this feature4
const uint subglobal_tid = gtid - group_feature * subglobal_size;
// extract feature mask, when a byte is set to 0, that feature is disabled
#if ENABLE_ALL_FEATURES == 1
// hopefully the compiler will propogate the constants and eliminate all branches
uchar4 feature_mask = (uchar4)(0xff, 0xff, 0xff, 0xff);
#else
uchar4 feature_mask = feature_masks[group_feature];
#endif
// exit if all features are masked
if (!as_uint(feature_mask)) {
return;
}
// STAGE 1: read feature data, and gradient and hessian
// first half of the threads read feature data from global memory
// 4 features stored in a tuple MSB...(0, 1, 2, 3)...LSB
// We will prefetch data into the "next" variable at the beginning of each iteration
uchar4 feature4;
uchar4 feature4_next;
uchar4 feature4_prev;
// offset used to rotate feature4 vector
ushort offset = (ltid & 0x3);
// store gradient and hessian
float stat1, stat2;
float stat1_next, stat2_next;
ushort bin, addr, addr2;
data_size_t ind;
data_size_t ind_next;
stat1 = ordered_gradients[subglobal_tid];
#if CONST_HESSIAN == 0
stat2 = ordered_hessians[subglobal_tid];
#endif
#ifdef IGNORE_INDICES
ind = subglobal_tid;
#else
ind = data_indices[subglobal_tid];
#endif
feature4 = feature_data[ind];
feature4 = as_uchar4(as_uint(feature4) & 0x3f3f3f3f);
feature4_prev = feature4;
feature4_prev = as_uchar4(rotate(as_uint(feature4_prev), (uint)offset*8));
#if ENABLE_ALL_FEATURES == 0
// rotate feature_mask to match the feature order of each thread
feature_mask = as_uchar4(rotate(as_uint(feature_mask), (uint)offset*8));
#endif
acc_type s3_stat1 = 0.0f, s3_stat2 = 0.0f;
acc_type s2_stat1 = 0.0f, s2_stat2 = 0.0f;
acc_type s1_stat1 = 0.0f, s1_stat2 = 0.0f;
acc_type s0_stat1 = 0.0f, s0_stat2 = 0.0f;
// there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4
for (uint i = subglobal_tid; i < num_data; i += subglobal_size) {
// prefetch the next iteration variables
// we don't need bondary check because we have made the buffer larger
stat1_next = ordered_gradients[i + subglobal_size];
#if CONST_HESSIAN == 0
stat2_next = ordered_hessians[i + subglobal_size];
#endif
#ifdef IGNORE_INDICES
// we need to check to bounds here
ind_next = i + subglobal_size < num_data ? i + subglobal_size : i;
// start load next feature as early as possible
feature4_next = feature_data[ind_next];
#else
ind_next = data_indices[i + subglobal_size];
#endif
#if CONST_HESSIAN == 0
// swap gradient and hessian for threads 4, 5, 6, 7
float tmp = stat1;
stat1 = is_hessian_first ? stat2 : stat1;
stat2 = is_hessian_first ? tmp : stat2;
// stat1 = select(stat1, stat2, is_hessian_first);
// stat2 = select(stat2, tmp, is_hessian_first);
#endif
// STAGE 2: accumulate gradient and hessian
offset = (ltid & 0x3);
feature4 = as_uchar4(rotate(as_uint(feature4), (uint)offset*8));
bin = feature4.s3;
if ((bin != feature4_prev.s3) && feature_mask.s3) {
// printf("%3d (%4d): writing s3 %d %d offset %d", ltid, i, bin, feature4_prev.s3, offset);
bin = feature4_prev.s3;
feature4_prev.s3 = feature4.s3;
addr = bin * HG_BIN_MULT + bank * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
// thread 0, 1, 2, 3 now process feature 0, 1, 2, 3's gradients for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 0, 1, 2, 3's hessians for example 4, 5, 6, 7
atomic_local_add_f(gh_hist + addr, s3_stat1);
// thread 0, 1, 2, 3 now process feature 0, 1, 2, 3's hessians for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 0, 1, 2, 3's gradients for example 4, 5, 6, 7
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s3_stat2);
#endif
s3_stat1 = stat1;
s3_stat2 = stat2;
}
else {
// printf("%3d (%4d): acc s3 %d", ltid, i, bin);
s3_stat1 += stat1;
s3_stat2 += stat2;
}
bin = feature4.s2;
offset = (offset + 1) & 0x3;
if ((bin != feature4_prev.s2) && feature_mask.s2) {
// printf("%3d (%4d): writing s2 %d %d feature %d", ltid, i, bin, feature4_prev.s2, offset);
bin = feature4_prev.s2;
feature4_prev.s2 = feature4.s2;
addr = bin * HG_BIN_MULT + bank * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
// thread 0, 1, 2, 3 now process feature 1, 2, 3, 0's gradients for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 1, 2, 3, 0's hessians for example 4, 5, 6, 7
atomic_local_add_f(gh_hist + addr, s2_stat1);
// thread 0, 1, 2, 3 now process feature 1, 2, 3, 0's hessians for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 1, 2, 3, 0's gradients for example 4, 5, 6, 7
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s2_stat2);
#endif
s2_stat1 = stat1;
s2_stat2 = stat2;
}
else {
// printf("%3d (%4d): acc s2 %d", ltid, i, bin);
s2_stat1 += stat1;
s2_stat2 += stat2;
}
// prefetch the next iteration variables
// we don't need bondary check because if it is out of boundary, ind_next = 0
#ifndef IGNORE_INDICES
feature4_next = feature_data[ind_next];
#endif
bin = feature4.s1 & 0x3f;
offset = (offset + 1) & 0x3;
if ((bin != feature4_prev.s1) && feature_mask.s1) {
// printf("%3d (%4d): writing s1 %d %d feature %d", ltid, i, bin, feature4_prev.s1, offset);
bin = feature4_prev.s1;
feature4_prev.s1 = feature4.s1;
addr = bin * HG_BIN_MULT + bank * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
// thread 0, 1, 2, 3 now process feature 2, 3, 0, 1's gradients for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 2, 3, 0, 1's hessians for example 4, 5, 6, 7
atomic_local_add_f(gh_hist + addr, s1_stat1);
// thread 0, 1, 2, 3 now process feature 2, 3, 0, 1's hessians for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 2, 3, 0, 1's gradients for example 4, 5, 6, 7
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s1_stat2);
#endif
s1_stat1 = stat1;
s1_stat2 = stat2;
}
else {
// printf("%3d (%4d): acc s1 %d", ltid, i, bin);
s1_stat1 += stat1;
s1_stat2 += stat2;
}
bin = feature4.s0;
offset = (offset + 1) & 0x3;
if ((bin != feature4_prev.s0) && feature_mask.s0) {
// printf("%3d (%4d): writing s0 %d %d feature %d", ltid, i, bin, feature4_prev.s0, offset);
bin = feature4_prev.s0;
feature4_prev.s0 = feature4.s0;
addr = bin * HG_BIN_MULT + bank * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
// thread 0, 1, 2, 3 now process feature 3, 0, 1, 2's gradients for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 3, 0, 1, 2's hessians for example 4, 5, 6, 7
atomic_local_add_f(gh_hist + addr, s0_stat1);
// thread 0, 1, 2, 3 now process feature 3, 0, 1, 2's hessians for example 0, 1, 2, 3
// thread 4, 5, 6, 7 now process feature 3, 0, 1, 2's gradients for example 4, 5, 6, 7
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s0_stat2);
#endif
s0_stat1 = stat1;
s0_stat2 = stat2;
}
else {
// printf("%3d (%4d): acc s0 %d", ltid, i, bin);
s0_stat1 += stat1;
s0_stat2 += stat2;
}
// STAGE 3: accumulate counter
// there are 4 counters for 4 features
// thread 0, 1, 2, 3 now process feature 0, 1, 2, 3's counts for example 0, 1, 2, 3
offset = (ltid & 0x3);
if (feature_mask.s3) {
bin = feature4.s3;
addr = bin * CNT_BIN_MULT + bank * 4 + offset;
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3 now process feature 1, 2, 3, 0's counts for example 0, 1, 2, 3
offset = (offset + 1) & 0x3;
if (feature_mask.s2) {
bin = feature4.s2;
addr = bin * CNT_BIN_MULT + bank * 4 + offset;
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3 now process feature 2, 3, 0, 1's counts for example 0, 1, 2, 3
offset = (offset + 1) & 0x3;
if (feature_mask.s1) {
bin = feature4.s1;
addr = bin * CNT_BIN_MULT + bank * 4 + offset;
atom_inc(cnt_hist + addr);
}
// thread 0, 1, 2, 3 now process feature 3, 0, 1, 2's counts for example 0, 1, 2, 3
offset = (offset + 1) & 0x3;
if (feature_mask.s0) {
bin = feature4.s0;
addr = bin * CNT_BIN_MULT + bank * 4 + offset;
atom_inc(cnt_hist + addr);
}
stat1 = stat1_next;
stat2 = stat2_next;
feature4 = feature4_next;
feature4 = as_uchar4(as_uint(feature4) & 0x3f3f3f3f);
}
bin = feature4_prev.s3;
offset = (ltid & 0x3);
addr = bin * HG_BIN_MULT + bank * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
atomic_local_add_f(gh_hist + addr, s3_stat1);
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s3_stat2);
#endif
bin = feature4_prev.s2;
offset = (offset + 1) & 0x3;
addr = bin * HG_BIN_MULT + bank * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
atomic_local_add_f(gh_hist + addr, s2_stat1);
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s2_stat2);
#endif
bin = feature4_prev.s1;
offset = (offset + 1) & 0x3;
addr = bin * HG_BIN_MULT + bank * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
atomic_local_add_f(gh_hist + addr, s1_stat1);
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s1_stat2);
#endif
bin = feature4_prev.s0;
offset = (offset + 1) & 0x3;
addr = bin * HG_BIN_MULT + bank * 8 + is_hessian_first * 4 + offset;
addr2 = addr + 4 - 8 * is_hessian_first;
atomic_local_add_f(gh_hist + addr, s0_stat1);
#if CONST_HESSIAN == 0
atomic_local_add_f(gh_hist + addr2, s0_stat2);
#endif
barrier(CLK_LOCAL_MEM_FENCE);
#if ENABLE_ALL_FEATURES == 0
// restore feature_mask
feature_mask = feature_masks[group_feature];
#endif
// now reduce the 4 banks of subhistograms into 1
/* memory layout of gh_hist:
-----------------------------------------------------------------------------------------------
bk0_g_f0_bin0 bk0_g_f1_bin0 bk0_g_f2_bin0 bk0_g_f3_bin0 bk0_h_f0_bin0 bk0_h_f1_bin0 bk0_h_f2_bin0 bk0_h_f3_bin0
bk1_g_f0_bin0 bk1_g_f1_bin0 bk1_g_f2_bin0 bk1_g_f3_bin0 bk1_h_f0_bin0 bk1_h_f1_bin0 bk1_h_f2_bin0 bk1_h_f3_bin0
bk2_g_f0_bin0 bk2_g_f1_bin0 bk2_g_f2_bin0 bk2_g_f3_bin0 bk2_h_f0_bin0 bk2_h_f1_bin0 bk2_h_f2_bin0 bk2_h_f3_bin0
bk3_g_f0_bin0 bk3_g_f1_bin0 bk3_g_f2_bin0 bk3_g_f3_bin0 bk3_h_f0_bin0 bk3_h_f1_bin0 bk3_h_f2_bin0 bk3_h_f3_bin0
bk0_g_f0_bin1 bk0_g_f1_bin1 bk0_g_f2_bin1 bk0_g_f3_bin1 bk0_h_f0_bin1 bk0_h_f1_bin1 bk0_h_f2_bin1 bk0_h_f3_bin1
bk1_g_f0_bin1 bk1_g_f1_bin1 bk1_g_f2_bin1 bk1_g_f3_bin1 bk1_h_f0_bin1 bk1_h_f1_bin1 bk1_h_f2_bin1 bk1_h_f3_bin1
bk2_g_f0_bin1 bk2_g_f1_bin1 bk2_g_f2_bin1 bk2_g_f3_bin1 bk2_h_f0_bin1 bk2_h_f1_bin1 bk2_h_f2_bin1 bk2_h_f3_bin1
bk3_g_f0_bin1 bk3_g_f1_bin1 bk3_g_f2_bin1 bk3_g_f3_bin1 bk3_h_f0_bin1 bk3_h_f1_bin1 bk3_h_f2_bin1 bk3_h_f3_bin1
...
bk0_g_f0_bin64 bk0_g_f1_bin64 bk0_g_f2_bin64 bk0_g_f3_bin64 bk0_h_f0_bin64 bk0_h_f1_bin64 bk0_h_f2_bin64 bk0_h_f3_bin64
bk1_g_f0_bin64 bk1_g_f1_bin64 bk1_g_f2_bin64 bk1_g_f3_bin64 bk1_h_f0_bin64 bk1_h_f1_bin64 bk1_h_f2_bin64 bk1_h_f3_bin64
bk2_g_f0_bin64 bk2_g_f1_bin64 bk2_g_f2_bin64 bk2_g_f3_bin64 bk2_h_f0_bin64 bk2_h_f1_bin64 bk2_h_f2_bin64 bk2_h_f3_bin64
bk3_g_f0_bin64 bk3_g_f1_bin64 bk3_g_f2_bin64 bk3_g_f3_bin64 bk3_h_f0_bin64 bk3_h_f1_bin64 bk3_h_f2_bin64 bk3_h_f3_bin64
-----------------------------------------------------------------------------------------------
*/
/* memory layout in cnt_hist:
-----------------------------------------------
bk0_c_f0_bin0 bk0_c_f1_bin0 bk0_c_f2_bin0 bk0_c_f3_bin0
bk1_c_f0_bin0 bk1_c_f1_bin0 bk1_c_f2_bin0 bk1_c_f3_bin0
bk2_c_f0_bin0 bk2_c_f1_bin0 bk2_c_f2_bin0 bk2_c_f3_bin0
bk3_c_f0_bin0 bk3_c_f1_bin0 bk3_c_f2_bin0 bk3_c_f3_bin0
bk0_c_f0_bin1 bk0_c_f1_bin1 bk0_c_f2_bin1 bk0_c_f3_bin1
bk1_c_f0_bin1 bk1_c_f1_bin1 bk1_c_f2_bin1 bk1_c_f3_bin1
bk2_c_f0_bin1 bk2_c_f1_bin1 bk2_c_f2_bin1 bk2_c_f3_bin1
bk3_c_f0_bin1 bk3_c_f1_bin1 bk3_c_f2_bin1 bk3_c_f3_bin1
...
bk0_c_f0_bin64 bk0_c_f1_bin64 bk0_c_f2_bin64 bk0_c_f3_bin64
bk1_c_f0_bin64 bk1_c_f1_bin64 bk1_c_f2_bin64 bk1_c_f3_bin64
bk2_c_f0_bin64 bk2_c_f1_bin64 bk2_c_f2_bin64 bk2_c_f3_bin64
bk3_c_f0_bin64 bk3_c_f1_bin64 bk3_c_f2_bin64 bk3_c_f3_bin64
-----------------------------------------------
*/
acc_type g_val = 0.0f;
acc_type h_val = 0.0f;
uint cnt_val = 0;
// 256 threads, working on 4 features and 64 bins,
// so each thread has an independent feature/bin to work on.
const ushort feature_id = ltid & 3; // range 0 - 4
const ushort bin_id = ltid >> 2; // range 0 - 63
offset = (ltid >> 2) & BANK_MASK; // helps avoid LDS bank conflicts
for (int i = 0; i < NUM_BANKS; ++i) {
ushort bank_id = (i + offset) & BANK_MASK;
g_val += gh_hist[bin_id * HG_BIN_MULT + bank_id * 8 + feature_id];
h_val += gh_hist[bin_id * HG_BIN_MULT + bank_id * 8 + feature_id + 4];
cnt_val += cnt_hist[bin_id * CNT_BIN_MULT + bank_id * 4 + feature_id];
}
// now thread 0 - 3 holds feature 0, 1, 2, 3's gradient, hessian and count bin 0
// now thread 4 - 7 holds feature 0, 1, 2, 3's gradient, hessian and count bin 1
// etc,
#if CONST_HESSIAN == 1
g_val += h_val;
h_val = cnt_val * const_hessian;
#endif
// write to output
// write gradients and hessians histogram for all 4 features
// output data in linear order for further reduction
// output size = 4 (features) * 3 (counters) * 64 (bins) * sizeof(float)
/* memory layout of output:
g_f0_bin0 g_f1_bin0 g_f2_bin0 g_f3_bin0
g_f0_bin1 g_f1_bin1 g_f2_bin1 g_f3_bin1
...
g_f0_bin63 g_f1_bin63 g_f2_bin63 g_f3_bin63
h_f0_bin0 h_f1_bin0 h_f2_bin0 h_f3_bin0
h_f0_bin1 h_f1_bin1 h_f2_bin1 h_f3_bin1
...
h_f0_bin63 h_f1_bin63 h_f2_bin63 h_f3_bin63
c_f0_bin0 c_f1_bin0 c_f2_bin0 c_f3_bin0
c_f0_bin1 c_f1_bin1 c_f2_bin1 c_f3_bin1
...
c_f0_bin63 c_f1_bin63 c_f2_bin63 c_f3_bin63
*/
// if there is only one workgroup processing this feature4, don't even need to write
uint feature4_id = (group_id >> POWER_FEATURE_WORKGROUPS);
#if POWER_FEATURE_WORKGROUPS != 0
__global acc_type * restrict output = (__global acc_type * restrict)output_buf + group_id * 4 * 3 * NUM_BINS;
// if g_val and h_val are double, they are converted to float here
// write gradients for 4 features
output[0 * 4 * NUM_BINS + ltid] = g_val;
// write hessians for 4 features
output[1 * 4 * NUM_BINS + ltid] = h_val;
// write counts for 4 features
output[2 * 4 * NUM_BINS + ltid] = as_acc_type((acc_int_type)cnt_val);
barrier(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE);
mem_fence(CLK_GLOBAL_MEM_FENCE);
// To avoid the cost of an extra reducting kernel, we have to deal with some
// gray area in OpenCL. We want the last work group that process this feature to
// make the final reduction, and other threads will just quit.
// This requires that the results written by other workgroups available to the
// last workgroup (memory consistency)
#if NVIDIA == 1
// this is equavalent to CUDA __threadfence();
// ensure the writes above goes to main memory and other workgroups can see it
asm volatile("{\n\tmembar.gl;\n\t}\n\t" :::"memory");
#else
// FIXME: how to do the above on AMD GPUs??
// GCN ISA says that the all writes will bypass L1 cache (write through),
// however when the last thread is reading sub-histogram data we have to
// make sure that no part of data is modified in local L1 cache of other workgroups.
// Otherwise reading can be a problem (atomic operations to get consistency).
// But in our case, the sub-histogram of this workgroup cannot be in the cache
// of another workgroup, so the following trick will work just fine.
#endif
// Now, we want one workgroup to do the final reduction.
// Other workgroups processing the same feature quit.
// The is done by using an global atomic counter.
// On AMD GPUs ideally this should be done in GDS,
// but currently there is no easy way to access it via OpenCL.
__local uint * counter_val = cnt_hist;
if (ltid == 0) {
// all workgroups processing the same feature add this counter
*counter_val = atom_inc(sync_counters + feature4_id);
}
// make sure everyone in this workgroup is here
barrier(CLK_LOCAL_MEM_FENCE);
// everyone in this wrokgroup: if we are the last workgroup, then do reduction!
if (*counter_val == (1 << POWER_FEATURE_WORKGROUPS) - 1) {
if (ltid == 0) {
// printf("workgroup %d start reduction!\n", group_id);
// printf("feature_data[0] = %d %d %d %d", feature_data[0].s0, feature_data[0].s1, feature_data[0].s2, feature_data[0].s3);
// clear the sync counter for using it next time
sync_counters[feature4_id] = 0;
}
#else
// only 1 work group, no need to increase counter
// the reduction will become a simple copy
if (1) {
barrier(CLK_LOCAL_MEM_FENCE);
#endif
// locate our feature4's block in output memory
uint output_offset = (feature4_id << POWER_FEATURE_WORKGROUPS);
__global acc_type const * restrict feature4_subhists =
(__global acc_type *)output_buf + output_offset * 4 * 3 * NUM_BINS;
// skip reading the data already in local memory
uint skip_id = group_id ^ output_offset;
// locate output histogram location for this feature4
__global acc_type* restrict hist_buf = hist_buf_base + feature4_id * 4 * 3 * NUM_BINS;
within_kernel_reduction64x4(feature_mask, feature4_subhists, skip_id, g_val, h_val, cnt_val,
1 << POWER_FEATURE_WORKGROUPS, hist_buf, (__local acc_type *)shared_array);
}
}
// The following line ends the string literal, adds an extra #endif at the end
// the +9 skips extra characters ")", newline, "#endif" and newline at the beginning
// )"" "\n#endif" + 9
#endif
......@@ -5,6 +5,7 @@
#include <LightGBM/network.h>
#include "serial_tree_learner.h"
#include "gpu_tree_learner.h"
#include <cstring>
......@@ -18,11 +19,12 @@ namespace LightGBM {
* Different machine will find best split on different features, then sync global best split
* It is recommonded used when #data is small or #feature is large
*/
class FeatureParallelTreeLearner: public SerialTreeLearner {
template <typename TREELEARNER_T>
class FeatureParallelTreeLearner: public TREELEARNER_T {
public:
explicit FeatureParallelTreeLearner(const TreeConfig* tree_config);
~FeatureParallelTreeLearner();
void Init(const Dataset* train_data) override;
void Init(const Dataset* train_data, bool is_constant_hessian) override;
protected:
void BeforeTrain() override;
......@@ -43,11 +45,12 @@ private:
* Workers use local data to construct histograms locally, then sync up global histograms.
* It is recommonded used when #data is large or #feature is small
*/
class DataParallelTreeLearner: public SerialTreeLearner {
template <typename TREELEARNER_T>
class DataParallelTreeLearner: public TREELEARNER_T {
public:
explicit DataParallelTreeLearner(const TreeConfig* tree_config);
~DataParallelTreeLearner();
void Init(const Dataset* train_data) override;
void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetConfig(const TreeConfig* tree_config) override;
protected:
void BeforeTrain() override;
......@@ -95,11 +98,12 @@ private:
* Here using voting to reduce features, and only aggregate histograms for selected features.
* When #data is large and #feature is large, you can use this to have better speed-up
*/
class VotingParallelTreeLearner: public SerialTreeLearner {
template <typename TREELEARNER_T>
class VotingParallelTreeLearner: public TREELEARNER_T {
public:
explicit VotingParallelTreeLearner(const TreeConfig* tree_config);
~VotingParallelTreeLearner() { }
void Init(const Dataset* train_data) override;
void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetConfig(const TreeConfig* tree_config) override;
protected:
void BeforeTrain() override;
......
......@@ -37,10 +37,11 @@ SerialTreeLearner::~SerialTreeLearner() {
#endif
}
void SerialTreeLearner::Init(const Dataset* train_data) {
void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) {
train_data_ = train_data;
num_data_ = train_data_->num_data();
num_features_ = train_data_->num_features();
is_constant_hessian_ = is_constant_hessian;
int max_cache_size = 0;
// Get the max size of pool
if (tree_config_->histogram_pool_size <= 0) {
......
......@@ -18,6 +18,11 @@
#include <random>
#include <cmath>
#include <memory>
#ifdef USE_GPU
// Use 4KBytes aligned allocator for ordered gradients and ordered hessians when GPU is enabled.
// This is necessary to pin the two arrays in memory and make transferring faster.
#include <boost/align/aligned_allocator.hpp>
#endif
namespace LightGBM {
......@@ -30,7 +35,7 @@ public:
~SerialTreeLearner();
void Init(const Dataset* train_data) override;
void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingData(const Dataset* train_data) override;
......@@ -69,7 +74,7 @@ protected:
*/
virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf);
void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
virtual void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
/*!
* \brief Find best thresholds for all features, using multi-threading.
......@@ -130,10 +135,17 @@ protected:
/*! \brief stores best thresholds for all feature for larger leaf */
std::unique_ptr<LeafSplits> larger_leaf_splits_;
#ifdef USE_GPU
/*! \brief gradients of current iteration, ordered for cache optimized, aligned to 4K page */
std::vector<score_t, boost::alignment::aligned_allocator<score_t, 4096>> ordered_gradients_;
/*! \brief hessians of current iteration, ordered for cache optimized, aligned to 4K page */
std::vector<score_t, boost::alignment::aligned_allocator<score_t, 4096>> ordered_hessians_;
#else
/*! \brief gradients of current iteration, ordered for cache optimized */
std::vector<score_t> ordered_gradients_;
/*! \brief hessians of current iteration, ordered for cache optimized */
std::vector<score_t> ordered_hessians_;
#endif
/*! \brief Store ordered bin */
std::vector<std::unique_ptr<OrderedBin>> ordered_bins_;
......
#include <LightGBM/tree_learner.h>
#include "serial_tree_learner.h"
#include "gpu_tree_learner.h"
#include "parallel_tree_learner.h"
namespace LightGBM {
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& type, const TreeConfig* tree_config) {
if (type == std::string("serial")) {
return new SerialTreeLearner(tree_config);
} else if (type == std::string("feature")) {
return new FeatureParallelTreeLearner(tree_config);
} else if (type == std::string("data")) {
return new DataParallelTreeLearner(tree_config);
} else if (type == std::string("voting")) {
return new VotingParallelTreeLearner(tree_config);
TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, const std::string& device_type, const TreeConfig* tree_config) {
if (device_type == std::string("cpu")) {
if (learner_type == std::string("serial")) {
return new SerialTreeLearner(tree_config);
} else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<SerialTreeLearner>(tree_config);
} else if (learner_type == std::string("data")) {
return new DataParallelTreeLearner<SerialTreeLearner>(tree_config);
} else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<SerialTreeLearner>(tree_config);
}
}
else if (device_type == std::string("gpu")) {
if (learner_type == std::string("serial")) {
return new GPUTreeLearner(tree_config);
} else if (learner_type == std::string("feature")) {
return new FeatureParallelTreeLearner<GPUTreeLearner>(tree_config);
} else if (learner_type == std::string("data")) {
return new DataParallelTreeLearner<GPUTreeLearner>(tree_config);
} else if (learner_type == std::string("voting")) {
return new VotingParallelTreeLearner<GPUTreeLearner>(tree_config);
}
}
return nullptr;
}
......
......@@ -9,25 +9,27 @@
namespace LightGBM {
VotingParallelTreeLearner::VotingParallelTreeLearner(const TreeConfig* tree_config)
:SerialTreeLearner(tree_config) {
top_k_ = tree_config_->top_k;
template <typename TREELEARNER_T>
VotingParallelTreeLearner<TREELEARNER_T>::VotingParallelTreeLearner(const TreeConfig* tree_config)
:TREELEARNER_T(tree_config) {
top_k_ = this->tree_config_->top_k;
}
void VotingParallelTreeLearner::Init(const Dataset* train_data) {
SerialTreeLearner::Init(train_data);
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, bool is_constant_hessian) {
TREELEARNER_T::Init(train_data, is_constant_hessian);
rank_ = Network::rank();
num_machines_ = Network::num_machines();
// limit top k
if (top_k_ > num_features_) {
top_k_ = num_features_;
if (top_k_ > this->num_features_) {
top_k_ = this->num_features_;
}
// get max bin
int max_bin = 0;
for (int i = 0; i < num_features_; ++i) {
if (max_bin < train_data_->FeatureNumBin(i)) {
max_bin = train_data_->FeatureNumBin(i);
for (int i = 0; i < this->num_features_; ++i) {
if (max_bin < this->train_data_->FeatureNumBin(i)) {
max_bin = this->train_data_->FeatureNumBin(i);
}
}
// calculate buffer size
......@@ -36,29 +38,29 @@ void VotingParallelTreeLearner::Init(const Dataset* train_data) {
input_buffer_.resize(buffer_size);
output_buffer_.resize(buffer_size);
smaller_is_feature_aggregated_.resize(num_features_);
larger_is_feature_aggregated_.resize(num_features_);
smaller_is_feature_aggregated_.resize(this->num_features_);
larger_is_feature_aggregated_.resize(this->num_features_);
block_start_.resize(num_machines_);
block_len_.resize(num_machines_);
smaller_buffer_read_start_pos_.resize(num_features_);
larger_buffer_read_start_pos_.resize(num_features_);
global_data_count_in_leaf_.resize(tree_config_->num_leaves);
smaller_buffer_read_start_pos_.resize(this->num_features_);
larger_buffer_read_start_pos_.resize(this->num_features_);
global_data_count_in_leaf_.resize(this->tree_config_->num_leaves);
smaller_leaf_splits_global_.reset(new LeafSplits(train_data_->num_data()));
larger_leaf_splits_global_.reset(new LeafSplits(train_data_->num_data()));
smaller_leaf_splits_global_.reset(new LeafSplits(this->train_data_->num_data()));
larger_leaf_splits_global_.reset(new LeafSplits(this->train_data_->num_data()));
local_tree_config_ = *tree_config_;
local_tree_config_ = *this->tree_config_;
local_tree_config_.min_data_in_leaf /= num_machines_;
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_;
histogram_pool_.ResetConfig(&local_tree_config_);
this->histogram_pool_.ResetConfig(&local_tree_config_);
// initialize histograms for global
smaller_leaf_histogram_array_global_.reset(new FeatureHistogram[num_features_]);
larger_leaf_histogram_array_global_.reset(new FeatureHistogram[num_features_]);
auto num_total_bin = train_data_->NumTotalBin();
smaller_leaf_histogram_array_global_.reset(new FeatureHistogram[this->num_features_]);
larger_leaf_histogram_array_global_.reset(new FeatureHistogram[this->num_features_]);
auto num_total_bin = this->train_data_->NumTotalBin();
smaller_leaf_histogram_data_.resize(num_total_bin);
larger_leaf_histogram_data_.resize(num_total_bin);
feature_metas_.resize(train_data->num_features());
......@@ -70,7 +72,7 @@ void VotingParallelTreeLearner::Init(const Dataset* train_data) {
} else {
feature_metas_[i].bias = 0;
}
feature_metas_[i].tree_config = tree_config_;
feature_metas_[i].tree_config = this->tree_config_;
}
uint64_t offset = 0;
for (int j = 0; j < train_data->num_features(); ++j) {
......@@ -85,25 +87,27 @@ void VotingParallelTreeLearner::Init(const Dataset* train_data) {
}
}
void VotingParallelTreeLearner::ResetConfig(const TreeConfig* tree_config) {
SerialTreeLearner::ResetConfig(tree_config);
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::ResetConfig(const TreeConfig* tree_config) {
TREELEARNER_T::ResetConfig(tree_config);
local_tree_config_ = *tree_config_;
local_tree_config_ = *this->tree_config_;
local_tree_config_.min_data_in_leaf /= num_machines_;
local_tree_config_.min_sum_hessian_in_leaf /= num_machines_;
histogram_pool_.ResetConfig(&local_tree_config_);
global_data_count_in_leaf_.resize(tree_config_->num_leaves);
this->histogram_pool_.ResetConfig(&local_tree_config_);
global_data_count_in_leaf_.resize(this->tree_config_->num_leaves);
for (size_t i = 0; i < feature_metas_.size(); ++i) {
feature_metas_[i].tree_config = tree_config_;
feature_metas_[i].tree_config = this->tree_config_;
}
}
void VotingParallelTreeLearner::BeforeTrain() {
SerialTreeLearner::BeforeTrain();
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
TREELEARNER_T::BeforeTrain();
// sync global data sumup info
std::tuple<data_size_t, double, double> data(smaller_leaf_splits_->num_data_in_leaf(), smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians());
std::tuple<data_size_t, double, double> data(this->smaller_leaf_splits_->num_data_in_leaf(), this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians());
int size = sizeof(std::tuple<data_size_t, double, double>);
std::memcpy(input_buffer_.data(), &data, size);
......@@ -133,20 +137,21 @@ void VotingParallelTreeLearner::BeforeTrain() {
global_data_count_in_leaf_[0] = std::get<0>(data);
}
bool VotingParallelTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
if (SerialTreeLearner::BeforeFindBestSplit(tree, left_leaf, right_leaf)) {
template <typename TREELEARNER_T>
bool VotingParallelTreeLearner<TREELEARNER_T>::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
if (TREELEARNER_T::BeforeFindBestSplit(tree, left_leaf, right_leaf)) {
data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
if (right_leaf < 0) {
return true;
} else if (num_data_in_left_child < num_data_in_right_child) {
// get local sumup
smaller_leaf_splits_->Init(left_leaf, data_partition_.get(), gradients_, hessians_);
larger_leaf_splits_->Init(right_leaf, data_partition_.get(), gradients_, hessians_);
this->smaller_leaf_splits_->Init(left_leaf, this->data_partition_.get(), this->gradients_, this->hessians_);
this->larger_leaf_splits_->Init(right_leaf, this->data_partition_.get(), this->gradients_, this->hessians_);
} else {
// get local sumup
smaller_leaf_splits_->Init(right_leaf, data_partition_.get(), gradients_, hessians_);
larger_leaf_splits_->Init(left_leaf, data_partition_.get(), gradients_, hessians_);
this->smaller_leaf_splits_->Init(right_leaf, this->data_partition_.get(), this->gradients_, this->hessians_);
this->larger_leaf_splits_->Init(left_leaf, this->data_partition_.get(), this->gradients_, this->hessians_);
}
return true;
} else {
......@@ -154,14 +159,15 @@ bool VotingParallelTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_l
}
}
void VotingParallelTreeLearner::GlobalVoting(int leaf_idx, const std::vector<SplitInfo>& splits, std::vector<int>* out) {
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::GlobalVoting(int leaf_idx, const std::vector<SplitInfo>& splits, std::vector<int>* out) {
out->clear();
if (leaf_idx < 0) {
return;
}
// get mean number on machines
score_t mean_num_data = GetGlobalDataCountInLeaf(leaf_idx) / static_cast<score_t>(num_machines_);
std::vector<SplitInfo> feature_best_split(num_features_, SplitInfo());
std::vector<SplitInfo> feature_best_split(this->num_features_, SplitInfo());
for (auto & split : splits) {
int fid = split.feature;
if (fid < 0) {
......@@ -185,8 +191,9 @@ void VotingParallelTreeLearner::GlobalVoting(int leaf_idx, const std::vector<Spl
}
}
void VotingParallelTreeLearner::CopyLocalHistogram(const std::vector<int>& smaller_top_features, const std::vector<int>& larger_top_features) {
for (int i = 0; i < num_features_; ++i) {
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::CopyLocalHistogram(const std::vector<int>& smaller_top_features, const std::vector<int>& larger_top_features) {
for (int i = 0; i < this->num_features_; ++i) {
smaller_is_feature_aggregated_[i] = false;
larger_is_feature_aggregated_[i] = false;
}
......@@ -203,7 +210,7 @@ void VotingParallelTreeLearner::CopyLocalHistogram(const std::vector<int>& small
while (cur_used_features < cur_total_feature) {
// copy smaller leaf histograms first
if (smaller_idx < smaller_top_features.size()) {
int inner_feature_index = train_data_->InnerFeatureIndex(smaller_top_features[smaller_idx]);
int inner_feature_index = this->train_data_->InnerFeatureIndex(smaller_top_features[smaller_idx]);
++cur_used_features;
// mark local aggregated feature
if (i == rank_) {
......@@ -211,9 +218,9 @@ void VotingParallelTreeLearner::CopyLocalHistogram(const std::vector<int>& small
smaller_buffer_read_start_pos_[inner_feature_index] = static_cast<int>(cur_size);
}
// copy
std::memcpy(input_buffer_.data() + reduce_scatter_size_, smaller_leaf_histogram_array_[inner_feature_index].RawData(), smaller_leaf_histogram_array_[inner_feature_index].SizeOfHistgram());
cur_size += smaller_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
reduce_scatter_size_ += smaller_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
std::memcpy(input_buffer_.data() + reduce_scatter_size_, this->smaller_leaf_histogram_array_[inner_feature_index].RawData(), this->smaller_leaf_histogram_array_[inner_feature_index].SizeOfHistgram());
cur_size += this->smaller_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
reduce_scatter_size_ += this->smaller_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
++smaller_idx;
}
if (cur_used_features >= cur_total_feature) {
......@@ -221,7 +228,7 @@ void VotingParallelTreeLearner::CopyLocalHistogram(const std::vector<int>& small
}
// then copy larger leaf histograms
if (larger_idx < larger_top_features.size()) {
int inner_feature_index = train_data_->InnerFeatureIndex(larger_top_features[larger_idx]);
int inner_feature_index = this->train_data_->InnerFeatureIndex(larger_top_features[larger_idx]);
++cur_used_features;
// mark local aggregated feature
if (i == rank_) {
......@@ -229,9 +236,9 @@ void VotingParallelTreeLearner::CopyLocalHistogram(const std::vector<int>& small
larger_buffer_read_start_pos_[inner_feature_index] = static_cast<int>(cur_size);
}
// copy
std::memcpy(input_buffer_.data() + reduce_scatter_size_, larger_leaf_histogram_array_[inner_feature_index].RawData(), larger_leaf_histogram_array_[inner_feature_index].SizeOfHistgram());
cur_size += larger_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
reduce_scatter_size_ += larger_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
std::memcpy(input_buffer_.data() + reduce_scatter_size_, this->larger_leaf_histogram_array_[inner_feature_index].RawData(), this->larger_leaf_histogram_array_[inner_feature_index].SizeOfHistgram());
cur_size += this->larger_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
reduce_scatter_size_ += this->larger_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
++larger_idx;
}
}
......@@ -243,60 +250,61 @@ void VotingParallelTreeLearner::CopyLocalHistogram(const std::vector<int>& small
}
}
void VotingParallelTreeLearner::FindBestThresholds() {
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestThresholds() {
// use local data to find local best splits
std::vector<int8_t> is_feature_used(num_features_, 0);
std::vector<int8_t> is_feature_used(this->num_features_, 0);
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
if (!this->is_feature_used_[feature_index]) continue;
if (this->parent_leaf_histogram_array_ != nullptr
&& !this->parent_leaf_histogram_array_[feature_index].is_splittable()) {
this->smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
is_feature_used[feature_index] = 1;
}
bool use_subtract = true;
if (parent_leaf_histogram_array_ == nullptr) {
if (this->parent_leaf_histogram_array_ == nullptr) {
use_subtract = false;
}
ConstructHistograms(is_feature_used, use_subtract);
this->ConstructHistograms(is_feature_used, use_subtract);
std::vector<SplitInfo> smaller_bestsplit_per_features(num_features_);
std::vector<SplitInfo> larger_bestsplit_per_features(num_features_);
std::vector<SplitInfo> smaller_bestsplit_per_features(this->num_features_);
std::vector<SplitInfo> larger_bestsplit_per_features(this->num_features_);
OMP_INIT_EX();
// find splits
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
OMP_LOOP_EX_BEGIN();
if (!is_feature_used[feature_index]) { continue; }
const int real_feature_index = train_data_->RealFeatureIndex(feature_index);
train_data_->FixHistogram(feature_index,
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(),
smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_histogram_array_[feature_index].RawData());
smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
smaller_leaf_splits_->sum_gradients(),
smaller_leaf_splits_->sum_hessians(),
smaller_leaf_splits_->num_data_in_leaf(),
const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
this->train_data_->FixHistogram(feature_index,
this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(),
this->smaller_leaf_splits_->num_data_in_leaf(),
this->smaller_leaf_histogram_array_[feature_index].RawData());
this->smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
this->smaller_leaf_splits_->sum_gradients(),
this->smaller_leaf_splits_->sum_hessians(),
this->smaller_leaf_splits_->num_data_in_leaf(),
&smaller_bestsplit_per_features[feature_index]);
smaller_bestsplit_per_features[feature_index].feature = real_feature_index;
// only has root leaf
if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; }
if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->LeafIndex() < 0) { continue; }
if (use_subtract) {
larger_leaf_histogram_array_[feature_index].Subtract(smaller_leaf_histogram_array_[feature_index]);
this->larger_leaf_histogram_array_[feature_index].Subtract(this->smaller_leaf_histogram_array_[feature_index]);
} else {
train_data_->FixHistogram(feature_index, larger_leaf_splits_->sum_gradients(), larger_leaf_splits_->sum_hessians(),
larger_leaf_splits_->num_data_in_leaf(),
larger_leaf_histogram_array_[feature_index].RawData());
this->train_data_->FixHistogram(feature_index, this->larger_leaf_splits_->sum_gradients(), this->larger_leaf_splits_->sum_hessians(),
this->larger_leaf_splits_->num_data_in_leaf(),
this->larger_leaf_histogram_array_[feature_index].RawData());
}
// find best threshold for larger child
larger_leaf_histogram_array_[feature_index].FindBestThreshold(
larger_leaf_splits_->sum_gradients(),
larger_leaf_splits_->sum_hessians(),
larger_leaf_splits_->num_data_in_leaf(),
this->larger_leaf_histogram_array_[feature_index].FindBestThreshold(
this->larger_leaf_splits_->sum_gradients(),
this->larger_leaf_splits_->sum_hessians(),
this->larger_leaf_splits_->num_data_in_leaf(),
&larger_bestsplit_per_features[feature_index]);
larger_bestsplit_per_features[feature_index].feature = real_feature_index;
OMP_LOOP_EX_END();
......@@ -332,8 +340,8 @@ void VotingParallelTreeLearner::FindBestThresholds() {
}
// global voting
std::vector<int> smaller_top_features, larger_top_features;
GlobalVoting(smaller_leaf_splits_->LeafIndex(), smaller_top_k_splits_global, &smaller_top_features);
GlobalVoting(larger_leaf_splits_->LeafIndex(), larger_top_k_splits_global, &larger_top_features);
GlobalVoting(this->smaller_leaf_splits_->LeafIndex(), smaller_top_k_splits_global, &smaller_top_features);
GlobalVoting(this->larger_leaf_splits_->LeafIndex(), larger_top_k_splits_global, &larger_top_features);
// copy local histgrams to buffer
CopyLocalHistogram(smaller_top_features, larger_top_features);
......@@ -341,11 +349,11 @@ void VotingParallelTreeLearner::FindBestThresholds() {
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(), block_len_.data(),
output_buffer_.data(), &HistogramBinEntry::SumReducer);
std::vector<SplitInfo> smaller_best(num_threads_);
std::vector<SplitInfo> larger_best(num_threads_);
std::vector<SplitInfo> smaller_best(this->num_threads_);
std::vector<SplitInfo> larger_best(this->num_threads_);
// find best split from local aggregated histograms
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num();
if (smaller_is_feature_aggregated_[feature_index]) {
......@@ -354,7 +362,7 @@ void VotingParallelTreeLearner::FindBestThresholds() {
smaller_leaf_histogram_array_global_[feature_index].FromMemory(
output_buffer_.data() + smaller_buffer_read_start_pos_[feature_index]);
train_data_->FixHistogram(feature_index,
this->train_data_->FixHistogram(feature_index,
smaller_leaf_splits_global_->sum_gradients(), smaller_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()),
smaller_leaf_histogram_array_global_[feature_index].RawData());
......@@ -367,7 +375,7 @@ void VotingParallelTreeLearner::FindBestThresholds() {
&smaller_split);
if (smaller_split.gain > smaller_best[tid].gain) {
smaller_best[tid] = smaller_split;
smaller_best[tid].feature = train_data_->RealFeatureIndex(feature_index);
smaller_best[tid].feature = this->train_data_->RealFeatureIndex(feature_index);
}
}
......@@ -376,7 +384,7 @@ void VotingParallelTreeLearner::FindBestThresholds() {
// restore from buffer
larger_leaf_histogram_array_global_[feature_index].FromMemory(output_buffer_.data() + larger_buffer_read_start_pos_[feature_index]);
train_data_->FixHistogram(feature_index,
this->train_data_->FixHistogram(feature_index,
larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()),
larger_leaf_histogram_array_global_[feature_index].RawData());
......@@ -389,31 +397,32 @@ void VotingParallelTreeLearner::FindBestThresholds() {
&larger_split);
if (larger_split.gain > larger_best[tid].gain) {
larger_best[tid] = larger_split;
larger_best[tid].feature = train_data_->RealFeatureIndex(feature_index);
larger_best[tid].feature = this->train_data_->RealFeatureIndex(feature_index);
}
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_best);
int leaf = smaller_leaf_splits_->LeafIndex();
best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx];
int leaf = this->smaller_leaf_splits_->LeafIndex();
this->best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx];
if (larger_leaf_splits_ != nullptr && larger_leaf_splits_->LeafIndex() >= 0) {
leaf = larger_leaf_splits_->LeafIndex();
if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->LeafIndex() >= 0) {
leaf = this->larger_leaf_splits_->LeafIndex();
auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best);
best_split_per_leaf_[leaf] = larger_best[larger_best_idx];
this->best_split_per_leaf_[leaf] = larger_best[larger_best_idx];
}
}
void VotingParallelTreeLearner::FindBestSplitsForLeaves() {
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsForLeaves() {
// find local best
SplitInfo smaller_best, larger_best;
smaller_best = best_split_per_leaf_[smaller_leaf_splits_->LeafIndex()];
smaller_best = this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()];
// find local best split for larger leaf
if (larger_leaf_splits_->LeafIndex() >= 0) {
larger_best = best_split_per_leaf_[larger_leaf_splits_->LeafIndex()];
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
larger_best = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
}
// sync global best info
std::memcpy(input_buffer_.data(), &smaller_best, sizeof(SplitInfo));
......@@ -425,34 +434,38 @@ void VotingParallelTreeLearner::FindBestSplitsForLeaves() {
std::memcpy(&larger_best, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
// copy back
best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best;
this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best;
if (larger_best.feature >= 0 && larger_leaf_splits_global_->LeafIndex() >= 0) {
best_split_per_leaf_[larger_leaf_splits_global_->LeafIndex()] = larger_best;
this->best_split_per_leaf_[larger_leaf_splits_global_->LeafIndex()] = larger_best;
}
}
void VotingParallelTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
SerialTreeLearner::Split(tree, best_Leaf, left_leaf, right_leaf);
const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
TREELEARNER_T::Split(tree, best_Leaf, left_leaf, right_leaf);
const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf];
// set the global number of data for leaves
global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
global_data_count_in_leaf_[*right_leaf] = best_split_info.right_count;
// init the global sumup info
if (best_split_info.left_count < best_split_info.right_count) {
smaller_leaf_splits_global_->Init(*left_leaf, data_partition_.get(),
smaller_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian);
larger_leaf_splits_global_->Init(*right_leaf, data_partition_.get(),
larger_leaf_splits_global_->Init(*right_leaf, this->data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian);
} else {
smaller_leaf_splits_global_->Init(*right_leaf, data_partition_.get(),
smaller_leaf_splits_global_->Init(*right_leaf, this->data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian);
larger_leaf_splits_global_->Init(*left_leaf, data_partition_.get(),
larger_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian);
}
}
// instantiate template classes, otherwise linker cannot find the code
template class VotingParallelTreeLearner<GPUTreeLearner>;
template class VotingParallelTreeLearner<SerialTreeLearner>;
} // namespace FTLBoost
......@@ -220,10 +220,10 @@ class TestEngine(unittest.TestCase):
gbm3.save_model('categorical.model')
gbm4 = lgb.Booster(model_file='categorical.model')
pred4 = list(gbm4.predict(X_test))
self.assertListEqual(pred0, pred1)
self.assertListEqual(pred0, pred2)
self.assertListEqual(pred0, pred3)
self.assertListEqual(pred0, pred4)
np.testing.assert_almost_equal(pred0, pred1)
np.testing.assert_almost_equal(pred0, pred2)
np.testing.assert_almost_equal(pred0, pred3)
np.testing.assert_almost_equal(pred0, pred4)
print("----------------------------------------------------------------------")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment