Commit bfb0217a authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Move all prediction transform to the objective. (#383)

* many refactors.

* remove multi_loglossova.

* fix tests.

* avoid using lambda function.

* fix some format.

* reduce branching.
parent d4c4d9ae
......@@ -29,4 +29,30 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
}
return nullptr;
}
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& str) {
auto strs = Common::Split(str.c_str(), " ");
auto type = strs[0];
if (type == std::string("regression")) {
return new RegressionL2loss(strs);
} else if (type == std::string("regression_l1")) {
return new RegressionL1loss(strs);
} else if (type == std::string("huber")) {
return new RegressionHuberLoss(strs);
} else if (type == std::string("fair")) {
return new RegressionFairLoss(strs);
} else if (type == std::string("poisson")) {
return new RegressionPoissonLoss(strs);
} else if (type == std::string("binary")) {
return new BinaryLogloss(strs);
} else if (type == std::string("lambdarank")) {
return new LambdarankNDCG(strs);
} else if (type == std::string("multiclass")) {
return new MulticlassSoftmax(strs);
} else if (type == std::string("multiclassova")) {
return new MulticlassOVA(strs);
}
return nullptr;
}
} // namespace LightGBM
......@@ -35,6 +35,11 @@ public:
Log::Fatal("Sigmoid param %f should be greater than zero", sigmoid_);
}
}
explicit LambdarankNDCG(const std::vector<std::string>&) {
}
~LambdarankNDCG() {
}
......@@ -196,6 +201,12 @@ public:
return "lambdarank";
}
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
private:
/*! \brief Gains for labels */
std::vector<double> label_gain_;
......
......@@ -13,6 +13,10 @@ public:
explicit RegressionL2loss(const ObjectiveConfig&) {
}
explicit RegressionL2loss(const std::vector<std::string>&) {
}
~RegressionL2loss() {
}
......@@ -43,6 +47,12 @@ public:
return "regression";
}
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
bool IsConstantHessian() const override {
if (weights_ == nullptr) {
return true;
......@@ -51,6 +61,8 @@ public:
}
}
bool BoostFromAverage() const override { return true; }
private:
/*! \brief Number of data */
data_size_t num_data_;
......@@ -69,6 +81,10 @@ public:
eta_ = static_cast<double>(config.gaussian_eta);
}
explicit RegressionL1loss(const std::vector<std::string>&) {
}
~RegressionL1loss() {}
void Init(const Metadata& metadata, data_size_t num_data) override {
......@@ -108,6 +124,14 @@ public:
return "regression_l1";
}
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
bool BoostFromAverage() const override { return true; }
private:
/*! \brief Number of data */
data_size_t num_data_;
......@@ -129,6 +153,10 @@ public:
eta_ = static_cast<double>(config.gaussian_eta);
}
explicit RegressionHuberLoss(const std::vector<std::string>&) {
}
~RegressionHuberLoss() {
}
......@@ -181,6 +209,14 @@ public:
return "huber";
}
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
bool BoostFromAverage() const override { return true; }
private:
/*! \brief Number of data */
data_size_t num_data_;
......@@ -202,6 +238,10 @@ public:
c_ = static_cast<double>(config.fair_c);
}
explicit RegressionFairLoss(const std::vector<std::string>&) {
}
~RegressionFairLoss() {}
void Init(const Metadata& metadata, data_size_t num_data) override {
......@@ -233,6 +273,14 @@ public:
return "fair";
}
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
bool BoostFromAverage() const override { return true; }
private:
/*! \brief Number of data */
data_size_t num_data_;
......@@ -254,6 +302,10 @@ public:
max_delta_step_ = static_cast<double>(config.poisson_max_delta_step);
}
explicit RegressionPoissonLoss(const std::vector<std::string>&) {
}
~RegressionPoissonLoss() {}
void Init(const Metadata& metadata, data_size_t num_data) override {
......@@ -283,6 +335,14 @@ public:
return "poisson";
}
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
bool BoostFromAverage() const override { return true; }
private:
/*! \brief Number of data */
data_size_t num_data_;
......
......@@ -222,6 +222,7 @@
<ClInclude Include="..\src\io\parser.hpp" />
<ClInclude Include="..\src\io\sparse_bin.hpp" />
<ClInclude Include="..\src\metric\binary_metric.hpp" />
<ClInclude Include="..\src\metric\map_metric.hpp" />
<ClInclude Include="..\src\metric\rank_metric.hpp" />
<ClInclude Include="..\src\metric\regression_metric.hpp" />
<ClInclude Include="..\src\metric\multiclass_metric.hpp" />
......
......@@ -180,6 +180,9 @@
<ClInclude Include="..\include\LightGBM\utils\openmp_wrapper.h">
<Filter>include\LightGBM\utils</Filter>
</ClInclude>
<ClInclude Include="..\src\metric\map_metric.hpp">
<Filter>src\metric</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\src\application\application.cpp">
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment