cuda_tree.hpp 5.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/*!
 * Copyright (c) 2021 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */

#ifdef USE_CUDA_EXP

#ifndef LIGHTGBM_CUDA_CUDA_TREE_HPP_
#define LIGHTGBM_CUDA_CUDA_TREE_HPP_

#include <LightGBM/cuda/cuda_column_data.hpp>
#include <LightGBM/cuda/cuda_split_info.hpp>
#include <LightGBM/tree.h>
#include <LightGBM/bin.h>

namespace LightGBM {

__device__ void SetDecisionTypeCUDA(int8_t* decision_type, bool input, int8_t mask);

__device__ void SetMissingTypeCUDA(int8_t* decision_type, int8_t input);

__device__ bool GetDecisionTypeCUDA(int8_t decision_type, int8_t mask);

__device__ int8_t GetMissingTypeCUDA(int8_t decision_type);

__device__ bool IsZeroCUDA(double fval);

class CUDATree : public Tree {
 public:
  /*!
  * \brief Constructor
  * \param max_leaves The number of max leaves
  * \param track_branch_features Whether to keep track of ancestors of leaf nodes
  * \param is_linear Whether the tree has linear models at each leaf
  */
  explicit CUDATree(int max_leaves, bool track_branch_features, bool is_linear,
    const int gpu_device_id, const bool has_categorical_feature);

  explicit CUDATree(const Tree* host_tree);

  ~CUDATree() noexcept;

  int Split(const int leaf_index,
            const int real_feature_index,
            const double real_threshold,
            const MissingType missing_type,
            const CUDASplitInfo* cuda_split_info);

  int SplitCategorical(
    const int leaf_index,
    const int real_feature_index,
    const MissingType missing_type,
    const CUDASplitInfo* cuda_split_info,
    uint32_t* cuda_bitset,
    size_t cuda_bitset_len,
    uint32_t* cuda_bitset_inner,
    size_t cuda_bitset_inner_len);

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  /*!
  * \brief Adding prediction value of this tree model to scores
  * \param data The dataset
  * \param num_data Number of total data
  * \param score Will add prediction to score
  */
  void AddPredictionToScore(const Dataset* data,
                            data_size_t num_data,
                            double* score) const override;

  /*!
  * \brief Adding prediction value of this tree model to scores
  * \param data The dataset
  * \param used_data_indices Indices of used data
  * \param num_data Number of total data
  * \param score Will add prediction to score
  */
  void AddPredictionToScore(const Dataset* data,
                            const data_size_t* used_data_indices,
                            data_size_t num_data, double* score) const override;

80
81
  inline void AsConstantTree(double val) override;

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
  const int* cuda_leaf_parent() const { return cuda_leaf_parent_; }

  const int* cuda_left_child() const { return cuda_left_child_; }

  const int* cuda_right_child() const { return cuda_right_child_; }

  const int* cuda_split_feature_inner() const { return cuda_split_feature_inner_; }

  const int* cuda_split_feature() const { return cuda_split_feature_; }

  const uint32_t* cuda_threshold_in_bin() const { return cuda_threshold_in_bin_; }

  const double* cuda_threshold() const { return cuda_threshold_; }

  const int8_t* cuda_decision_type() const { return cuda_decision_type_; }

  const double* cuda_leaf_value() const { return cuda_leaf_value_; }

  double* cuda_leaf_value_ref() { return cuda_leaf_value_; }

  inline void Shrinkage(double rate) override;

  inline void AddBias(double val) override;

  void ToHost();

  void SyncLeafOutputFromHostToCUDA();

  void SyncLeafOutputFromCUDAToHost();

 private:
  void InitCUDAMemory();

  void InitCUDA();

  void LaunchSplitKernel(const int leaf_index,
                         const int real_feature_index,
                         const double real_threshold,
                         const MissingType missing_type,
                         const CUDASplitInfo* cuda_split_info);

  void LaunchSplitCategoricalKernel(
    const int leaf_index,
    const int real_feature_index,
    const MissingType missing_type,
    const CUDASplitInfo* cuda_split_info,
    size_t cuda_bitset_len,
    size_t cuda_bitset_inner_len);

131
132
133
134
  void LaunchAddPredictionToScoreKernel(const Dataset* data,
                                        const data_size_t* used_data_indices,
                                        data_size_t num_data, double* score) const;

135
136
137
138
  void LaunchShrinkageKernel(const double rate);

  void LaunchAddBiasKernel(const double val);

139
140
141
142
  void RecordBranchFeatures(const int left_leaf_index,
                            const int right_leaf_index,
                            const int real_feature_index);

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
  int* cuda_left_child_;
  int* cuda_right_child_;
  int* cuda_split_feature_inner_;
  int* cuda_split_feature_;
  int* cuda_leaf_depth_;
  int* cuda_leaf_parent_;
  uint32_t* cuda_threshold_in_bin_;
  double* cuda_threshold_;
  double* cuda_internal_weight_;
  double* cuda_internal_value_;
  int8_t* cuda_decision_type_;
  double* cuda_leaf_value_;
  data_size_t* cuda_leaf_count_;
  double* cuda_leaf_weight_;
  data_size_t* cuda_internal_count_;
  float* cuda_split_gain_;
  CUDAVector<uint32_t> cuda_bitset_;
  CUDAVector<uint32_t> cuda_bitset_inner_;
  CUDAVector<int> cuda_cat_boundaries_;
  CUDAVector<int> cuda_cat_boundaries_inner_;

  cudaStream_t cuda_stream_;

  const int num_threads_per_block_add_prediction_to_score_;
};

}  // namespace LightGBM

#endif  // LIGHTGBM_CUDA_CUDA_TREE_HPP_

#endif  // USE_CUDA_EXP