Commit 27d3eb33 authored by Tsukasa OMOTO's avatar Tsukasa OMOTO Committed by Guolin Ke
Browse files

Fix huber loss (#178)

* fix typo

* fix hessians to approximate hessians with Gaussian function

* fix ApproximateHessianWithGaussian

* take fabs of gradient

* use atan(1) to calculate pi

* fix pi
parent 79431df8
......@@ -405,17 +405,18 @@ inline void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t s
* cf. https://en.wikipedia.org/wiki/Gaussian_function
*
* y is a prediction.
* t mesas true target.
* t means true target.
* g means gradient.
* w means weights.
*/
inline static double ApproximateHessianWithGaussian(double y, double t, double w=1.0f) {
inline static double ApproximateHessianWithGaussian(const double y, const double t, const double g, const double w=1.0f) {
const double diff = y - t;
const double pi = M_PI;
const double x = (std::fabs(diff) > 0.0) ? std::fabs(diff) : 1.0e-6;
const double a = 2.0 * w; // difference of two first derivatives, (zero to inf) and (zero to -inf).
const double pi = 4.0 * std::atan(1.0);
const double x = std::fabs(diff);
const double a = 2.0 * std::fabs(g) * w; // difference of two first derivatives, (zero to inf) and (zero to -inf).
const double b = 0.0;
const double c = (std::fabs(y) + std::fabs(t)) / 1.0e3;
return w * std::exp(-(x - b) * (x - b) / 2.0 * c * c) * a / std::sqrt(2.0 * pi) * c;
const double c = std::max(std::fabs(y) + std::fabs(t), 1.0e-10);
return w * std::exp(-(x - b) * (x - b) / (2.0 * c * c)) * a / (c * std::sqrt(2 * pi));
}
} // namespace Common
......
......@@ -13,7 +13,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
} else if (type == std::string("regression_l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
return new RegressionL1loss(config);
} else if (type == std::string("huber")) {
return new RegressionLHuberLoss(config);
return new RegressionHuberLoss(config);
} else if (type == std::string("binary")) {
return new BinaryLogloss(config);
} else if (type == std::string("lambdarank")) {
......
......@@ -76,7 +76,7 @@ public:
} else {
gradients[i] = -1.0;
}
hessians[i] = Common::ApproximateHessianWithGaussian(score[i], label_[i]);
hessians[i] = Common::ApproximateHessianWithGaussian(score[i], label_[i], gradients[i]);
}
} else {
#pragma omp parallel for schedule(static)
......@@ -87,7 +87,7 @@ public:
} else {
gradients[i] = -weights_[i];
}
hessians[i] = Common::ApproximateHessianWithGaussian(score[i], label_[i], weights_[i]);
hessians[i] = Common::ApproximateHessianWithGaussian(score[i], label_[i], gradients[i], weights_[i]);
}
}
}
......@@ -106,13 +106,13 @@ private:
};
class RegressionLHuberLoss: public ObjectiveFunction {
class RegressionHuberLoss: public ObjectiveFunction {
public:
explicit RegressionLHuberLoss(const ObjectiveConfig& config) {
explicit RegressionHuberLoss(const ObjectiveConfig& config) {
delta_ = config.huber_delta;
}
~RegressionLHuberLoss() {
~RegressionHuberLoss() {
}
void Init(const Metadata& metadata, data_size_t num_data) override {
......@@ -127,6 +127,7 @@ public:
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
const double diff = score[i] - label_[i];
if (std::abs(diff) <= delta_) {
gradients[i] = diff;
hessians[i] = 1.0;
......@@ -136,13 +137,14 @@ public:
} else {
gradients[i] = -delta_;
}
hessians[i] = 0.0;
hessians[i] = Common::ApproximateHessianWithGaussian(score[i], label_[i], gradients[i]);
}
}
} else {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
const double diff = score[i] - label_[i];
if (std::abs(diff) <= delta_) {
gradients[i] = diff * weights_[i];
hessians[i] = weights_[i];
......@@ -152,7 +154,7 @@ public:
} else {
gradients[i] = -delta_ * weights_[i];
}
hessians[i] = 0.0;
hessians[i] = Common::ApproximateHessianWithGaussian(score[i], label_[i], gradients[i], weights_[i]);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment