cuda_metadata.hpp 1.58 KB
Newer Older
1
2
3
4
5
/*!
 * Copyright (c) 2021 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */

6
7
#ifndef LIGHTGBM_INCLUDE_LIGHTGBM_CUDA_CUDA_METADATA_HPP_
#define LIGHTGBM_INCLUDE_LIGHTGBM_CUDA_CUDA_METADATA_HPP_
8

9
#ifdef USE_CUDA
10

11
#include <LightGBM/cuda/cuda_utils.hu>
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <LightGBM/meta.h>

#include <vector>

namespace LightGBM {

class CUDAMetadata {
 public:
  explicit CUDAMetadata(const int gpu_device_id);

  ~CUDAMetadata();

  void Init(const std::vector<label_t>& label,
            const std::vector<label_t>& weight,
            const std::vector<data_size_t>& query_boundaries,
            const std::vector<label_t>& query_weights,
            const std::vector<double>& init_score);

  void SetLabel(const label_t* label, data_size_t len);

  void SetWeights(const label_t* weights, data_size_t len);

  void SetQuery(const data_size_t* query, const label_t* query_weights, data_size_t num_queries);

  void SetInitScore(const double* init_score, data_size_t len);

  const label_t* cuda_label() const { return cuda_label_; }

  const label_t* cuda_weights() const { return cuda_weights_; }

  const data_size_t* cuda_query_boundaries() const { return cuda_query_boundaries_; }

  const label_t* cuda_query_weights() const { return cuda_query_weights_; }

 private:
  label_t* cuda_label_;
  label_t* cuda_weights_;
  data_size_t* cuda_query_boundaries_;
  label_t* cuda_query_weights_;
  double* cuda_init_score_;
};

}  // namespace LightGBM

56
#endif  // USE_CUDA
57
58

#endif  // LIGHTGBM_INCLUDE_LIGHTGBM_CUDA_CUDA_METADATA_HPP_