cuda_leaf_splits.cu 11 KB
Newer Older
1
2
3
4
5
6
7
/*!
 * Copyright (c) 2021 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
 */


8
#ifdef USE_CUDA
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

#include "cuda_leaf_splits.hpp"
#include <LightGBM/cuda/cuda_algorithms.hpp>

namespace LightGBM {

template <bool USE_INDICES>
__global__ void CUDAInitValuesKernel1(const score_t* cuda_gradients, const score_t* cuda_hessians,
  const data_size_t num_data, const data_size_t* cuda_bagging_data_indices,
  double* cuda_sum_of_gradients, double* cuda_sum_of_hessians) {
  __shared__ double shared_mem_buffer[32];
  const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
  double gradient = 0.0f;
  double hessian = 0.0f;
  if (data_index < num_data) {
    gradient = USE_INDICES ? cuda_gradients[cuda_bagging_data_indices[data_index]] : cuda_gradients[data_index];
    hessian = USE_INDICES ? cuda_hessians[cuda_bagging_data_indices[data_index]] : cuda_hessians[data_index];
  }
  const double block_sum_gradient = ShuffleReduceSum<double>(gradient, shared_mem_buffer, blockDim.x);
  __syncthreads();
  const double block_sum_hessian = ShuffleReduceSum<double>(hessian, shared_mem_buffer, blockDim.x);
  if (threadIdx.x == 0) {
    cuda_sum_of_gradients[blockIdx.x] += block_sum_gradient;
    cuda_sum_of_hessians[blockIdx.x] += block_sum_hessian;
  }
}

__global__ void CUDAInitValuesKernel2(
  const double lambda_l1,
  const double lambda_l2,
  const int num_blocks_to_reduce,
  double* cuda_sum_of_gradients,
  double* cuda_sum_of_hessians,
  const data_size_t num_data,
  const data_size_t* cuda_data_indices_in_leaf,
  hist_t* cuda_hist_in_leaf,
  CUDALeafSplitsStruct* cuda_struct) {
  __shared__ double shared_mem_buffer[32];
  double thread_sum_of_gradients = 0.0f;
  double thread_sum_of_hessians = 0.0f;
  for (int block_index = static_cast<int>(threadIdx.x); block_index < num_blocks_to_reduce; block_index += static_cast<int>(blockDim.x)) {
    thread_sum_of_gradients += cuda_sum_of_gradients[block_index];
    thread_sum_of_hessians += cuda_sum_of_hessians[block_index];
  }
  const double sum_of_gradients = ShuffleReduceSum<double>(thread_sum_of_gradients, shared_mem_buffer, blockDim.x);
  __syncthreads();
  const double sum_of_hessians = ShuffleReduceSum<double>(thread_sum_of_hessians, shared_mem_buffer, blockDim.x);
  if (threadIdx.x == 0) {
    cuda_sum_of_hessians[0] = sum_of_hessians;
    cuda_struct->leaf_index = 0;
    cuda_struct->sum_of_gradients = sum_of_gradients;
    cuda_struct->sum_of_hessians = sum_of_hessians;
    cuda_struct->num_data_in_leaf = num_data;
    const bool use_l1 = lambda_l1 > 0.0f;
    if (!use_l1) {
      // no smoothing on root node
      cuda_struct->gain = CUDALeafSplits::GetLeafGain<false, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
    } else {
      // no smoothing on root node
      cuda_struct->gain = CUDALeafSplits::GetLeafGain<true, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
    }
    if (!use_l1) {
      // no smoothing on root node
      cuda_struct->leaf_value =
        CUDALeafSplits::CalculateSplittedLeafOutput<false, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
    } else {
      // no smoothing on root node
      cuda_struct->leaf_value =
        CUDALeafSplits::CalculateSplittedLeafOutput<true, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
    }
    cuda_struct->data_indices_in_leaf = cuda_data_indices_in_leaf;
    cuda_struct->hist_in_leaf = cuda_hist_in_leaf;
  }
}

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
template <bool USE_INDICES>
__global__ void CUDAInitValuesKernel3(const int16_t* cuda_gradients_and_hessians,
  const data_size_t num_data, const data_size_t* cuda_bagging_data_indices,
  double* cuda_sum_of_gradients, double* cuda_sum_of_hessians, int64_t* cuda_sum_of_hessians_hessians,
  const score_t* grad_scale_pointer, const score_t* hess_scale_pointer) {
  const score_t grad_scale = *grad_scale_pointer;
  const score_t hess_scale = *hess_scale_pointer;
  __shared__ int64_t shared_mem_buffer[32];
  const data_size_t data_index = static_cast<data_size_t>(threadIdx.x + blockIdx.x * blockDim.x);
  int64_t int_gradient = 0;
  int64_t int_hessian = 0;
  if (data_index < num_data) {
    int_gradient = USE_INDICES ? cuda_gradients_and_hessians[2 * cuda_bagging_data_indices[data_index] + 1] :
      cuda_gradients_and_hessians[2 * data_index + 1];
    int_hessian = USE_INDICES ? cuda_gradients_and_hessians[2 * cuda_bagging_data_indices[data_index]] :
      cuda_gradients_and_hessians[2 * data_index];
  }
  const int64_t block_sum_gradient = ShuffleReduceSum<int64_t>(int_gradient, shared_mem_buffer, blockDim.x);
  __syncthreads();
  const int64_t block_sum_hessian = ShuffleReduceSum<int64_t>(int_hessian, shared_mem_buffer, blockDim.x);
  if (threadIdx.x == 0) {
    cuda_sum_of_gradients[blockIdx.x] = block_sum_gradient * grad_scale;
    cuda_sum_of_hessians[blockIdx.x] = block_sum_hessian * hess_scale;
    cuda_sum_of_hessians_hessians[blockIdx.x] = ((block_sum_gradient << 32) | block_sum_hessian);
  }
}

__global__ void CUDAInitValuesKernel4(
  const double lambda_l1,
  const double lambda_l2,
  const int num_blocks_to_reduce,
  double* cuda_sum_of_gradients,
  double* cuda_sum_of_hessians,
  int64_t* cuda_sum_of_gradients_hessians,
  const data_size_t num_data,
  const data_size_t* cuda_data_indices_in_leaf,
  hist_t* cuda_hist_in_leaf,
  CUDALeafSplitsStruct* cuda_struct) {
  __shared__ double shared_mem_buffer[32];
  double thread_sum_of_gradients = 0.0f;
  double thread_sum_of_hessians = 0.0f;
  int64_t thread_sum_of_gradients_hessians = 0;
  for (int block_index = static_cast<int>(threadIdx.x); block_index < num_blocks_to_reduce; block_index += static_cast<int>(blockDim.x)) {
    thread_sum_of_gradients += cuda_sum_of_gradients[block_index];
    thread_sum_of_hessians += cuda_sum_of_hessians[block_index];
    thread_sum_of_gradients_hessians += cuda_sum_of_gradients_hessians[block_index];
  }
  const double sum_of_gradients = ShuffleReduceSum<double>(thread_sum_of_gradients, shared_mem_buffer, blockDim.x);
  __syncthreads();
  const double sum_of_hessians = ShuffleReduceSum<double>(thread_sum_of_hessians, shared_mem_buffer, blockDim.x);
  __syncthreads();
  const double sum_of_gradients_hessians = ShuffleReduceSum<int64_t>(
    thread_sum_of_gradients_hessians,
    reinterpret_cast<int64_t*>(shared_mem_buffer),
    blockDim.x);
  if (threadIdx.x == 0) {
    cuda_sum_of_hessians[0] = sum_of_hessians;
    cuda_struct->leaf_index = 0;
    cuda_struct->sum_of_gradients = sum_of_gradients;
    cuda_struct->sum_of_hessians = sum_of_hessians;
    cuda_struct->sum_of_gradients_hessians = sum_of_gradients_hessians;
    cuda_struct->num_data_in_leaf = num_data;
    const bool use_l1 = lambda_l1 > 0.0f;
    if (!use_l1) {
      // no smoothing on root node
      cuda_struct->gain = CUDALeafSplits::GetLeafGain<false, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
    } else {
      // no smoothing on root node
      cuda_struct->gain = CUDALeafSplits::GetLeafGain<true, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
    }
    if (!use_l1) {
      // no smoothing on root node
      cuda_struct->leaf_value =
        CUDALeafSplits::CalculateSplittedLeafOutput<false, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
    } else {
      // no smoothing on root node
      cuda_struct->leaf_value =
        CUDALeafSplits::CalculateSplittedLeafOutput<true, false>(sum_of_gradients, sum_of_hessians, lambda_l1, lambda_l2, 0.0f, 0, 0.0f);
    }
    cuda_struct->data_indices_in_leaf = cuda_data_indices_in_leaf;
    cuda_struct->hist_in_leaf = cuda_hist_in_leaf;
  }
}

168
169
170
171
172
173
174
175
176
177
178
179
__global__ void InitValuesEmptyKernel(CUDALeafSplitsStruct* cuda_struct) {
  cuda_struct->leaf_index = -1;
  cuda_struct->sum_of_gradients = 0.0f;
  cuda_struct->sum_of_hessians = 0.0f;
  cuda_struct->num_data_in_leaf = 0;
  cuda_struct->gain = 0.0f;
  cuda_struct->leaf_value = 0.0f;
  cuda_struct->data_indices_in_leaf = nullptr;
  cuda_struct->hist_in_leaf = nullptr;
}

void CUDALeafSplits::LaunchInitValuesEmptyKernel() {
180
  InitValuesEmptyKernel<<<1, 1>>>(cuda_struct_.RawData());
181
182
}

183
void CUDALeafSplits::LaunchInitValuesKernel(
184
185
186
187
188
189
  const double lambda_l1, const double lambda_l2,
  const data_size_t* cuda_bagging_data_indices,
  const data_size_t* cuda_data_indices_in_leaf,
  const data_size_t num_used_indices,
  hist_t* cuda_hist_in_leaf) {
  if (cuda_bagging_data_indices == nullptr) {
190
    CUDAInitValuesKernel1<false><<<num_blocks_init_from_gradients_, NUM_THREADS_PER_BLOCK_LEAF_SPLITS>>>(
191
192
      cuda_gradients_, cuda_hessians_, num_used_indices, nullptr, cuda_sum_of_gradients_buffer_.RawData(),
      cuda_sum_of_hessians_buffer_.RawData());
193
  } else {
194
    CUDAInitValuesKernel1<true><<<num_blocks_init_from_gradients_, NUM_THREADS_PER_BLOCK_LEAF_SPLITS>>>(
195
196
      cuda_gradients_, cuda_hessians_, num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_.RawData(),
      cuda_sum_of_hessians_buffer_.RawData());
197
198
  }
  SynchronizeCUDADevice(__FILE__, __LINE__);
199
  CUDAInitValuesKernel2<<<1, NUM_THREADS_PER_BLOCK_LEAF_SPLITS>>>(
200
201
    lambda_l1, lambda_l2,
    num_blocks_init_from_gradients_,
202
203
204
205
206
207
208
209
210
    cuda_sum_of_gradients_buffer_.RawData(),
    cuda_sum_of_hessians_buffer_.RawData(),
    num_used_indices,
    cuda_data_indices_in_leaf,
    cuda_hist_in_leaf,
    cuda_struct_.RawData());
  SynchronizeCUDADevice(__FILE__, __LINE__);
}

211
void CUDALeafSplits::LaunchInitValuesKernel(
212
213
214
215
216
217
218
219
  const double lambda_l1, const double lambda_l2,
  const data_size_t* cuda_bagging_data_indices,
  const data_size_t* cuda_data_indices_in_leaf,
  const data_size_t num_used_indices,
  hist_t* cuda_hist_in_leaf,
  const score_t* grad_scale,
  const score_t* hess_scale) {
  if (cuda_bagging_data_indices == nullptr) {
220
    CUDAInitValuesKernel3<false><<<num_blocks_init_from_gradients_, NUM_THREADS_PER_BLOCK_LEAF_SPLITS>>>(
221
222
223
      reinterpret_cast<const int16_t*>(cuda_gradients_), num_used_indices, nullptr, cuda_sum_of_gradients_buffer_.RawData(),
      cuda_sum_of_hessians_buffer_.RawData(), cuda_sum_of_gradients_hessians_buffer_.RawData(), grad_scale, hess_scale);
  } else {
224
    CUDAInitValuesKernel3<true><<<num_blocks_init_from_gradients_, NUM_THREADS_PER_BLOCK_LEAF_SPLITS>>>(
225
226
227
228
229
      reinterpret_cast<const int16_t*>(cuda_gradients_), num_used_indices, cuda_bagging_data_indices, cuda_sum_of_gradients_buffer_.RawData(),
      cuda_sum_of_hessians_buffer_.RawData(), cuda_sum_of_gradients_hessians_buffer_.RawData(), grad_scale, hess_scale);
  }

  SynchronizeCUDADevice(__FILE__, __LINE__);
230
  CUDAInitValuesKernel4<<<1, NUM_THREADS_PER_BLOCK_LEAF_SPLITS>>>(
231
232
233
234
235
    lambda_l1, lambda_l2,
    num_blocks_init_from_gradients_,
    cuda_sum_of_gradients_buffer_.RawData(),
    cuda_sum_of_hessians_buffer_.RawData(),
    cuda_sum_of_gradients_hessians_buffer_.RawData(),
236
237
238
    num_used_indices,
    cuda_data_indices_in_leaf,
    cuda_hist_in_leaf,
239
    cuda_struct_.RawData());
240
241
242
243
244
  SynchronizeCUDADevice(__FILE__, __LINE__);
}

}  // namespace LightGBM

245
#endif  // USE_CUDA