objective_function.cpp 3.78 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/objective_function.h>
6

Guolin Ke's avatar
Guolin Ke committed
7
#include "binary_objective.hpp"
8
#include "multiclass_objective.hpp"
9
#include "rank_objective.hpp"
10
#include "rank_xendcg_objective.hpp"
11
#include "regression_objective.hpp"
12
#include "xentropy_objective.hpp"
Guolin Ke's avatar
Guolin Ke committed
13
14
15

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
16
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const Config& config) {
Guolin Ke's avatar
Guolin Ke committed
17
  if (type == std::string("regression")) {
Guolin Ke's avatar
Guolin Ke committed
18
    return new RegressionL2loss(config);
Guolin Ke's avatar
Guolin Ke committed
19
  } else if (type == std::string("regression_l1")) {
20
    return new RegressionL1loss(config);
21
22
  } else if (type == std::string("quantile")) {
    return new RegressionQuantileloss(config);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
23
  } else if (type == std::string("huber")) {
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
24
    return new RegressionHuberLoss(config);
Tsukasa OMOTO's avatar
Tsukasa OMOTO committed
25
26
  } else if (type == std::string("fair")) {
    return new RegressionFairLoss(config);
27
28
  } else if (type == std::string("poisson")) {
    return new RegressionPoissonLoss(config);
Guolin Ke's avatar
Guolin Ke committed
29
  } else if (type == std::string("binary")) {
Guolin Ke's avatar
Guolin Ke committed
30
    return new BinaryLogloss(config);
Guolin Ke's avatar
Guolin Ke committed
31
  } else if (type == std::string("lambdarank")) {
Guolin Ke's avatar
Guolin Ke committed
32
    return new LambdarankNDCG(config);
33
34
  } else if (type == std::string("rank_xendcg")) {
    return new RankXENDCG(config);
Guolin Ke's avatar
Guolin Ke committed
35
  } else if (type == std::string("multiclass")) {
Guolin Ke's avatar
Guolin Ke committed
36
    return new MulticlassSoftmax(config);
Guolin Ke's avatar
Guolin Ke committed
37
  } else if (type == std::string("multiclassova")) {
Guolin Ke's avatar
Guolin Ke committed
38
    return new MulticlassOVA(config);
Guolin Ke's avatar
Guolin Ke committed
39
  } else if (type == std::string("cross_entropy")) {
40
    return new CrossEntropy(config);
Guolin Ke's avatar
Guolin Ke committed
41
  } else if (type == std::string("cross_entropy_lambda")) {
42
    return new CrossEntropyLambda(config);
Guolin Ke's avatar
Guolin Ke committed
43
  } else if (type == std::string("mape")) {
44
    return new RegressionMAPELOSS(config);
Guolin Ke's avatar
Guolin Ke committed
45
46
47
48
  } else if (type == std::string("gamma")) {
    return new RegressionGammaLoss(config);
  } else if (type == std::string("tweedie")) {
    return new RegressionTweedieLoss(config);
Guolin Ke's avatar
Guolin Ke committed
49
  } else if (type == std::string("custom")) {
50
    return nullptr;
Guolin Ke's avatar
Guolin Ke committed
51
  }
52
  Log::Fatal("Unknown objective type name: %s", type.c_str());
Guolin Ke's avatar
Guolin Ke committed
53
}
54
55

ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& str) {
Guolin Ke's avatar
Guolin Ke committed
56
  auto strs = Common::Split(str.c_str(), ' ');
57
58
59
60
61
  auto type = strs[0];
  if (type == std::string("regression")) {
    return new RegressionL2loss(strs);
  } else if (type == std::string("regression_l1")) {
    return new RegressionL1loss(strs);
62
63
  } else if (type == std::string("quantile")) {
    return new RegressionQuantileloss(strs);
64
65
66
67
68
69
70
71
72
73
  } else if (type == std::string("huber")) {
    return new RegressionHuberLoss(strs);
  } else if (type == std::string("fair")) {
    return new RegressionFairLoss(strs);
  } else if (type == std::string("poisson")) {
    return new RegressionPoissonLoss(strs);
  } else if (type == std::string("binary")) {
    return new BinaryLogloss(strs);
  } else if (type == std::string("lambdarank")) {
    return new LambdarankNDCG(strs);
74
75
  } else if (type == std::string("rank_xendcg")) {
    return new RankXENDCG(strs);
76
77
78
79
  } else if (type == std::string("multiclass")) {
    return new MulticlassSoftmax(strs);
  } else if (type == std::string("multiclassova")) {
    return new MulticlassOVA(strs);
Guolin Ke's avatar
Guolin Ke committed
80
  } else if (type == std::string("cross_entropy")) {
81
    return new CrossEntropy(strs);
Guolin Ke's avatar
Guolin Ke committed
82
  } else if (type == std::string("cross_entropy_lambda")) {
83
    return new CrossEntropyLambda(strs);
Guolin Ke's avatar
Guolin Ke committed
84
  } else if (type == std::string("mape")) {
Guolin Ke's avatar
Guolin Ke committed
85
    return new RegressionMAPELOSS(strs);
Guolin Ke's avatar
Guolin Ke committed
86
87
88
89
  } else if (type == std::string("gamma")) {
    return new RegressionGammaLoss(strs);
  } else if (type == std::string("tweedie")) {
    return new RegressionTweedieLoss(strs);
Guolin Ke's avatar
Guolin Ke committed
90
  } else if (type == std::string("custom")) {
91
    return nullptr;
92
  }
93
  Log::Fatal("Unknown objective type name: %s", type.c_str());
94
95
}

Guolin Ke's avatar
Guolin Ke committed
96
}  // namespace LightGBM