cuda_leaf_splits.cpp 3.5 KB
Newer Older
1
2
3
4
5
6
/*!
 * Copyright (c) 2021 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
 */

7
#ifdef USE_CUDA
8
9
10
11
12
13

#include "cuda_leaf_splits.hpp"

namespace LightGBM {

CUDALeafSplits::CUDALeafSplits(const data_size_t num_data):
14
num_data_(num_data) {}
15

16
CUDALeafSplits::~CUDALeafSplits() {}
17

18
void CUDALeafSplits::Init(const bool use_quantized_grad) {
19
20
21
22
  num_blocks_init_from_gradients_ = (num_data_ + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;

  // allocate more memory for sum reduction in CUDA
  // only the first element records the final sum
23
24
25
26
27
  cuda_sum_of_gradients_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
  cuda_sum_of_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
  if (use_quantized_grad) {
    cuda_sum_of_gradients_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
  }
28

29
  cuda_struct_.Resize(1);
30
31
32
33
34
35
36
37
38
39
40
}

void CUDALeafSplits::InitValues() {
  LaunchInitValuesEmptyKernel();
  SynchronizeCUDADevice(__FILE__, __LINE__);
}

void CUDALeafSplits::InitValues(
  const double lambda_l1, const double lambda_l2,
  const score_t* cuda_gradients, const score_t* cuda_hessians,
  const data_size_t* cuda_bagging_data_indices, const data_size_t* cuda_data_indices_in_leaf,
41
42
  const data_size_t num_used_indices, hist_t* cuda_hist_in_leaf,
  double* root_sum_gradients, double* root_sum_hessians) {
43
44
  cuda_gradients_ = cuda_gradients;
  cuda_hessians_ = cuda_hessians;
45
46
  cuda_sum_of_gradients_buffer_.SetValue(0);
  cuda_sum_of_hessians_buffer_.SetValue(0);
47
  LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf);
48
  CopyFromCUDADeviceToHost<double>(root_sum_gradients, cuda_sum_of_gradients_buffer_.RawData(), 1, __FILE__, __LINE__);
49
50
51
52
53
54
55
56
57
  CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__);
  SynchronizeCUDADevice(__FILE__, __LINE__);
}

void CUDALeafSplits::InitValues(
  const double lambda_l1, const double lambda_l2,
  const int16_t* cuda_gradients_and_hessians,
  const data_size_t* cuda_bagging_data_indices,
  const data_size_t* cuda_data_indices_in_leaf, const data_size_t num_used_indices,
58
  hist_t* cuda_hist_in_leaf, double* root_sum_gradients, double* root_sum_hessians,
59
60
61
62
  const score_t* grad_scale, const score_t* hess_scale) {
  cuda_gradients_ = reinterpret_cast<const score_t*>(cuda_gradients_and_hessians);
  cuda_hessians_ = nullptr;
  LaunchInitValuesKernal(lambda_l1, lambda_l2, cuda_bagging_data_indices, cuda_data_indices_in_leaf, num_used_indices, cuda_hist_in_leaf, grad_scale, hess_scale);
63
  CopyFromCUDADeviceToHost<double>(root_sum_gradients, cuda_sum_of_gradients_buffer_.RawData(), 1, __FILE__, __LINE__);
64
  CopyFromCUDADeviceToHost<double>(root_sum_hessians, cuda_sum_of_hessians_buffer_.RawData(), 1, __FILE__, __LINE__);
65
66
67
68
69
  SynchronizeCUDADevice(__FILE__, __LINE__);
}

void CUDALeafSplits::Resize(const data_size_t num_data) {
  num_data_ = num_data;
70
71
72
73
  num_blocks_init_from_gradients_ = (num_data + NUM_THRADS_PER_BLOCK_LEAF_SPLITS - 1) / NUM_THRADS_PER_BLOCK_LEAF_SPLITS;
  cuda_sum_of_gradients_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
  cuda_sum_of_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
  cuda_sum_of_gradients_hessians_buffer_.Resize(static_cast<size_t>(num_blocks_init_from_gradients_));
74
75
76
77
}

}  // namespace LightGBM

78
#endif  // USE_CUDA