/*! * Copyright (c) 2020 IBM Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. */ #ifndef LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_ #define LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_ #include "LightGBM/meta.h" namespace LightGBM { // use double precision or not #ifndef USE_DP_FLOAT #define USE_DP_FLOAT 1 #endif // ignore hessian, and use the local memory for hessian as an additional bank for gradient #ifndef CONST_HESSIAN #define CONST_HESSIAN 0 #endif typedef unsigned char uchar; template __device__ double as_double(const T t) { static_assert(sizeof(T) == sizeof(double), "size mismatch"); double d; memcpy(&d, &t, sizeof(T)); return d; } template __device__ unsigned long long as_ulong_ulong(const T t) { static_assert(sizeof(T) == sizeof(unsigned long long), "size mismatch"); unsigned long long u; memcpy(&u, &t, sizeof(T)); return u; } template __device__ float as_float(const T t) { static_assert(sizeof(T) == sizeof(float), "size mismatch"); float f; memcpy(&f, &t, sizeof(T)); return f; } template __device__ unsigned int as_uint(const T t) { static_assert(sizeof(T) == sizeof(unsigned int), "size_mismatch"); unsigned int u; memcpy(&u, &t, sizeof(T)); return u; } template __device__ uchar4 as_uchar4(const T t) { static_assert(sizeof(T) == sizeof(uchar4), "size mismatch"); uchar4 u; memcpy(&u, &t, sizeof(T)); return u; } #if USE_DP_FLOAT == 1 typedef double acc_type; typedef unsigned long long acc_int_type; #define as_acc_type as_double #define as_acc_int_type as_ulong_ulong #else typedef float acc_type; typedef unsigned int acc_int_type; #define as_acc_type as_float #define as_acc_int_type as_uint #endif // use all features and do not use feature mask #ifndef ENABLE_ALL_FEATURES #define ENABLE_ALL_FEATURES 1 #endif // define all of the different kernels #define DECLARE_CONST_BUF(name) \ __global__ void name(__global const uchar* restrict feature_data_base, \ const uchar* restrict feature_masks,\ const data_size_t feature_size,\ const data_size_t* restrict data_indices, \ const data_size_t num_data, \ const score_t* restrict ordered_gradients, \ const score_t* restrict ordered_hessians,\ char* __restrict__ output_buf,\ volatile int * sync_counters,\ acc_type* __restrict__ hist_buf_base, \ const size_t power_feature_workgroups); #define DECLARE_CONST_HES_CONST_BUF(name) \ __global__ void name(const uchar* __restrict__ feature_data_base, \ const uchar* __restrict__ feature_masks,\ const data_size_t feature_size,\ const data_size_t* __restrict__ data_indices, \ const data_size_t num_data, \ const score_t* __restrict__ ordered_gradients, \ const score_t const_hessian,\ char* __restrict__ output_buf,\ volatile int * sync_counters,\ acc_type* __restrict__ hist_buf_base, \ const size_t power_feature_workgroups); #define DECLARE_CONST_HES(name) \ __global__ void name(const uchar* feature_data_base, \ const uchar* __restrict__ feature_masks,\ const data_size_t feature_size,\ const data_size_t* data_indices, \ const data_size_t num_data, \ const score_t* ordered_gradients, \ const score_t const_hessian,\ char* __restrict__ output_buf, \ volatile int * sync_counters,\ acc_type* __restrict__ hist_buf_base, \ const size_t power_feature_workgroups); #define DECLARE(name) \ __global__ void name(const uchar* feature_data_base, \ const uchar* __restrict__ feature_masks,\ const data_size_t feature_size,\ const data_size_t* data_indices, \ const data_size_t num_data, \ const score_t* ordered_gradients, \ const score_t* ordered_hessians,\ char* __restrict__ output_buf, \ volatile int * sync_counters,\ acc_type* __restrict__ hist_buf_base, \ const size_t power_feature_workgroups); DECLARE_CONST_HES(histogram16_allfeats); DECLARE_CONST_HES(histogram16_fulldata); DECLARE_CONST_HES(histogram16); DECLARE(histogram16_allfeats); DECLARE(histogram16_fulldata); DECLARE(histogram16); DECLARE_CONST_HES(histogram64_allfeats); DECLARE_CONST_HES(histogram64_fulldata); DECLARE_CONST_HES(histogram64); DECLARE(histogram64_allfeats); DECLARE(histogram64_fulldata); DECLARE(histogram64); DECLARE_CONST_HES(histogram256_allfeats); DECLARE_CONST_HES(histogram256_fulldata); DECLARE_CONST_HES(histogram256); DECLARE(histogram256_allfeats); DECLARE(histogram256_fulldata); DECLARE(histogram256); } // namespace LightGBM #endif // LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_