cuda_metric.hpp 1.2 KB
Newer Older
1
2
3
4
5
6
/*!
 * Copyright (c) 2021 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
 */

7
8
#ifndef LIGHTGBM_INCLUDE_LIGHTGBM_CUDA_CUDA_METRIC_HPP_
#define LIGHTGBM_INCLUDE_LIGHTGBM_CUDA_CUDA_METRIC_HPP_
9

10
#ifdef USE_CUDA
11

12
#include <LightGBM/cuda/cuda_utils.hu>
13
14
15
16
17
18
19
20
21
22
#include <LightGBM/metric.h>

namespace LightGBM {

template <typename HOST_METRIC>
class CUDAMetricInterface: public HOST_METRIC {
 public:
  explicit CUDAMetricInterface(const Config& config): HOST_METRIC(config) {
    cuda_labels_ = nullptr;
    cuda_weights_ = nullptr;
23
24
    const int gpu_device_id = config.gpu_device_id >= 0 ? config.gpu_device_id : 0;
    SetCUDADevice(gpu_device_id, __FILE__, __LINE__);
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
  }

  void Init(const Metadata& metadata, data_size_t num_data) override {
    HOST_METRIC::Init(metadata, num_data);
    cuda_labels_ = metadata.cuda_metadata()->cuda_label();
    cuda_weights_ = metadata.cuda_metadata()->cuda_weights();
  }

  bool IsCUDAMetric() const { return true; }

 protected:
  const label_t* cuda_labels_;
  const label_t* cuda_weights_;
};

}  // namespace LightGBM

42
#endif  // USE_CUDA
43

44
#endif  // LIGHTGBM_INCLUDE_LIGHTGBM_CUDA_CUDA_METRIC_HPP_