"examples/vscode:/vscode.git/clone" did not exist on "fdc582ea6ba13faf15ee6707c7c7542790c8821d"
cuda_pointwise_metric.cpp 2.76 KB
Newer Older
1
2
3
4
5
6
/*!
 * Copyright (c) 2022 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
 */

7
#if defined(USE_CUDA) || defined(USE_ROCM)
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

#include "cuda_binary_metric.hpp"
#include "cuda_pointwise_metric.hpp"
#include "cuda_regression_metric.hpp"

namespace LightGBM {

template <typename HOST_METRIC, typename CUDA_METRIC>
void CUDAPointwiseMetricInterface<HOST_METRIC, CUDA_METRIC>::Init(const Metadata& metadata, data_size_t num_data) {
  CUDAMetricInterface<HOST_METRIC>::Init(metadata, num_data);
  const int max_num_reduce_blocks = (this->num_data_ + NUM_DATA_PER_EVAL_THREAD - 1) / NUM_DATA_PER_EVAL_THREAD;
  if (this->cuda_weights_ == nullptr) {
    reduce_block_buffer_.Resize(max_num_reduce_blocks);
  } else {
    reduce_block_buffer_.Resize(max_num_reduce_blocks * 2);
  }
  const int max_num_reduce_blocks_inner = (max_num_reduce_blocks + NUM_DATA_PER_EVAL_THREAD - 1) / NUM_DATA_PER_EVAL_THREAD;
  if (this->cuda_weights_ == nullptr) {
    reduce_block_buffer_inner_.Resize(max_num_reduce_blocks_inner);
  } else {
    reduce_block_buffer_inner_.Resize(max_num_reduce_blocks_inner * 2);
  }
}

template void CUDAPointwiseMetricInterface<RMSEMetric, CUDARMSEMetric>::Init(const Metadata& metadata, data_size_t num_data);
template void CUDAPointwiseMetricInterface<L2Metric, CUDAL2Metric>::Init(const Metadata& metadata, data_size_t num_data);
34
template void CUDAPointwiseMetricInterface<QuantileMetric, CUDAQuantileMetric>::Init(const Metadata& metadata, data_size_t num_data);
35
template void CUDAPointwiseMetricInterface<BinaryLoglossMetric, CUDABinaryLoglossMetric>::Init(const Metadata& metadata, data_size_t num_data);
36
37
38
39
40
41
42
43
template void CUDAPointwiseMetricInterface<L1Metric, CUDAL1Metric>::Init(const Metadata& metadata, data_size_t num_data);
template void CUDAPointwiseMetricInterface<HuberLossMetric, CUDAHuberLossMetric>::Init(const Metadata& metadata, data_size_t num_data);
template void CUDAPointwiseMetricInterface<FairLossMetric, CUDAFairLossMetric>::Init(const Metadata& metadata, data_size_t num_data);
template void CUDAPointwiseMetricInterface<PoissonMetric, CUDAPoissonMetric>::Init(const Metadata& metadata, data_size_t num_data);
template void CUDAPointwiseMetricInterface<MAPEMetric, CUDAMAPEMetric>::Init(const Metadata& metadata, data_size_t num_data);
template void CUDAPointwiseMetricInterface<GammaMetric, CUDAGammaMetric>::Init(const Metadata& metadata, data_size_t num_data);
template void CUDAPointwiseMetricInterface<GammaDevianceMetric, CUDAGammaDevianceMetric>::Init(const Metadata& metadata, data_size_t num_data);
template void CUDAPointwiseMetricInterface<TweedieMetric, CUDATweedieMetric>::Init(const Metadata& metadata, data_size_t num_data);
44
45
46

}  // namespace LightGBM

47
#endif  // USE_CUDA || USE_ROCM