"vscode:/vscode.git/clone" did not exist on "edfc2035abde13b53d17d02201604eed7540f90e"
cuda_best_split_finder.cpp 18.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*!
 * Copyright (c) 2021 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
 */

#ifdef USE_CUDA_EXP

#include <algorithm>

#include "cuda_best_split_finder.hpp"
#include "cuda_leaf_splits.hpp"

namespace LightGBM {

CUDABestSplitFinder::CUDABestSplitFinder(
  const hist_t* cuda_hist,
  const Dataset* train_data,
  const std::vector<uint32_t>& feature_hist_offsets,
20
  const bool select_features_by_node,
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
  const Config* config):
  num_features_(train_data->num_features()),
  num_leaves_(config->num_leaves),
  feature_hist_offsets_(feature_hist_offsets),
  lambda_l1_(config->lambda_l1),
  lambda_l2_(config->lambda_l2),
  min_data_in_leaf_(config->min_data_in_leaf),
  min_sum_hessian_in_leaf_(config->min_sum_hessian_in_leaf),
  min_gain_to_split_(config->min_gain_to_split),
  cat_smooth_(config->cat_smooth),
  cat_l2_(config->cat_l2),
  max_cat_threshold_(config->max_cat_threshold),
  min_data_per_group_(config->min_data_per_group),
  max_cat_to_onehot_(config->max_cat_to_onehot),
  extra_trees_(config->extra_trees),
  extra_seed_(config->extra_seed),
  use_smoothing_(config->path_smooth > 0),
  path_smooth_(config->path_smooth),
  num_total_bin_(feature_hist_offsets.empty() ? 0 : static_cast<int>(feature_hist_offsets.back())),
40
  select_features_by_node_(select_features_by_node),
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  cuda_hist_(cuda_hist) {
  InitFeatureMetaInfo(train_data);
  cuda_leaf_best_split_info_ = nullptr;
  cuda_best_split_info_ = nullptr;
  cuda_best_split_info_buffer_ = nullptr;
  cuda_is_feature_used_bytree_ = nullptr;
}

CUDABestSplitFinder::~CUDABestSplitFinder() {
  DeallocateCUDAMemory<CUDASplitInfo>(&cuda_leaf_best_split_info_, __FILE__, __LINE__);
  DeallocateCUDAMemory<CUDASplitInfo>(&cuda_best_split_info_, __FILE__, __LINE__);
  DeallocateCUDAMemory<int>(&cuda_best_split_info_buffer_, __FILE__, __LINE__);
  cuda_split_find_tasks_.Clear();
  DeallocateCUDAMemory<int8_t>(&cuda_is_feature_used_bytree_, __FILE__, __LINE__);
  gpuAssert(cudaStreamDestroy(cuda_streams_[0]), __FILE__, __LINE__);
  gpuAssert(cudaStreamDestroy(cuda_streams_[1]), __FILE__, __LINE__);
  cuda_streams_.clear();
  cuda_streams_.shrink_to_fit();
}

void CUDABestSplitFinder::InitFeatureMetaInfo(const Dataset* train_data) {
  feature_missing_type_.resize(num_features_);
  feature_mfb_offsets_.resize(num_features_);
  feature_default_bins_.resize(num_features_);
  feature_num_bins_.resize(num_features_);
  max_num_bin_in_feature_ = 0;
  has_categorical_feature_ = false;
  max_num_categorical_bin_ = 0;
  is_categorical_.resize(train_data->num_features(), 0);
  for (int inner_feature_index = 0; inner_feature_index < num_features_; ++inner_feature_index) {
    const BinMapper* bin_mapper = train_data->FeatureBinMapper(inner_feature_index);
    if (bin_mapper->bin_type() == BinType::CategoricalBin) {
      has_categorical_feature_ = true;
      is_categorical_[inner_feature_index] = 1;
      if (bin_mapper->num_bin() > max_num_categorical_bin_) {
        max_num_categorical_bin_ = bin_mapper->num_bin();
      }
    }
    const MissingType missing_type = bin_mapper->missing_type();
    feature_missing_type_[inner_feature_index] = missing_type;
    feature_mfb_offsets_[inner_feature_index] = static_cast<int8_t>(bin_mapper->GetMostFreqBin() == 0);
    feature_default_bins_[inner_feature_index] = bin_mapper->GetDefaultBin();
    feature_num_bins_[inner_feature_index] = static_cast<uint32_t>(bin_mapper->num_bin());
    const int num_bin_hist = bin_mapper->num_bin() - feature_mfb_offsets_[inner_feature_index];
    if (num_bin_hist > max_num_bin_in_feature_) {
      max_num_bin_in_feature_ = num_bin_hist;
    }
  }
  if (max_num_bin_in_feature_ > NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER) {
    use_global_memory_ = true;
  } else {
    use_global_memory_ = false;
  }
}

void CUDABestSplitFinder::Init() {
  InitCUDAFeatureMetaInfo();
  cuda_streams_.resize(2);
  CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_streams_[0]));
  CUDASUCCESS_OR_FATAL(cudaStreamCreate(&cuda_streams_[1]));
  AllocateCUDAMemory<int>(&cuda_best_split_info_buffer_, 8, __FILE__, __LINE__);
  if (use_global_memory_) {
    AllocateCUDAMemory<hist_t>(&cuda_feature_hist_grad_buffer_, static_cast<size_t>(num_total_bin_), __FILE__, __LINE__);
    AllocateCUDAMemory<hist_t>(&cuda_feature_hist_hess_buffer_, static_cast<size_t>(num_total_bin_), __FILE__, __LINE__);
    if (has_categorical_feature_) {
      AllocateCUDAMemory<hist_t>(&cuda_feature_hist_stat_buffer_, static_cast<size_t>(num_total_bin_), __FILE__, __LINE__);
      AllocateCUDAMemory<data_size_t>(&cuda_feature_hist_index_buffer_, static_cast<size_t>(num_total_bin_), __FILE__, __LINE__);
    }
  }
110
111
112
113
114

  if (select_features_by_node_) {
    is_feature_used_by_smaller_node_.Resize(num_features_);
    is_feature_used_by_larger_node_.Resize(num_features_);
  }
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
}

void CUDABestSplitFinder::InitCUDAFeatureMetaInfo() {
  AllocateCUDAMemory<int8_t>(&cuda_is_feature_used_bytree_, static_cast<size_t>(num_features_), __FILE__, __LINE__);

  // intialize split find task information (a split find task is one pass through the histogram of a feature)
  num_tasks_ = 0;
  for (int inner_feature_index = 0; inner_feature_index < num_features_; ++inner_feature_index) {
    const uint32_t num_bin = feature_num_bins_[inner_feature_index];
    const MissingType missing_type = feature_missing_type_[inner_feature_index];
    if (num_bin > 2 && missing_type != MissingType::None && !is_categorical_[inner_feature_index]) {
      num_tasks_ += 2;
    } else {
      ++num_tasks_;
    }
  }
  split_find_tasks_.resize(num_tasks_);
  split_find_tasks_.shrink_to_fit();
  int cur_task_index = 0;
  for (int inner_feature_index = 0; inner_feature_index < num_features_; ++inner_feature_index) {
    const uint32_t num_bin = feature_num_bins_[inner_feature_index];
    const MissingType missing_type = feature_missing_type_[inner_feature_index];
    if (num_bin > 2 && missing_type != MissingType::None && !is_categorical_[inner_feature_index]) {
      if (missing_type == MissingType::Zero) {
        SplitFindTask* new_task = &split_find_tasks_[cur_task_index];
        new_task->reverse = false;
        new_task->skip_default_bin = true;
        new_task->na_as_missing = false;
        new_task->inner_feature_index = inner_feature_index;
        new_task->assume_out_default_left = false;
        new_task->is_categorical = false;
        uint32_t num_bin = feature_num_bins_[inner_feature_index];
        new_task->is_one_hot = false;
        new_task->hist_offset = feature_hist_offsets_[inner_feature_index];
        new_task->mfb_offset = feature_mfb_offsets_[inner_feature_index];
        new_task->default_bin = feature_default_bins_[inner_feature_index];
        new_task->num_bin = num_bin;
        ++cur_task_index;

        new_task = &split_find_tasks_[cur_task_index];
        new_task->reverse = true;
        new_task->skip_default_bin = true;
        new_task->na_as_missing = false;
        new_task->inner_feature_index = inner_feature_index;
        new_task->assume_out_default_left = true;
        new_task->is_categorical = false;
        num_bin = feature_num_bins_[inner_feature_index];
        new_task->is_one_hot = false;
        new_task->hist_offset = feature_hist_offsets_[inner_feature_index];
        new_task->default_bin = feature_default_bins_[inner_feature_index];
        new_task->mfb_offset = feature_mfb_offsets_[inner_feature_index];
        new_task->num_bin = num_bin;
        ++cur_task_index;
      } else {
        SplitFindTask* new_task = &split_find_tasks_[cur_task_index];
        new_task->reverse = false;
        new_task->skip_default_bin = false;
        new_task->na_as_missing = true;
        new_task->inner_feature_index = inner_feature_index;
        new_task->assume_out_default_left = false;
        new_task->is_categorical = false;
        uint32_t num_bin = feature_num_bins_[inner_feature_index];
        new_task->is_one_hot = false;
        new_task->hist_offset = feature_hist_offsets_[inner_feature_index];
        new_task->mfb_offset = feature_mfb_offsets_[inner_feature_index];
        new_task->default_bin = feature_default_bins_[inner_feature_index];
        new_task->num_bin = num_bin;
        ++cur_task_index;

        new_task = &split_find_tasks_[cur_task_index];
        new_task->reverse = true;
        new_task->skip_default_bin = false;
        new_task->na_as_missing = true;
        new_task->inner_feature_index = inner_feature_index;
        new_task->assume_out_default_left = true;
        new_task->is_categorical = false;
        num_bin = feature_num_bins_[inner_feature_index];
        new_task->is_one_hot = false;
        new_task->hist_offset = feature_hist_offsets_[inner_feature_index];
        new_task->mfb_offset = feature_mfb_offsets_[inner_feature_index];
        new_task->default_bin = feature_default_bins_[inner_feature_index];
        new_task->num_bin = num_bin;
        ++cur_task_index;
      }
    } else {
      SplitFindTask& new_task = split_find_tasks_[cur_task_index];
      const uint32_t num_bin = feature_num_bins_[inner_feature_index];
      if (is_categorical_[inner_feature_index]) {
        new_task.reverse = false;
        new_task.is_categorical = true;
        new_task.is_one_hot = (static_cast<int>(num_bin) <= max_cat_to_onehot_);
      } else {
        new_task.reverse = true;
        new_task.is_categorical = false;
        new_task.is_one_hot = false;
      }
      new_task.skip_default_bin = false;
      new_task.na_as_missing = false;
      new_task.inner_feature_index = inner_feature_index;
      if (missing_type != MissingType::NaN && !is_categorical_[inner_feature_index]) {
        new_task.assume_out_default_left = true;
      } else {
        new_task.assume_out_default_left = false;
      }
      new_task.hist_offset = feature_hist_offsets_[inner_feature_index];
      new_task.mfb_offset = feature_mfb_offsets_[inner_feature_index];
      new_task.default_bin = feature_default_bins_[inner_feature_index];
      new_task.num_bin = num_bin;
      ++cur_task_index;
    }
  }
  CHECK_EQ(cur_task_index, static_cast<int>(split_find_tasks_.size()));

  if (extra_trees_) {
    cuda_randoms_.Resize(num_tasks_ * 2);
    LaunchInitCUDARandomKernel();
  }

  const int num_task_blocks = (num_tasks_ + NUM_TASKS_PER_SYNC_BLOCK - 1) / NUM_TASKS_PER_SYNC_BLOCK;
  const size_t cuda_best_leaf_split_info_buffer_size = static_cast<size_t>(num_task_blocks) * static_cast<size_t>(num_leaves_);

  AllocateCUDAMemory<CUDASplitInfo>(&cuda_leaf_best_split_info_,
                                    cuda_best_leaf_split_info_buffer_size,
                                    __FILE__,
                                    __LINE__);

  cuda_split_find_tasks_.Resize(num_tasks_);
  CopyFromHostToCUDADevice<SplitFindTask>(cuda_split_find_tasks_.RawData(),
                                          split_find_tasks_.data(),
                                          split_find_tasks_.size(),
                                          __FILE__,
                                          __LINE__);

  const size_t output_buffer_size = 2 * static_cast<size_t>(num_tasks_);
  AllocateCUDAMemory<CUDASplitInfo>(&cuda_best_split_info_, output_buffer_size, __FILE__, __LINE__);

  max_num_categories_in_split_ = std::min(max_cat_threshold_, max_num_categorical_bin_ / 2);
  AllocateCUDAMemory<uint32_t>(&cuda_cat_threshold_feature_, max_num_categories_in_split_ * output_buffer_size, __FILE__, __LINE__);
  AllocateCUDAMemory<int>(&cuda_cat_threshold_real_feature_, max_num_categories_in_split_ * output_buffer_size, __FILE__, __LINE__);
  AllocateCUDAMemory<uint32_t>(&cuda_cat_threshold_leaf_, max_num_categories_in_split_ * cuda_best_leaf_split_info_buffer_size, __FILE__, __LINE__);
  AllocateCUDAMemory<int>(&cuda_cat_threshold_real_leaf_, max_num_categories_in_split_ * cuda_best_leaf_split_info_buffer_size, __FILE__, __LINE__);
  AllocateCatVectors(cuda_leaf_best_split_info_, cuda_cat_threshold_leaf_, cuda_cat_threshold_real_leaf_, cuda_best_leaf_split_info_buffer_size);
  AllocateCatVectors(cuda_best_split_info_, cuda_cat_threshold_feature_, cuda_cat_threshold_real_feature_, output_buffer_size);
}

void CUDABestSplitFinder::ResetTrainingData(
  const hist_t* cuda_hist,
  const Dataset* train_data,
  const std::vector<uint32_t>& feature_hist_offsets) {
  cuda_hist_ = cuda_hist;
  num_features_ = train_data->num_features();
  feature_hist_offsets_ = feature_hist_offsets;
  InitFeatureMetaInfo(train_data);
  DeallocateCUDAMemory<int8_t>(&cuda_is_feature_used_bytree_, __FILE__, __LINE__);
  DeallocateCUDAMemory<CUDASplitInfo>(&cuda_best_split_info_, __FILE__, __LINE__);
  InitCUDAFeatureMetaInfo();
}

void CUDABestSplitFinder::ResetConfig(const Config* config, const hist_t* cuda_hist) {
  num_leaves_ = config->num_leaves;
  lambda_l1_ = config->lambda_l1;
  lambda_l2_ = config->lambda_l2;
  min_data_in_leaf_ = config->min_data_in_leaf;
  min_sum_hessian_in_leaf_ = config->min_sum_hessian_in_leaf;
  min_gain_to_split_ = config->min_gain_to_split;
  cat_smooth_ = config->cat_smooth;
  cat_l2_ = config->cat_l2;
  max_cat_threshold_ = config->max_cat_threshold;
  min_data_per_group_ = config->min_data_per_group;
  max_cat_to_onehot_ = config->max_cat_to_onehot;
  extra_trees_ = config->extra_trees;
  extra_seed_ = config->extra_seed;
  use_smoothing_ = (config->path_smooth > 0.0f);
  path_smooth_ = config->path_smooth;
  cuda_hist_ = cuda_hist;

  const int num_task_blocks = (num_tasks_ + NUM_TASKS_PER_SYNC_BLOCK - 1) / NUM_TASKS_PER_SYNC_BLOCK;
  size_t cuda_best_leaf_split_info_buffer_size = static_cast<size_t>(num_task_blocks) * static_cast<size_t>(num_leaves_);
  DeallocateCUDAMemory<CUDASplitInfo>(&cuda_leaf_best_split_info_, __FILE__, __LINE__);
  AllocateCUDAMemory<CUDASplitInfo>(&cuda_leaf_best_split_info_,
                                    cuda_best_leaf_split_info_buffer_size,
                                    __FILE__,
                                    __LINE__);
  max_num_categories_in_split_ = std::min(max_cat_threshold_, max_num_categorical_bin_ / 2);
  size_t total_cat_threshold_size = max_num_categories_in_split_ * cuda_best_leaf_split_info_buffer_size;
  DeallocateCUDAMemory<uint32_t>(&cuda_cat_threshold_leaf_, __FILE__, __LINE__);
  DeallocateCUDAMemory<int>(&cuda_cat_threshold_real_leaf_, __FILE__, __LINE__);
  AllocateCUDAMemory<uint32_t>(&cuda_cat_threshold_leaf_, total_cat_threshold_size, __FILE__, __LINE__);
  AllocateCUDAMemory<int>(&cuda_cat_threshold_real_leaf_, total_cat_threshold_size, __FILE__, __LINE__);
  AllocateCatVectors(cuda_leaf_best_split_info_, cuda_cat_threshold_leaf_, cuda_cat_threshold_real_leaf_, cuda_best_leaf_split_info_buffer_size);

  cuda_best_leaf_split_info_buffer_size = 2 * static_cast<size_t>(num_tasks_);
  total_cat_threshold_size = max_num_categories_in_split_ * cuda_best_leaf_split_info_buffer_size;
  DeallocateCUDAMemory<uint32_t>(&cuda_cat_threshold_feature_, __FILE__, __LINE__);
  DeallocateCUDAMemory<int>(&cuda_cat_threshold_real_feature_, __FILE__, __LINE__);
  AllocateCUDAMemory<uint32_t>(&cuda_cat_threshold_feature_, total_cat_threshold_size, __FILE__, __LINE__);
  AllocateCUDAMemory<int>(&cuda_cat_threshold_real_feature_, total_cat_threshold_size, __FILE__, __LINE__);
  AllocateCatVectors(cuda_best_split_info_, cuda_cat_threshold_feature_, cuda_cat_threshold_real_feature_, cuda_best_leaf_split_info_buffer_size);
}

void CUDABestSplitFinder::BeforeTrain(const std::vector<int8_t>& is_feature_used_bytree) {
  CopyFromHostToCUDADevice<int8_t>(cuda_is_feature_used_bytree_,
                                   is_feature_used_bytree.data(),
                                   is_feature_used_bytree.size(), __FILE__, __LINE__);
}

void CUDABestSplitFinder::FindBestSplitsForLeaf(
  const CUDALeafSplitsStruct* smaller_leaf_splits,
  const CUDALeafSplitsStruct* larger_leaf_splits,
  const int smaller_leaf_index,
  const int larger_leaf_index,
  const data_size_t num_data_in_smaller_leaf,
  const data_size_t num_data_in_larger_leaf,
  const double sum_hessians_in_smaller_leaf,
  const double sum_hessians_in_larger_leaf) {
  const bool is_smaller_leaf_valid = (num_data_in_smaller_leaf > min_data_in_leaf_ &&
    sum_hessians_in_smaller_leaf > min_sum_hessian_in_leaf_);
  const bool is_larger_leaf_valid = (num_data_in_larger_leaf > min_data_in_leaf_ &&
    sum_hessians_in_larger_leaf > min_sum_hessian_in_leaf_ && larger_leaf_index >= 0);
  LaunchFindBestSplitsForLeafKernel(smaller_leaf_splits, larger_leaf_splits,
    smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid);
  global_timer.Start("CUDABestSplitFinder::LaunchSyncBestSplitForLeafKernel");
  LaunchSyncBestSplitForLeafKernel(smaller_leaf_index, larger_leaf_index, is_smaller_leaf_valid, is_larger_leaf_valid);
  SynchronizeCUDADevice(__FILE__, __LINE__);
  global_timer.Stop("CUDABestSplitFinder::LaunchSyncBestSplitForLeafKernel");
}

const CUDASplitInfo* CUDABestSplitFinder::FindBestFromAllSplits(
    const int cur_num_leaves,
    const int smaller_leaf_index,
    const int larger_leaf_index,
    int* smaller_leaf_best_split_feature,
    uint32_t* smaller_leaf_best_split_threshold,
    uint8_t* smaller_leaf_best_split_default_left,
    int* larger_leaf_best_split_feature,
    uint32_t* larger_leaf_best_split_threshold,
    uint8_t* larger_leaf_best_split_default_left,
    int* best_leaf_index,
    int* num_cat_threshold) {
  LaunchFindBestFromAllSplitsKernel(
    cur_num_leaves,
    smaller_leaf_index,
    larger_leaf_index,
    smaller_leaf_best_split_feature,
    smaller_leaf_best_split_threshold,
    smaller_leaf_best_split_default_left,
    larger_leaf_best_split_feature,
    larger_leaf_best_split_threshold,
    larger_leaf_best_split_default_left,
    best_leaf_index,
    num_cat_threshold);
  SynchronizeCUDADevice(__FILE__, __LINE__);
  return cuda_leaf_best_split_info_ + (*best_leaf_index);
}

void CUDABestSplitFinder::AllocateCatVectors(CUDASplitInfo* cuda_split_infos, uint32_t* cat_threshold_vec, int* cat_threshold_real_vec, size_t len) {
  LaunchAllocateCatVectorsKernel(cuda_split_infos, cat_threshold_vec, cat_threshold_real_vec, len);
}

374
375
376
377
378
379
380
381
382
383
void CUDABestSplitFinder::SetUsedFeatureByNode(const std::vector<int8_t>& is_feature_used_by_smaller_node,
                                               const std::vector<int8_t>& is_feature_used_by_larger_node) {
  if (select_features_by_node_) {
    CopyFromHostToCUDADevice<int8_t>(is_feature_used_by_smaller_node_.RawData(),
                                     is_feature_used_by_smaller_node.data(), is_feature_used_by_smaller_node.size(), __FILE__, __LINE__);
    CopyFromHostToCUDADevice<int8_t>(is_feature_used_by_larger_node_.RawData(),
                                     is_feature_used_by_larger_node.data(), is_feature_used_by_larger_node.size(), __FILE__, __LINE__);
  }
}

384
385
386
}  // namespace LightGBM

#endif  // USE_CUDA_EXP