objective_function.cpp 6.51 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/objective_function.h>
6

Guolin Ke's avatar
Guolin Ke committed
7
#include "binary_objective.hpp"
8
#include "multiclass_objective.hpp"
9
10
#include "rank_objective.hpp"
#include "regression_objective.hpp"
11
#include "xentropy_objective.hpp"
Guolin Ke's avatar
Guolin Ke committed
12

13
#include "cuda/cuda_binary_objective.hpp"
14
#include "cuda/cuda_multiclass_objective.hpp"
15
#include "cuda/cuda_rank_objective.hpp"
16
#include "cuda/cuda_regression_objective.hpp"
17

Guolin Ke's avatar
Guolin Ke committed
18
19
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
20
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const Config& config) {
21
22
  #ifdef USE_CUDA
  if (config.device_type == std::string("cuda") &&
23
24
      config.data_sample_strategy != std::string("goss") &&
      config.boosting != std::string("rf")) {
25
    if (type == std::string("regression")) {
26
      return new CUDARegressionL2loss(config);
27
    } else if (type == std::string("regression_l1")) {
28
      return new CUDARegressionL1loss(config);
29
    } else if (type == std::string("quantile")) {
30
      return new CUDARegressionQuantileloss(config);
31
    } else if (type == std::string("huber")) {
32
      return new CUDARegressionHuberLoss(config);
33
    } else if (type == std::string("fair")) {
34
      return new CUDARegressionFairLoss(config);
35
    } else if (type == std::string("poisson")) {
36
      return new CUDARegressionPoissonLoss(config);
37
    } else if (type == std::string("binary")) {
38
      return new CUDABinaryLogloss(config);
39
    } else if (type == std::string("lambdarank")) {
40
      return new CUDALambdarankNDCG(config);
41
    } else if (type == std::string("rank_xendcg")) {
42
      return new CUDARankXENDCG(config);
43
    } else if (type == std::string("multiclass")) {
44
      return new CUDAMulticlassSoftmax(config);
45
    } else if (type == std::string("multiclassova")) {
46
      return new CUDAMulticlassOVA(config);
47
    } else if (type == std::string("cross_entropy")) {
48
      Log::Warning("Objective cross_entropy is not implemented in cuda version. Fall back to boosting on CPU.");
49
50
      return new CrossEntropy(config);
    } else if (type == std::string("cross_entropy_lambda")) {
51
      Log::Warning("Objective cross_entropy_lambda is not implemented in cuda version. Fall back to boosting on CPU.");
52
53
      return new CrossEntropyLambda(config);
    } else if (type == std::string("mape")) {
54
      Log::Warning("Objective mape is not implemented in cuda version. Fall back to boosting on CPU.");
55
56
      return new RegressionMAPELOSS(config);
    } else if (type == std::string("gamma")) {
57
      Log::Warning("Objective gamma is not implemented in cuda version. Fall back to boosting on CPU.");
58
59
      return new RegressionGammaLoss(config);
    } else if (type == std::string("tweedie")) {
60
      Log::Warning("Objective tweedie is not implemented in cuda version. Fall back to boosting on CPU.");
61
62
      return new RegressionTweedieLoss(config);
    } else if (type == std::string("custom")) {
63
      Log::Warning("Using customized objective with cuda. This requires copying gradients from CPU to GPU, which can be slow.");
64
65
66
      return nullptr;
    }
  } else {
67
  #endif  // USE_CUDA
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    if (type == std::string("regression")) {
      return new RegressionL2loss(config);
    } else if (type == std::string("regression_l1")) {
      return new RegressionL1loss(config);
    } else if (type == std::string("quantile")) {
      return new RegressionQuantileloss(config);
    } else if (type == std::string("huber")) {
      return new RegressionHuberLoss(config);
    } else if (type == std::string("fair")) {
      return new RegressionFairLoss(config);
    } else if (type == std::string("poisson")) {
      return new RegressionPoissonLoss(config);
    } else if (type == std::string("binary")) {
      return new BinaryLogloss(config);
    } else if (type == std::string("lambdarank")) {
      return new LambdarankNDCG(config);
    } else if (type == std::string("rank_xendcg")) {
      return new RankXENDCG(config);
    } else if (type == std::string("multiclass")) {
      return new MulticlassSoftmax(config);
    } else if (type == std::string("multiclassova")) {
      return new MulticlassOVA(config);
    } else if (type == std::string("cross_entropy")) {
      return new CrossEntropy(config);
    } else if (type == std::string("cross_entropy_lambda")) {
      return new CrossEntropyLambda(config);
    } else if (type == std::string("mape")) {
      return new RegressionMAPELOSS(config);
    } else if (type == std::string("gamma")) {
      return new RegressionGammaLoss(config);
    } else if (type == std::string("tweedie")) {
      return new RegressionTweedieLoss(config);
    } else if (type == std::string("custom")) {
      return nullptr;
    }
103
  #ifdef USE_CUDA
Guolin Ke's avatar
Guolin Ke committed
104
  }
105
  #endif  // USE_CUDA
106
  Log::Fatal("Unknown objective type name: %s", type.c_str());
Guolin Ke's avatar
Guolin Ke committed
107
  return nullptr;
Guolin Ke's avatar
Guolin Ke committed
108
}
109
110

ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& str) {
Guolin Ke's avatar
Guolin Ke committed
111
  auto strs = Common::Split(str.c_str(), ' ');
112
113
114
115
116
  auto type = strs[0];
  if (type == std::string("regression")) {
    return new RegressionL2loss(strs);
  } else if (type == std::string("regression_l1")) {
    return new RegressionL1loss(strs);
117
118
  } else if (type == std::string("quantile")) {
    return new RegressionQuantileloss(strs);
119
120
121
122
123
124
125
126
127
128
  } else if (type == std::string("huber")) {
    return new RegressionHuberLoss(strs);
  } else if (type == std::string("fair")) {
    return new RegressionFairLoss(strs);
  } else if (type == std::string("poisson")) {
    return new RegressionPoissonLoss(strs);
  } else if (type == std::string("binary")) {
    return new BinaryLogloss(strs);
  } else if (type == std::string("lambdarank")) {
    return new LambdarankNDCG(strs);
129
130
  } else if (type == std::string("rank_xendcg")) {
    return new RankXENDCG(strs);
131
132
133
134
  } else if (type == std::string("multiclass")) {
    return new MulticlassSoftmax(strs);
  } else if (type == std::string("multiclassova")) {
    return new MulticlassOVA(strs);
Guolin Ke's avatar
Guolin Ke committed
135
  } else if (type == std::string("cross_entropy")) {
136
    return new CrossEntropy(strs);
Guolin Ke's avatar
Guolin Ke committed
137
  } else if (type == std::string("cross_entropy_lambda")) {
138
    return new CrossEntropyLambda(strs);
Guolin Ke's avatar
Guolin Ke committed
139
  } else if (type == std::string("mape")) {
Guolin Ke's avatar
Guolin Ke committed
140
    return new RegressionMAPELOSS(strs);
Guolin Ke's avatar
Guolin Ke committed
141
142
143
144
  } else if (type == std::string("gamma")) {
    return new RegressionGammaLoss(strs);
  } else if (type == std::string("tweedie")) {
    return new RegressionTweedieLoss(strs);
Guolin Ke's avatar
Guolin Ke committed
145
  } else if (type == std::string("custom")) {
146
    return nullptr;
147
  }
148
  Log::Fatal("Unknown objective type name: %s", type.c_str());
Guolin Ke's avatar
Guolin Ke committed
149
  return nullptr;
150
151
}

Guolin Ke's avatar
Guolin Ke committed
152
}  // namespace LightGBM