Commit bfb0217a authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Move all prediction transform to the objective. (#383)

* many refactors.

* remove multi_loglossova.

* fix tests.

* avoid using lambda function.

* fix some format.

* reduce branching.
parent d4c4d9ae
...@@ -29,4 +29,30 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& ...@@ -29,4 +29,30 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
} }
return nullptr; return nullptr;
} }
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& str) {
auto strs = Common::Split(str.c_str(), " ");
auto type = strs[0];
if (type == std::string("regression")) {
return new RegressionL2loss(strs);
} else if (type == std::string("regression_l1")) {
return new RegressionL1loss(strs);
} else if (type == std::string("huber")) {
return new RegressionHuberLoss(strs);
} else if (type == std::string("fair")) {
return new RegressionFairLoss(strs);
} else if (type == std::string("poisson")) {
return new RegressionPoissonLoss(strs);
} else if (type == std::string("binary")) {
return new BinaryLogloss(strs);
} else if (type == std::string("lambdarank")) {
return new LambdarankNDCG(strs);
} else if (type == std::string("multiclass")) {
return new MulticlassSoftmax(strs);
} else if (type == std::string("multiclassova")) {
return new MulticlassOVA(strs);
}
return nullptr;
}
} // namespace LightGBM } // namespace LightGBM
...@@ -35,6 +35,11 @@ public: ...@@ -35,6 +35,11 @@ public:
Log::Fatal("Sigmoid param %f should be greater than zero", sigmoid_); Log::Fatal("Sigmoid param %f should be greater than zero", sigmoid_);
} }
} }
explicit LambdarankNDCG(const std::vector<std::string>&) {
}
~LambdarankNDCG() { ~LambdarankNDCG() {
} }
...@@ -196,6 +201,12 @@ public: ...@@ -196,6 +201,12 @@ public:
return "lambdarank"; return "lambdarank";
} }
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
private: private:
/*! \brief Gains for labels */ /*! \brief Gains for labels */
std::vector<double> label_gain_; std::vector<double> label_gain_;
......
...@@ -13,6 +13,10 @@ public: ...@@ -13,6 +13,10 @@ public:
explicit RegressionL2loss(const ObjectiveConfig&) { explicit RegressionL2loss(const ObjectiveConfig&) {
} }
explicit RegressionL2loss(const std::vector<std::string>&) {
}
~RegressionL2loss() { ~RegressionL2loss() {
} }
...@@ -43,6 +47,12 @@ public: ...@@ -43,6 +47,12 @@ public:
return "regression"; return "regression";
} }
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
bool IsConstantHessian() const override { bool IsConstantHessian() const override {
if (weights_ == nullptr) { if (weights_ == nullptr) {
return true; return true;
...@@ -51,6 +61,8 @@ public: ...@@ -51,6 +61,8 @@ public:
} }
} }
bool BoostFromAverage() const override { return true; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
...@@ -69,6 +81,10 @@ public: ...@@ -69,6 +81,10 @@ public:
eta_ = static_cast<double>(config.gaussian_eta); eta_ = static_cast<double>(config.gaussian_eta);
} }
explicit RegressionL1loss(const std::vector<std::string>&) {
}
~RegressionL1loss() {} ~RegressionL1loss() {}
void Init(const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
...@@ -108,6 +124,14 @@ public: ...@@ -108,6 +124,14 @@ public:
return "regression_l1"; return "regression_l1";
} }
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
bool BoostFromAverage() const override { return true; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
...@@ -129,6 +153,10 @@ public: ...@@ -129,6 +153,10 @@ public:
eta_ = static_cast<double>(config.gaussian_eta); eta_ = static_cast<double>(config.gaussian_eta);
} }
explicit RegressionHuberLoss(const std::vector<std::string>&) {
}
~RegressionHuberLoss() { ~RegressionHuberLoss() {
} }
...@@ -181,6 +209,14 @@ public: ...@@ -181,6 +209,14 @@ public:
return "huber"; return "huber";
} }
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
bool BoostFromAverage() const override { return true; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
...@@ -202,6 +238,10 @@ public: ...@@ -202,6 +238,10 @@ public:
c_ = static_cast<double>(config.fair_c); c_ = static_cast<double>(config.fair_c);
} }
explicit RegressionFairLoss(const std::vector<std::string>&) {
}
~RegressionFairLoss() {} ~RegressionFairLoss() {}
void Init(const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
...@@ -233,6 +273,14 @@ public: ...@@ -233,6 +273,14 @@ public:
return "fair"; return "fair";
} }
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
bool BoostFromAverage() const override { return true; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
...@@ -254,6 +302,10 @@ public: ...@@ -254,6 +302,10 @@ public:
max_delta_step_ = static_cast<double>(config.poisson_max_delta_step); max_delta_step_ = static_cast<double>(config.poisson_max_delta_step);
} }
explicit RegressionPoissonLoss(const std::vector<std::string>&) {
}
~RegressionPoissonLoss() {} ~RegressionPoissonLoss() {}
void Init(const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
...@@ -283,6 +335,14 @@ public: ...@@ -283,6 +335,14 @@ public:
return "poisson"; return "poisson";
} }
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
bool BoostFromAverage() const override { return true; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
......
...@@ -222,6 +222,7 @@ ...@@ -222,6 +222,7 @@
<ClInclude Include="..\src\io\parser.hpp" /> <ClInclude Include="..\src\io\parser.hpp" />
<ClInclude Include="..\src\io\sparse_bin.hpp" /> <ClInclude Include="..\src\io\sparse_bin.hpp" />
<ClInclude Include="..\src\metric\binary_metric.hpp" /> <ClInclude Include="..\src\metric\binary_metric.hpp" />
<ClInclude Include="..\src\metric\map_metric.hpp" />
<ClInclude Include="..\src\metric\rank_metric.hpp" /> <ClInclude Include="..\src\metric\rank_metric.hpp" />
<ClInclude Include="..\src\metric\regression_metric.hpp" /> <ClInclude Include="..\src\metric\regression_metric.hpp" />
<ClInclude Include="..\src\metric\multiclass_metric.hpp" /> <ClInclude Include="..\src\metric\multiclass_metric.hpp" />
......
...@@ -180,6 +180,9 @@ ...@@ -180,6 +180,9 @@
<ClInclude Include="..\include\LightGBM\utils\openmp_wrapper.h"> <ClInclude Include="..\include\LightGBM\utils\openmp_wrapper.h">
<Filter>include\LightGBM\utils</Filter> <Filter>include\LightGBM\utils</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\src\metric\map_metric.hpp">
<Filter>src\metric</Filter>
</ClInclude>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="..\src\application\application.cpp"> <ClCompile Include="..\src\application\application.cpp">
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment