"python-package/vscode:/vscode.git/clone" did not exist on "b60068c810cbcac9cf4e1a8e678d8d531c40eb72"
Commit f40e0d2e authored by Guolin Ke's avatar Guolin Ke
Browse files

add weight decay in softmax loss.

parent 4398906d
...@@ -17,6 +17,7 @@ class MulticlassSoftmax: public ObjectiveFunction { ...@@ -17,6 +17,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
public: public:
explicit MulticlassSoftmax(const ObjectiveConfig& config) { explicit MulticlassSoftmax(const ObjectiveConfig& config) {
num_class_ = config.num_class; num_class_ = config.num_class;
softmax_weight_decay_ = 1e-3;
} }
~MulticlassSoftmax() { ~MulticlassSoftmax() {
...@@ -35,6 +36,7 @@ public: ...@@ -35,6 +36,7 @@ public:
Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]); Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
} }
} }
hessian_nor_ = static_cast<score_t>(num_class_) / (num_class_ - 1);
} }
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override { void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
...@@ -52,11 +54,11 @@ public: ...@@ -52,11 +54,11 @@ public:
auto p = rec[k]; auto p = rec[k];
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) { if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>(p - 1.0f); gradients[idx] = static_cast<score_t>(p - 1.0f + softmax_weight_decay_ * score[idx]);
} else { } else {
gradients[idx] = static_cast<score_t>(p); gradients[idx] = static_cast<score_t>(p + softmax_weight_decay_ * score[idx]);
} }
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p)); hessians[idx] = static_cast<score_t>(hessian_nor_ * p * (1.0f - p) + softmax_weight_decay_);
} }
} }
} else { } else {
...@@ -73,11 +75,11 @@ public: ...@@ -73,11 +75,11 @@ public:
auto p = rec[k]; auto p = rec[k];
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) { if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]); gradients[idx] = static_cast<score_t>((p - 1.0f + softmax_weight_decay_ * score[idx]) * weights_[i]);
} else { } else {
gradients[idx] = static_cast<score_t>(p * weights_[i]); gradients[idx] = static_cast<score_t>((p + softmax_weight_decay_ * score[idx]) * weights_[i]);
} }
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]); hessians[idx] = static_cast<score_t>((hessian_nor_ * p * (1.0f - p) + softmax_weight_decay_)* weights_[i]);
} }
} }
} }
...@@ -98,6 +100,8 @@ private: ...@@ -98,6 +100,8 @@ private:
std::vector<int> label_int_; std::vector<int> label_int_;
/*! \brief Weights for data */ /*! \brief Weights for data */
const float* weights_; const float* weights_;
double softmax_weight_decay_;
score_t hessian_nor_;
}; };
/*! /*!
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment