objective_function.cpp 6.53 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/objective_function.h>
6

7
8
#include <string>

Guolin Ke's avatar
Guolin Ke committed
9
#include "binary_objective.hpp"
10
#include "multiclass_objective.hpp"
11
12
#include "rank_objective.hpp"
#include "regression_objective.hpp"
13
#include "xentropy_objective.hpp"
Guolin Ke's avatar
Guolin Ke committed
14

15
#include "cuda/cuda_binary_objective.hpp"
16
#include "cuda/cuda_multiclass_objective.hpp"
17
#include "cuda/cuda_rank_objective.hpp"
18
#include "cuda/cuda_regression_objective.hpp"
19

Guolin Ke's avatar
Guolin Ke committed
20
21
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
22
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const Config& config) {
23
24
  #ifdef USE_CUDA
  if (config.device_type == std::string("cuda") &&
25
26
      config.data_sample_strategy != std::string("goss") &&
      config.boosting != std::string("rf")) {
27
    if (type == std::string("regression")) {
28
      return new CUDARegressionL2loss(config);
29
    } else if (type == std::string("regression_l1")) {
30
      return new CUDARegressionL1loss(config);
31
    } else if (type == std::string("quantile")) {
32
      return new CUDARegressionQuantileloss(config);
33
    } else if (type == std::string("huber")) {
34
      return new CUDARegressionHuberLoss(config);
35
    } else if (type == std::string("fair")) {
36
      return new CUDARegressionFairLoss(config);
37
    } else if (type == std::string("poisson")) {
38
      return new CUDARegressionPoissonLoss(config);
39
    } else if (type == std::string("binary")) {
40
      return new CUDABinaryLogloss(config);
41
    } else if (type == std::string("lambdarank")) {
42
      return new CUDALambdarankNDCG(config);
43
    } else if (type == std::string("rank_xendcg")) {
44
      return new CUDARankXENDCG(config);
45
    } else if (type == std::string("multiclass")) {
46
      return new CUDAMulticlassSoftmax(config);
47
    } else if (type == std::string("multiclassova")) {
48
      return new CUDAMulticlassOVA(config);
49
    } else if (type == std::string("cross_entropy")) {
50
      Log::Warning("Objective cross_entropy is not implemented in cuda version. Fall back to boosting on CPU.");
51
52
      return new CrossEntropy(config);
    } else if (type == std::string("cross_entropy_lambda")) {
53
      Log::Warning("Objective cross_entropy_lambda is not implemented in cuda version. Fall back to boosting on CPU.");
54
55
      return new CrossEntropyLambda(config);
    } else if (type == std::string("mape")) {
56
      Log::Warning("Objective mape is not implemented in cuda version. Fall back to boosting on CPU.");
57
58
      return new RegressionMAPELOSS(config);
    } else if (type == std::string("gamma")) {
59
      Log::Warning("Objective gamma is not implemented in cuda version. Fall back to boosting on CPU.");
60
61
      return new RegressionGammaLoss(config);
    } else if (type == std::string("tweedie")) {
62
      Log::Warning("Objective tweedie is not implemented in cuda version. Fall back to boosting on CPU.");
63
64
      return new RegressionTweedieLoss(config);
    } else if (type == std::string("custom")) {
65
      Log::Warning("Using customized objective with cuda. This requires copying gradients from CPU to GPU, which can be slow.");
66
67
68
      return nullptr;
    }
  } else {
69
  #endif  // USE_CUDA
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    if (type == std::string("regression")) {
      return new RegressionL2loss(config);
    } else if (type == std::string("regression_l1")) {
      return new RegressionL1loss(config);
    } else if (type == std::string("quantile")) {
      return new RegressionQuantileloss(config);
    } else if (type == std::string("huber")) {
      return new RegressionHuberLoss(config);
    } else if (type == std::string("fair")) {
      return new RegressionFairLoss(config);
    } else if (type == std::string("poisson")) {
      return new RegressionPoissonLoss(config);
    } else if (type == std::string("binary")) {
      return new BinaryLogloss(config);
    } else if (type == std::string("lambdarank")) {
      return new LambdarankNDCG(config);
    } else if (type == std::string("rank_xendcg")) {
      return new RankXENDCG(config);
    } else if (type == std::string("multiclass")) {
      return new MulticlassSoftmax(config);
    } else if (type == std::string("multiclassova")) {
      return new MulticlassOVA(config);
    } else if (type == std::string("cross_entropy")) {
      return new CrossEntropy(config);
    } else if (type == std::string("cross_entropy_lambda")) {
      return new CrossEntropyLambda(config);
    } else if (type == std::string("mape")) {
      return new RegressionMAPELOSS(config);
    } else if (type == std::string("gamma")) {
      return new RegressionGammaLoss(config);
    } else if (type == std::string("tweedie")) {
      return new RegressionTweedieLoss(config);
    } else if (type == std::string("custom")) {
      return nullptr;
    }
105
  #ifdef USE_CUDA
Guolin Ke's avatar
Guolin Ke committed
106
  }
107
  #endif  // USE_CUDA
108
  Log::Fatal("Unknown objective type name: %s", type.c_str());
Guolin Ke's avatar
Guolin Ke committed
109
  return nullptr;
Guolin Ke's avatar
Guolin Ke committed
110
}
111
112

ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& str) {
Guolin Ke's avatar
Guolin Ke committed
113
  auto strs = Common::Split(str.c_str(), ' ');
114
115
116
117
118
  auto type = strs[0];
  if (type == std::string("regression")) {
    return new RegressionL2loss(strs);
  } else if (type == std::string("regression_l1")) {
    return new RegressionL1loss(strs);
119
120
  } else if (type == std::string("quantile")) {
    return new RegressionQuantileloss(strs);
121
122
123
124
125
126
127
128
129
130
  } else if (type == std::string("huber")) {
    return new RegressionHuberLoss(strs);
  } else if (type == std::string("fair")) {
    return new RegressionFairLoss(strs);
  } else if (type == std::string("poisson")) {
    return new RegressionPoissonLoss(strs);
  } else if (type == std::string("binary")) {
    return new BinaryLogloss(strs);
  } else if (type == std::string("lambdarank")) {
    return new LambdarankNDCG(strs);
131
132
  } else if (type == std::string("rank_xendcg")) {
    return new RankXENDCG(strs);
133
134
135
136
  } else if (type == std::string("multiclass")) {
    return new MulticlassSoftmax(strs);
  } else if (type == std::string("multiclassova")) {
    return new MulticlassOVA(strs);
Guolin Ke's avatar
Guolin Ke committed
137
  } else if (type == std::string("cross_entropy")) {
138
    return new CrossEntropy(strs);
Guolin Ke's avatar
Guolin Ke committed
139
  } else if (type == std::string("cross_entropy_lambda")) {
140
    return new CrossEntropyLambda(strs);
Guolin Ke's avatar
Guolin Ke committed
141
  } else if (type == std::string("mape")) {
Guolin Ke's avatar
Guolin Ke committed
142
    return new RegressionMAPELOSS(strs);
Guolin Ke's avatar
Guolin Ke committed
143
144
145
146
  } else if (type == std::string("gamma")) {
    return new RegressionGammaLoss(strs);
  } else if (type == std::string("tweedie")) {
    return new RegressionTweedieLoss(strs);
Guolin Ke's avatar
Guolin Ke committed
147
  } else if (type == std::string("custom")) {
148
    return nullptr;
149
  }
150
  Log::Fatal("Unknown objective type name: %s", type.c_str());
Guolin Ke's avatar
Guolin Ke committed
151
  return nullptr;
152
153
}

Guolin Ke's avatar
Guolin Ke committed
154
}  // namespace LightGBM