cuda_tree.cpp 16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
/*!
 * Copyright (c) 2021 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */

#ifdef USE_CUDA_EXP

#include <LightGBM/cuda/cuda_tree.hpp>

namespace LightGBM {

CUDATree::CUDATree(int max_leaves, bool track_branch_features, bool is_linear,
  const int gpu_device_id, const bool has_categorical_feature):
Tree(max_leaves, track_branch_features, is_linear),
num_threads_per_block_add_prediction_to_score_(1024) {
  is_cuda_tree_ = true;
  if (gpu_device_id >= 0) {
    SetCUDADevice(gpu_device_id, __FILE__, __LINE__);
  } else {
    SetCUDADevice(0, __FILE__, __LINE__);
  }
  if (has_categorical_feature) {
    cuda_cat_boundaries_.Resize(max_leaves);
    cuda_cat_boundaries_inner_.Resize(max_leaves);
  }
  InitCUDAMemory();
}

CUDATree::CUDATree(const Tree* host_tree):
  Tree(*host_tree),
  num_threads_per_block_add_prediction_to_score_(1024) {
  is_cuda_tree_ = true;
  InitCUDA();
}

CUDATree::~CUDATree() {
  DeallocateCUDAMemory<int>(&cuda_left_child_, __FILE__, __LINE__);
  DeallocateCUDAMemory<int>(&cuda_right_child_, __FILE__, __LINE__);
  DeallocateCUDAMemory<int>(&cuda_split_feature_inner_, __FILE__, __LINE__);
  DeallocateCUDAMemory<int>(&cuda_split_feature_, __FILE__, __LINE__);
  DeallocateCUDAMemory<int>(&cuda_leaf_depth_, __FILE__, __LINE__);
  DeallocateCUDAMemory<int>(&cuda_leaf_parent_, __FILE__, __LINE__);
  DeallocateCUDAMemory<uint32_t>(&cuda_threshold_in_bin_, __FILE__, __LINE__);
  DeallocateCUDAMemory<double>(&cuda_threshold_, __FILE__, __LINE__);
  DeallocateCUDAMemory<double>(&cuda_internal_weight_, __FILE__, __LINE__);
  DeallocateCUDAMemory<double>(&cuda_internal_value_, __FILE__, __LINE__);
  DeallocateCUDAMemory<int8_t>(&cuda_decision_type_, __FILE__, __LINE__);
  DeallocateCUDAMemory<double>(&cuda_leaf_value_, __FILE__, __LINE__);
  DeallocateCUDAMemory<data_size_t>(&cuda_leaf_count_, __FILE__, __LINE__);
  DeallocateCUDAMemory<double>(&cuda_leaf_weight_, __FILE__, __LINE__);
  DeallocateCUDAMemory<data_size_t>(&cuda_internal_count_, __FILE__, __LINE__);
  DeallocateCUDAMemory<float>(&cuda_split_gain_, __FILE__, __LINE__);
  gpuAssert(cudaStreamDestroy(cuda_stream_), __FILE__, __LINE__);
}

void CUDATree::InitCUDAMemory() {
  AllocateCUDAMemory<int>(&cuda_left_child_,
                               static_cast<size_t>(max_leaves_),
                               __FILE__,
                               __LINE__);
  AllocateCUDAMemory<int>(&cuda_right_child_,
                               static_cast<size_t>(max_leaves_),
                               __FILE__,
                               __LINE__);
  AllocateCUDAMemory<int>(&cuda_split_feature_inner_,
                               static_cast<size_t>(max_leaves_),
                               __FILE__,
                               __LINE__);
  AllocateCUDAMemory<int>(&cuda_split_feature_,
                               static_cast<size_t>(max_leaves_),
                               __FILE__,
                               __LINE__);
  AllocateCUDAMemory<int>(&cuda_leaf_depth_,
                               static_cast<size_t>(max_leaves_),
                               __FILE__,
                               __LINE__);
  AllocateCUDAMemory<int>(&cuda_leaf_parent_,
                               static_cast<size_t>(max_leaves_),
                               __FILE__,
                               __LINE__);
  AllocateCUDAMemory<uint32_t>(&cuda_threshold_in_bin_,
                                    static_cast<size_t>(max_leaves_),
                                    __FILE__,
                                    __LINE__);
  AllocateCUDAMemory<double>(&cuda_threshold_,
                                  static_cast<size_t>(max_leaves_),
                                  __FILE__,
                                  __LINE__);
  AllocateCUDAMemory<int8_t>(&cuda_decision_type_,
                                  static_cast<size_t>(max_leaves_),
                                  __FILE__,
                                  __LINE__);
  AllocateCUDAMemory<double>(&cuda_leaf_value_,
                                  static_cast<size_t>(max_leaves_),
                                  __FILE__,
                                  __LINE__);
  AllocateCUDAMemory<double>(&cuda_internal_weight_,
                                  static_cast<size_t>(max_leaves_),
                                  __FILE__,
                                  __LINE__);
  AllocateCUDAMemory<double>(&cuda_internal_value_,
                                  static_cast<size_t>(max_leaves_),
                                  __FILE__,
                                  __LINE__);
  AllocateCUDAMemory<double>(&cuda_leaf_weight_,
                             static_cast<size_t>(max_leaves_),
                             __FILE__,
                             __LINE__);
  AllocateCUDAMemory<data_size_t>(&cuda_leaf_count_,
                                  static_cast<size_t>(max_leaves_),
                                  __FILE__,
                                  __LINE__);
  AllocateCUDAMemory<data_size_t>(&cuda_internal_count_,
                                       static_cast<size_t>(max_leaves_),
                                       __FILE__,
                                       __LINE__);
  AllocateCUDAMemory<float>(&cuda_split_gain_,
                                 static_cast<size_t>(max_leaves_),
                                 __FILE__,
                                 __LINE__);
  SetCUDAMemory<double>(cuda_leaf_value_, 0.0f, 1, __FILE__, __LINE__);
  SetCUDAMemory<double>(cuda_leaf_weight_, 0.0f, 1, __FILE__, __LINE__);
  SetCUDAMemory<int>(cuda_leaf_parent_, -1, 1, __FILE__, __LINE__);
  CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_stream_));
  SynchronizeCUDADevice(__FILE__, __LINE__);
}

void CUDATree::InitCUDA() {
  InitCUDAMemoryFromHostMemory<int>(&cuda_left_child_,
                                    left_child_.data(),
                                    left_child_.size(),
                                    __FILE__,
                                    __LINE__);
  InitCUDAMemoryFromHostMemory<int>(&cuda_right_child_,
                                    right_child_.data(),
                                    right_child_.size(),
                                    __FILE__,
                                    __LINE__);
  InitCUDAMemoryFromHostMemory<int>(&cuda_split_feature_inner_,
                                    split_feature_inner_.data(),
                                    split_feature_inner_.size(),
                                    __FILE__,
                                    __LINE__);
  InitCUDAMemoryFromHostMemory<int>(&cuda_split_feature_,
                                    split_feature_.data(),
                                    split_feature_.size(),
                                    __FILE__,
                                    __LINE__);
  InitCUDAMemoryFromHostMemory<uint32_t>(&cuda_threshold_in_bin_,
                                    threshold_in_bin_.data(),
                                    threshold_in_bin_.size(),
                                    __FILE__,
                                    __LINE__);
  InitCUDAMemoryFromHostMemory<double>(&cuda_threshold_,
                                    threshold_.data(),
                                    threshold_.size(),
                                    __FILE__,
                                    __LINE__);
  InitCUDAMemoryFromHostMemory<int>(&cuda_leaf_depth_,
                                    leaf_depth_.data(),
                                    leaf_depth_.size(),
                                    __FILE__,
                                    __LINE__);
  InitCUDAMemoryFromHostMemory<int8_t>(&cuda_decision_type_,
                                       decision_type_.data(),
                                       decision_type_.size(),
                                       __FILE__,
                                       __LINE__);
  InitCUDAMemoryFromHostMemory<double>(&cuda_internal_weight_,
                                       internal_weight_.data(),
                                       internal_weight_.size(),
                                       __FILE__,
                                       __LINE__);
  InitCUDAMemoryFromHostMemory<double>(&cuda_internal_value_,
                                       internal_value_.data(),
                                       internal_value_.size(),
                                       __FILE__,
                                       __LINE__);
  InitCUDAMemoryFromHostMemory<data_size_t>(&cuda_internal_count_,
                                       internal_count_.data(),
                                       internal_count_.size(),
                                       __FILE__,
                                       __LINE__);
  InitCUDAMemoryFromHostMemory<data_size_t>(&cuda_leaf_count_,
                                       leaf_count_.data(),
                                       leaf_count_.size(),
                                       __FILE__,
                                       __LINE__);
  InitCUDAMemoryFromHostMemory<float>(&cuda_split_gain_,
                                       split_gain_.data(),
                                       split_gain_.size(),
                                       __FILE__,
                                       __LINE__);
  InitCUDAMemoryFromHostMemory<double>(&cuda_leaf_value_,
                                    leaf_value_.data(),
                                    leaf_value_.size(),
                                    __FILE__,
                                    __LINE__);
  InitCUDAMemoryFromHostMemory<double>(&cuda_leaf_weight_,
                                    leaf_weight_.data(),
                                    leaf_weight_.size(),
                                    __FILE__,
                                    __LINE__);
  InitCUDAMemoryFromHostMemory<int>(&cuda_leaf_parent_,
                                    leaf_parent_.data(),
                                    leaf_parent_.size(),
                                    __FILE__,
                                    __LINE__);
  CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_stream_));
  SynchronizeCUDADevice(__FILE__, __LINE__);
}

int CUDATree::Split(const int leaf_index,
           const int real_feature_index,
           const double real_threshold,
           const MissingType missing_type,
           const CUDASplitInfo* cuda_split_info) {
  LaunchSplitKernel(leaf_index, real_feature_index, real_threshold, missing_type, cuda_split_info);
219
  RecordBranchFeatures(leaf_index, num_leaves_, real_feature_index);
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
  ++num_leaves_;
  return num_leaves_ - 1;
}

int CUDATree::SplitCategorical(const int leaf_index,
           const int real_feature_index,
           const MissingType missing_type,
           const CUDASplitInfo* cuda_split_info,
           uint32_t* cuda_bitset,
           size_t cuda_bitset_len,
           uint32_t* cuda_bitset_inner,
           size_t cuda_bitset_inner_len) {
  LaunchSplitCategoricalKernel(leaf_index, real_feature_index,
    missing_type, cuda_split_info,
    cuda_bitset_len, cuda_bitset_inner_len);
  cuda_bitset_.PushBack(cuda_bitset, cuda_bitset_len);
  cuda_bitset_inner_.PushBack(cuda_bitset_inner, cuda_bitset_inner_len);
  ++num_leaves_;
  ++num_cat_;
239
  RecordBranchFeatures(leaf_index, num_leaves_, real_feature_index);
240
241
242
  return num_leaves_ - 1;
}

243
244
245
246
247
248
249
250
251
252
void CUDATree::RecordBranchFeatures(const int left_leaf_index,
                                    const int right_leaf_index,
                                    const int real_feature_index) {
  if (track_branch_features_) {
    branch_features_[right_leaf_index] = branch_features_[left_leaf_index];
    branch_features_[right_leaf_index].push_back(real_feature_index);
    branch_features_[left_leaf_index].push_back(real_feature_index);
  }
}

253
254
255
256
257
258
259
260
261
262
263
264
265
266
void CUDATree::AddPredictionToScore(const Dataset* data,
                                    data_size_t num_data,
                                    double* score) const {
  LaunchAddPredictionToScoreKernel(data, nullptr, num_data, score);
  SynchronizeCUDADevice(__FILE__, __LINE__);
}

void CUDATree::AddPredictionToScore(const Dataset* data,
                                    const data_size_t* used_data_indices,
                                    data_size_t num_data, double* score) const {
  LaunchAddPredictionToScoreKernel(data, used_data_indices, num_data, score);
  SynchronizeCUDADevice(__FILE__, __LINE__);
}

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
inline void CUDATree::Shrinkage(double rate) {
  Tree::Shrinkage(rate);
  LaunchShrinkageKernel(rate);
}

inline void CUDATree::AddBias(double val) {
  Tree::AddBias(val);
  LaunchAddBiasKernel(val);
}

void CUDATree::ToHost() {
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
  decision_type_.resize(max_leaves_ - 1, 0);
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
  leaf_weight_.resize(max_leaves_);
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
  internal_weight_.resize(max_leaves_ - 1);
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);

  const size_t num_leaves_size = static_cast<size_t>(num_leaves_);
  CopyFromCUDADeviceToHost<int>(left_child_.data(), cuda_left_child_, num_leaves_size - 1, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<int>(right_child_.data(), cuda_right_child_, num_leaves_size - 1, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<int>(split_feature_inner_.data(), cuda_split_feature_inner_, num_leaves_size - 1, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<int>(split_feature_.data(), cuda_split_feature_, num_leaves_size - 1, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<uint32_t>(threshold_in_bin_.data(), cuda_threshold_in_bin_, num_leaves_size - 1, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<double>(threshold_.data(), cuda_threshold_, num_leaves_size - 1, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<int8_t>(decision_type_.data(), cuda_decision_type_, num_leaves_size - 1, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<float>(split_gain_.data(), cuda_split_gain_, num_leaves_size - 1, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<int>(leaf_parent_.data(), cuda_leaf_parent_, num_leaves_size - 1, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<double>(leaf_value_.data(), cuda_leaf_value_, num_leaves_size, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<double>(leaf_weight_.data(), cuda_leaf_weight_, num_leaves_size, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<data_size_t>(leaf_count_.data(), cuda_leaf_count_, num_leaves_size, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<double>(internal_value_.data(), cuda_internal_value_, num_leaves_size - 1, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<double>(internal_weight_.data(), cuda_internal_weight_, num_leaves_size - 1, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<data_size_t>(internal_count_.data(), cuda_internal_count_, num_leaves_size - 1, __FILE__, __LINE__);
  CopyFromCUDADeviceToHost<int>(leaf_depth_.data(), cuda_leaf_depth_, num_leaves_size, __FILE__, __LINE__);

  if (num_cat_ > 0) {
    cuda_cat_boundaries_inner_.Resize(num_cat_ + 1);
    cuda_cat_boundaries_.Resize(num_cat_ + 1);
    cat_boundaries_ = cuda_cat_boundaries_.ToHost();
    cat_boundaries_inner_ = cuda_cat_boundaries_inner_.ToHost();
    cat_threshold_ = cuda_bitset_.ToHost();
    cat_threshold_inner_ = cuda_bitset_inner_.ToHost();
  }

  SynchronizeCUDADevice(__FILE__, __LINE__);
}

void CUDATree::SyncLeafOutputFromHostToCUDA() {
  CopyFromHostToCUDADevice<double>(cuda_leaf_value_, leaf_value_.data(), leaf_value_.size(), __FILE__, __LINE__);
}

void CUDATree::SyncLeafOutputFromCUDAToHost() {
  CopyFromCUDADeviceToHost<double>(leaf_value_.data(), cuda_leaf_value_, leaf_value_.size(), __FILE__, __LINE__);
}

333
334
335
336
void CUDATree::AsConstantTree(double val) {
  Tree::AsConstantTree(val);
  CopyFromHostToCUDADevice<double>(cuda_leaf_value_, &val, 1, __FILE__, __LINE__);
}
337
338
339
340

}  // namespace LightGBM

#endif  // USE_CUDA_EXP