metric.cpp 6.42 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/metric.h>
6

7
8
#include <string>

Guolin Ke's avatar
Guolin Ke committed
9
#include "binary_metric.hpp"
Guolin Ke's avatar
Guolin Ke committed
10
#include "map_metric.hpp"
11
#include "multiclass_metric.hpp"
12
13
#include "rank_metric.hpp"
#include "regression_metric.hpp"
14
#include "xentropy_metric.hpp"
Guolin Ke's avatar
Guolin Ke committed
15

16
#include "cuda/cuda_binary_metric.hpp"
17
18
#include "cuda/cuda_regression_metric.hpp"

Guolin Ke's avatar
Guolin Ke committed
19
20
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
21
Metric* Metric::CreateMetric(const std::string& type, const Config& config) {
22
23
  #ifdef USE_CUDA
  if (config.device_type == std::string("cuda") && config.boosting == std::string("gbdt")) {
24
    if (type == std::string("l2")) {
25
      return new CUDAL2Metric(config);
26
    } else if (type == std::string("rmse")) {
27
      return new CUDARMSEMetric(config);
28
    } else if (type == std::string("l1")) {
29
      return new CUDAL1Metric(config);
30
    } else if (type == std::string("quantile")) {
31
      return new CUDAQuantileMetric(config);
32
    } else if (type == std::string("huber")) {
33
      return new CUDAHuberLossMetric(config);
34
    } else if (type == std::string("fair")) {
35
      return new CUDAFairLossMetric(config);
36
    } else if (type == std::string("poisson")) {
37
      return new CUDAPoissonMetric(config);
38
    } else if (type == std::string("binary_logloss")) {
39
      return new CUDABinaryLoglossMetric(config);
40
    } else if (type == std::string("binary_error")) {
41
      Log::Warning("Metric binary_error is not implemented in cuda version. Fall back to evaluation on CPU.");
42
43
      return new BinaryErrorMetric(config);
    } else if (type == std::string("auc")) {
44
      Log::Warning("Metric auc is not implemented in cuda version. Fall back to evaluation on CPU.");
45
46
      return new AUCMetric(config);
    } else if (type == std::string("average_precision")) {
47
      Log::Warning("Metric average_precision is not implemented in cuda version. Fall back to evaluation on CPU.");
48
49
      return new AveragePrecisionMetric(config);
    } else if (type == std::string("auc_mu")) {
50
      Log::Warning("Metric auc_mu is not implemented in cuda version. Fall back to evaluation on CPU.");
51
52
      return new AucMuMetric(config);
    } else if (type == std::string("ndcg")) {
53
      Log::Warning("Metric ndcg is not implemented in cuda version. Fall back to evaluation on CPU.");
54
55
      return new NDCGMetric(config);
    } else if (type == std::string("map")) {
56
      Log::Warning("Metric map is not implemented in cuda version. Fall back to evaluation on CPU.");
57
58
      return new MapMetric(config);
    } else if (type == std::string("multi_logloss")) {
59
      Log::Warning("Metric multi_logloss is not implemented in cuda version. Fall back to evaluation on CPU.");
60
61
      return new MultiSoftmaxLoglossMetric(config);
    } else if (type == std::string("multi_error")) {
62
      Log::Warning("Metric multi_error is not implemented in cuda version. Fall back to evaluation on CPU.");
63
64
      return new MultiErrorMetric(config);
    } else if (type == std::string("cross_entropy")) {
65
      Log::Warning("Metric cross_entropy is not implemented in cuda version. Fall back to evaluation on CPU.");
66
67
      return new CrossEntropyMetric(config);
    } else if (type == std::string("cross_entropy_lambda")) {
68
      Log::Warning("Metric cross_entropy_lambda is not implemented in cuda version. Fall back to evaluation on CPU.");
69
70
      return new CrossEntropyLambdaMetric(config);
    } else if (type == std::string("kullback_leibler")) {
71
      Log::Warning("Metric kullback_leibler is not implemented in cuda version. Fall back to evaluation on CPU.");
72
73
      return new KullbackLeiblerDivergence(config);
    } else if (type == std::string("mape")) {
74
      return new CUDAMAPEMetric(config);
75
    } else if (type == std::string("gamma")) {
76
      return new CUDAGammaMetric(config);
77
    } else if (type == std::string("gamma_deviance")) {
78
      return new CUDAGammaDevianceMetric(config);
79
    } else if (type == std::string("tweedie")) {
80
      return new CUDATweedieMetric(config);
81
82
83
    } else if (type == std::string("r2")) {
      Log::Warning("Metric r2 is not implemented in cuda version. Fall back to evaluation on CPU.");
      return new R2Metric(config);
84
85
    }
  } else {
86
  #endif  // USE_CUDA
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    if (type == std::string("l2")) {
      return new L2Metric(config);
    } else if (type == std::string("rmse")) {
      return new RMSEMetric(config);
    } else if (type == std::string("l1")) {
      return new L1Metric(config);
    } else if (type == std::string("quantile")) {
      return new QuantileMetric(config);
    } else if (type == std::string("huber")) {
      return new HuberLossMetric(config);
    } else if (type == std::string("fair")) {
      return new FairLossMetric(config);
    } else if (type == std::string("poisson")) {
      return new PoissonMetric(config);
    } else if (type == std::string("binary_logloss")) {
      return new BinaryLoglossMetric(config);
    } else if (type == std::string("binary_error")) {
      return new BinaryErrorMetric(config);
    } else if (type == std::string("auc")) {
      return new AUCMetric(config);
    } else if (type == std::string("average_precision")) {
      return new AveragePrecisionMetric(config);
    } else if (type == std::string("auc_mu")) {
      return new AucMuMetric(config);
    } else if (type == std::string("ndcg")) {
      return new NDCGMetric(config);
    } else if (type == std::string("map")) {
      return new MapMetric(config);
    } else if (type == std::string("multi_logloss")) {
      return new MultiSoftmaxLoglossMetric(config);
    } else if (type == std::string("multi_error")) {
      return new MultiErrorMetric(config);
    } else if (type == std::string("cross_entropy")) {
      return new CrossEntropyMetric(config);
    } else if (type == std::string("cross_entropy_lambda")) {
      return new CrossEntropyLambdaMetric(config);
    } else if (type == std::string("kullback_leibler")) {
      return new KullbackLeiblerDivergence(config);
    } else if (type == std::string("mape")) {
      return new MAPEMetric(config);
    } else if (type == std::string("gamma")) {
      return new GammaMetric(config);
    } else if (type == std::string("gamma_deviance")) {
      return new GammaDevianceMetric(config);
    } else if (type == std::string("tweedie")) {
      return new TweedieMetric(config);
133
134
    } else if (type == std::string("r2")) {
      return new R2Metric(config);
135
    }
136
  #ifdef USE_CUDA
Guolin Ke's avatar
Guolin Ke committed
137
  }
138
  #endif  // USE_CUDA
Guolin Ke's avatar
Guolin Ke committed
139
140
141
142
  return nullptr;
}

}  // namespace LightGBM