cuda_pointwise_metric.hpp 1.29 KB
Newer Older
1
2
3
4
5
6
/*!
 * Copyright (c) 2022 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
 */

7
8
#ifndef LIGHTGBM_SRC_METRIC_CUDA_CUDA_POINTWISE_METRIC_HPP_
#define LIGHTGBM_SRC_METRIC_CUDA_CUDA_POINTWISE_METRIC_HPP_
9

10
#ifdef USE_CUDA
11
12

#include <LightGBM/cuda/cuda_metric.hpp>
13
#include <LightGBM/cuda/cuda_utils.hu>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

#include <vector>

#define NUM_DATA_PER_EVAL_THREAD (1024)

namespace LightGBM {

template <typename HOST_METRIC, typename CUDA_METRIC>
class CUDAPointwiseMetricInterface: public CUDAMetricInterface<HOST_METRIC> {
 public:
  explicit CUDAPointwiseMetricInterface(const Config& config): CUDAMetricInterface<HOST_METRIC>(config), num_class_(config.num_class) {}

  virtual ~CUDAPointwiseMetricInterface() {}

  void Init(const Metadata& metadata, data_size_t num_data) override;

 protected:
  void LaunchEvalKernel(const double* score_convert, double* sum_loss, double* sum_weight) const;

33
34
  virtual double GetParamFromConfig() const { return 0.0; }

35
36
37
38
39
40
41
42
  mutable CUDAVector<double> score_convert_buffer_;
  CUDAVector<double> reduce_block_buffer_;
  CUDAVector<double> reduce_block_buffer_inner_;
  const int num_class_;
};

}  // namespace LightGBM

43
#endif  // USE_CUDA
44

45
#endif  // LIGHTGBM_SRC_METRIC_CUDA_CUDA_POINTWISE_METRIC_HPP_