Commit ba99bcdd authored by Laurae's avatar Laurae Committed by Guolin Ke
Browse files

Switch RMSE to MSE (true L2 loss) (#408)

* RMSE (L2) -> MSE (true L2)

* Remove sqrt unneeded reference

* Square L2 test (RMSE to MSE)

* No square root on test

* Attempt to add RMSE
parent 18d6a902
...@@ -203,9 +203,10 @@ The parameter format is ```key1=value1 key2=value2 ... ``` . And parameters can ...@@ -203,9 +203,10 @@ The parameter format is ```key1=value1 key2=value2 ... ``` . And parameters can
## Metric parameters ## Metric parameters
* ```metric```, default={```l2``` for regression}, {```binary_logloss``` for binary classification},{```ndcg``` for lambdarank}, type=multi-enum, options=```l1```,```l2```,```ndcg```,```auc```,```binary_logloss```,```binary_error``` * ```metric```, default={```l2``` for regression}, {```binary_logloss``` for binary classification},{```ndcg``` for lambdarank}, type=multi-enum, options=```l1```,```l2```,```ndcg```,```auc```,```binary_logloss```,```binary_error```...
* ```l1```, absolute loss, alias=```mean_absolute_error```, ```mae``` * ```l1```, absolute loss, alias=```mean_absolute_error```, ```mae```
* ```l2```, square loss, alias=```mean_squared_error```, ```mse``` * ```l2```, square loss, alias=```mean_squared_error```, ```mse```
* ```l2_root```, root square loss, alias=```root_mean_squared_error```, ```rmse```
* ```huber```, [Huber loss](https://en.wikipedia.org/wiki/Huber_loss "Huber loss - Wikipedia") * ```huber```, [Huber loss](https://en.wikipedia.org/wiki/Huber_loss "Huber loss - Wikipedia")
* ```fair```, [Fair loss](https://www.kaggle.com/c/allstate-claims-severity/discussion/24520) * ```fair```, [Fair loss](https://www.kaggle.com/c/allstate-claims-severity/discussion/24520)
* ```poisson```, [Poisson regression](https://en.wikipedia.org/wiki/Poisson_regression "Poisson regression") * ```poisson```, [Poisson regression](https://en.wikipedia.org/wiki/Poisson_regression "Poisson regression")
......
...@@ -10,6 +10,8 @@ namespace LightGBM { ...@@ -10,6 +10,8 @@ namespace LightGBM {
Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config) { Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config) {
if (type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) { if (type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
return new L2Metric(config); return new L2Metric(config);
} else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return new RMSEMetric(config);
} else if (type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) { } else if (type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
return new L1Metric(config); return new L1Metric(config);
} else if (type == std::string("huber")) { } else if (type == std::string("huber")) {
......
...@@ -111,6 +111,25 @@ private: ...@@ -111,6 +111,25 @@ private:
std::vector<std::string> name_; std::vector<std::string> name_;
}; };
/*! \brief RMSE loss for regression task */
class RMSEMetric: public RegressionMetric<RMSEMetric> {
public:
explicit RMSEMetric(const MetricConfig& config) :RegressionMetric<RMSEMetric>(config) {}
inline static double LossOnPoint(float label, double score, double, double) {
return (score - label)*(score - label);
}
inline static double AverageLoss(double sum_loss, double sum_weights) {
// need sqrt the result for RMSE loss
return std::sqrt(sum_loss / sum_weights);
}
inline static const char* Name() {
return "rmse";
}
};
/*! \brief L2 loss for regression task */ /*! \brief L2 loss for regression task */
class L2Metric: public RegressionMetric<L2Metric> { class L2Metric: public RegressionMetric<L2Metric> {
public: public:
...@@ -121,8 +140,8 @@ public: ...@@ -121,8 +140,8 @@ public:
} }
inline static double AverageLoss(double sum_loss, double sum_weights) { inline static double AverageLoss(double sum_loss, double sum_weights) {
// need sqrt the result for L2 loss // need mean of the result for L2 loss
return std::sqrt(sum_loss / sum_weights); return sum_loss / sum_weights;
} }
inline static const char* Name() { inline static const char* Name() {
......
...@@ -71,8 +71,7 @@ class TestEngine(unittest.TestCase): ...@@ -71,8 +71,7 @@ class TestEngine(unittest.TestCase):
def test_regreesion(self): def test_regreesion(self):
evals_result, ret = template.test_template() evals_result, ret = template.test_template()
ret **= 0.5 self.assertLess(ret, 16)
self.assertLess(ret, 4)
self.assertAlmostEqual(min(evals_result['eval']['l2']), ret, places=5) self.assertAlmostEqual(min(evals_result['eval']['l2']), ret, places=5)
def test_multiclass(self): def test_multiclass(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment