objective_function.cpp 4.28 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
#include <LightGBM/objective_function.h>
2

Guolin Ke's avatar
Guolin Ke committed
3
#include "binary_objective.hpp"
4
#include "multiclass_objective.hpp"
5
6
#include "rank_objective.hpp"
#include "regression_objective.hpp"
7
#include "xentropy_objective.hpp"
Guolin Ke's avatar
Guolin Ke committed
8
9
10

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
11
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const Config& config) {
12
  if (type == std::string("regression") || type == std::string("regression_l2")
13
      || type == std::string("mean_squared_error") || type == std::string("mse")
14
      || type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
Guolin Ke's avatar
Guolin Ke committed
15
    return new RegressionL2loss(config);
16
17
  } else if (type == std::string("regression_l1") || type == std::string("mean_absolute_error")  || type == std::string("mae")) {
    return new RegressionL1loss(config);
18
19
  } else if (type == std::string("quantile")) {
    return new RegressionQuantileloss(config);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
20
  } else if (type == std::string("huber")) {
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
21
    return new RegressionHuberLoss(config);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
22
23
  } else if (type == std::string("fair")) {
    return new RegressionFairLoss(config);
24
25
  } else if (type == std::string("poisson")) {
    return new RegressionPoissonLoss(config);
Guolin Ke's avatar
Guolin Ke committed
26
  } else if (type == std::string("binary")) {
Guolin Ke's avatar
Guolin Ke committed
27
    return new BinaryLogloss(config);
Guolin Ke's avatar
Guolin Ke committed
28
  } else if (type == std::string("lambdarank")) {
Guolin Ke's avatar
Guolin Ke committed
29
    return new LambdarankNDCG(config);
Nikita Titov's avatar
Nikita Titov committed
30
  } else if (type == std::string("multiclass") || type == std::string("softmax")) {
Guolin Ke's avatar
Guolin Ke committed
31
    return new MulticlassSoftmax(config);
Nikita Titov's avatar
Nikita Titov committed
32
  } else if (type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
Guolin Ke's avatar
Guolin Ke committed
33
    return new MulticlassOVA(config);
34
35
36
37
  } else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
    return new CrossEntropy(config);
  } else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
    return new CrossEntropyLambda(config);
38
39
  } else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
    return new RegressionMAPELOSS(config);
Guolin Ke's avatar
Guolin Ke committed
40
41
42
43
  } else if (type == std::string("gamma")) {
    return new RegressionGammaLoss(config);
  } else if (type == std::string("tweedie")) {
    return new RegressionTweedieLoss(config);
44
  } else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
45
    return nullptr;
Guolin Ke's avatar
Guolin Ke committed
46
  }
47
  Log::Fatal("Unknown objective type name: %s", type.c_str());
Guolin Ke's avatar
Guolin Ke committed
48
}
49
50

ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& str) {
Guolin Ke's avatar
Guolin Ke committed
51
  auto strs = Common::Split(str.c_str(), ' ');
52
53
54
55
56
  auto type = strs[0];
  if (type == std::string("regression")) {
    return new RegressionL2loss(strs);
  } else if (type == std::string("regression_l1")) {
    return new RegressionL1loss(strs);
57
58
  } else if (type == std::string("quantile")) {
    return new RegressionQuantileloss(strs);
59
60
61
62
63
64
65
66
67
68
69
70
71
72
  } else if (type == std::string("huber")) {
    return new RegressionHuberLoss(strs);
  } else if (type == std::string("fair")) {
    return new RegressionFairLoss(strs);
  } else if (type == std::string("poisson")) {
    return new RegressionPoissonLoss(strs);
  } else if (type == std::string("binary")) {
    return new BinaryLogloss(strs);
  } else if (type == std::string("lambdarank")) {
    return new LambdarankNDCG(strs);
  } else if (type == std::string("multiclass")) {
    return new MulticlassSoftmax(strs);
  } else if (type == std::string("multiclassova")) {
    return new MulticlassOVA(strs);
73
74
75
76
  } else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
    return new CrossEntropy(strs);
  } else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
    return new CrossEntropyLambda(strs);
Guolin Ke's avatar
Guolin Ke committed
77
78
  } else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
    return new RegressionMAPELOSS(strs);
Guolin Ke's avatar
Guolin Ke committed
79
80
81
82
  } else if (type == std::string("gamma")) {
    return new RegressionGammaLoss(strs);
  } else if (type == std::string("tweedie")) {
    return new RegressionTweedieLoss(strs);
83
  } else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
84
    return nullptr;
85
  }
86
  Log::Fatal("Unknown objective type name: %s", type.c_str());
87
88
}

Guolin Ke's avatar
Guolin Ke committed
89
}  // namespace LightGBM