"vscode:/vscode.git/clone" did not exist on "7820746266a9033294365a9129ecdd8a91928a02"
metric.cpp 6.15 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/metric.h>
6

Guolin Ke's avatar
Guolin Ke committed
7
#include "binary_metric.hpp"
Guolin Ke's avatar
Guolin Ke committed
8
#include "map_metric.hpp"
9
#include "multiclass_metric.hpp"
10
11
#include "rank_metric.hpp"
#include "regression_metric.hpp"
12
#include "xentropy_metric.hpp"
Guolin Ke's avatar
Guolin Ke committed
13

14
#include "cuda/cuda_binary_metric.hpp"
15
16
#include "cuda/cuda_regression_metric.hpp"

Guolin Ke's avatar
Guolin Ke committed
17
18
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
19
Metric* Metric::CreateMetric(const std::string& type, const Config& config) {
20
21
  #ifdef USE_CUDA
  if (config.device_type == std::string("cuda") && config.boosting == std::string("gbdt")) {
22
    if (type == std::string("l2")) {
23
      return new CUDAL2Metric(config);
24
    } else if (type == std::string("rmse")) {
25
      return new CUDARMSEMetric(config);
26
    } else if (type == std::string("l1")) {
27
      return new CUDAL1Metric(config);
28
    } else if (type == std::string("quantile")) {
29
      return new CUDAQuantileMetric(config);
30
    } else if (type == std::string("huber")) {
31
      return new CUDAHuberLossMetric(config);
32
    } else if (type == std::string("fair")) {
33
      return new CUDAFairLossMetric(config);
34
    } else if (type == std::string("poisson")) {
35
      return new CUDAPoissonMetric(config);
36
    } else if (type == std::string("binary_logloss")) {
37
      return new CUDABinaryLoglossMetric(config);
38
    } else if (type == std::string("binary_error")) {
39
      Log::Warning("Metric binary_error is not implemented in cuda version. Fall back to evaluation on CPU.");
40
41
      return new BinaryErrorMetric(config);
    } else if (type == std::string("auc")) {
42
      Log::Warning("Metric auc is not implemented in cuda version. Fall back to evaluation on CPU.");
43
44
      return new AUCMetric(config);
    } else if (type == std::string("average_precision")) {
45
      Log::Warning("Metric average_precision is not implemented in cuda version. Fall back to evaluation on CPU.");
46
47
      return new AveragePrecisionMetric(config);
    } else if (type == std::string("auc_mu")) {
48
      Log::Warning("Metric auc_mu is not implemented in cuda version. Fall back to evaluation on CPU.");
49
50
      return new AucMuMetric(config);
    } else if (type == std::string("ndcg")) {
51
      Log::Warning("Metric ndcg is not implemented in cuda version. Fall back to evaluation on CPU.");
52
53
      return new NDCGMetric(config);
    } else if (type == std::string("map")) {
54
      Log::Warning("Metric map is not implemented in cuda version. Fall back to evaluation on CPU.");
55
56
      return new MapMetric(config);
    } else if (type == std::string("multi_logloss")) {
57
      Log::Warning("Metric multi_logloss is not implemented in cuda version. Fall back to evaluation on CPU.");
58
59
      return new MultiSoftmaxLoglossMetric(config);
    } else if (type == std::string("multi_error")) {
60
      Log::Warning("Metric multi_error is not implemented in cuda version. Fall back to evaluation on CPU.");
61
62
      return new MultiErrorMetric(config);
    } else if (type == std::string("cross_entropy")) {
63
      Log::Warning("Metric cross_entropy is not implemented in cuda version. Fall back to evaluation on CPU.");
64
65
      return new CrossEntropyMetric(config);
    } else if (type == std::string("cross_entropy_lambda")) {
66
      Log::Warning("Metric cross_entropy_lambda is not implemented in cuda version. Fall back to evaluation on CPU.");
67
68
      return new CrossEntropyLambdaMetric(config);
    } else if (type == std::string("kullback_leibler")) {
69
      Log::Warning("Metric kullback_leibler is not implemented in cuda version. Fall back to evaluation on CPU.");
70
71
      return new KullbackLeiblerDivergence(config);
    } else if (type == std::string("mape")) {
72
      return new CUDAMAPEMetric(config);
73
    } else if (type == std::string("gamma")) {
74
      return new CUDAGammaMetric(config);
75
    } else if (type == std::string("gamma_deviance")) {
76
      return new CUDAGammaDevianceMetric(config);
77
    } else if (type == std::string("tweedie")) {
78
      return new CUDATweedieMetric(config);
79
80
    }
  } else {
81
  #endif  // USE_CUDA
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    if (type == std::string("l2")) {
      return new L2Metric(config);
    } else if (type == std::string("rmse")) {
      return new RMSEMetric(config);
    } else if (type == std::string("l1")) {
      return new L1Metric(config);
    } else if (type == std::string("quantile")) {
      return new QuantileMetric(config);
    } else if (type == std::string("huber")) {
      return new HuberLossMetric(config);
    } else if (type == std::string("fair")) {
      return new FairLossMetric(config);
    } else if (type == std::string("poisson")) {
      return new PoissonMetric(config);
    } else if (type == std::string("binary_logloss")) {
      return new BinaryLoglossMetric(config);
    } else if (type == std::string("binary_error")) {
      return new BinaryErrorMetric(config);
    } else if (type == std::string("auc")) {
      return new AUCMetric(config);
    } else if (type == std::string("average_precision")) {
      return new AveragePrecisionMetric(config);
    } else if (type == std::string("auc_mu")) {
      return new AucMuMetric(config);
    } else if (type == std::string("ndcg")) {
      return new NDCGMetric(config);
    } else if (type == std::string("map")) {
      return new MapMetric(config);
    } else if (type == std::string("multi_logloss")) {
      return new MultiSoftmaxLoglossMetric(config);
    } else if (type == std::string("multi_error")) {
      return new MultiErrorMetric(config);
    } else if (type == std::string("cross_entropy")) {
      return new CrossEntropyMetric(config);
    } else if (type == std::string("cross_entropy_lambda")) {
      return new CrossEntropyLambdaMetric(config);
    } else if (type == std::string("kullback_leibler")) {
      return new KullbackLeiblerDivergence(config);
    } else if (type == std::string("mape")) {
      return new MAPEMetric(config);
    } else if (type == std::string("gamma")) {
      return new GammaMetric(config);
    } else if (type == std::string("gamma_deviance")) {
      return new GammaDevianceMetric(config);
    } else if (type == std::string("tweedie")) {
      return new TweedieMetric(config);
    }
129
  #ifdef USE_CUDA
Guolin Ke's avatar
Guolin Ke committed
130
  }
131
  #endif  // USE_CUDA
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
  return nullptr;
}

}  // namespace LightGBM