cuda_split_info.hpp 2.89 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2021 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
5
 * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved.
6
7
 */

8
#ifdef USE_CUDA
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

#ifndef LIGHTGBM_CUDA_CUDA_SPLIT_INFO_HPP_
#define LIGHTGBM_CUDA_CUDA_SPLIT_INFO_HPP_

#include <LightGBM/meta.h>

namespace LightGBM {

class CUDASplitInfo {
 public:
  bool is_valid;
  int leaf_index;
  double gain;
  int inner_feature_index;
  uint32_t threshold;
  bool default_left;

  double left_sum_gradients;
  double left_sum_hessians;
28
  int64_t left_sum_of_gradients_hessians;
29
30
31
32
33
34
  data_size_t left_count;
  double left_gain;
  double left_value;

  double right_sum_gradients;
  double right_sum_hessians;
35
  int64_t right_sum_of_gradients_hessians;
36
37
38
39
40
41
42
43
  data_size_t right_count;
  double right_gain;
  double right_value;

  int num_cat_threshold = 0;
  uint32_t* cat_threshold = nullptr;
  int* cat_threshold_real = nullptr;

44
  __host__ __device__ CUDASplitInfo() {
45
46
47
48
49
    num_cat_threshold = 0;
    cat_threshold = nullptr;
    cat_threshold_real = nullptr;
  }

50
  __host__ __device__ ~CUDASplitInfo() {
51
52
    if (num_cat_threshold > 0) {
      if (cat_threshold != nullptr) {
53
        CUDASUCCESS_OR_FATAL(cudaFree(cat_threshold));
54
55
      }
      if (cat_threshold_real != nullptr) {
56
        CUDASUCCESS_OR_FATAL(cudaFree(cat_threshold_real));
57
58
59
60
      }
    }
  }

61
  __host__ __device__ CUDASplitInfo& operator=(const CUDASplitInfo& other) {
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    is_valid = other.is_valid;
    leaf_index = other.leaf_index;
    gain = other.gain;
    inner_feature_index = other.inner_feature_index;
    threshold = other.threshold;
    default_left = other.default_left;

    left_sum_gradients = other.left_sum_gradients;
    left_sum_hessians = other.left_sum_hessians;
    left_count = other.left_count;
    left_gain = other.left_gain;
    left_value = other.left_value;

    right_sum_gradients = other.right_sum_gradients;
    right_sum_hessians = other.right_sum_hessians;
    right_count = other.right_count;
    right_gain = other.right_gain;
    right_value = other.right_value;

    num_cat_threshold = other.num_cat_threshold;
    if (num_cat_threshold > 0 && cat_threshold == nullptr) {
      cat_threshold = new uint32_t[num_cat_threshold];
    }
    if (num_cat_threshold > 0 && cat_threshold_real == nullptr) {
      cat_threshold_real = new int[num_cat_threshold];
    }
    if (num_cat_threshold > 0) {
      if (other.cat_threshold != nullptr) {
        for (int i = 0; i < num_cat_threshold; ++i) {
          cat_threshold[i] = other.cat_threshold[i];
        }
      }
      if (other.cat_threshold_real != nullptr) {
        for (int i = 0; i < num_cat_threshold; ++i) {
          cat_threshold_real[i] = other.cat_threshold_real[i];
        }
      }
    }
    return *this;
  }
};

}  // namespace LightGBM

#endif  // LIGHTGBM_CUDA_CUDA_SPLIT_INFO_HPP_

108
#endif  // USE_CUDA