Unverified Commit 4f28233b authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix multi-class objective (softmax) (#3256)



* Update multiclass_objective.hpp

* Apply suggestions from code review

* Update src/objective/multiclass_objective.hpp

* Apply suggestions from code review

* Update test_basic.R

* Update test_basic.R

* Update src/objective/multiclass_objective.hpp
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent b5027de3
...@@ -41,10 +41,10 @@ test_that("train and predict softmax", { ...@@ -41,10 +41,10 @@ test_that("train and predict softmax", {
data = as.matrix(iris[, -5L]) data = as.matrix(iris[, -5L])
, label = lb , label = lb
, num_leaves = 4L , num_leaves = 4L
, learning_rate = 0.1 , learning_rate = 0.05
, nrounds = 20L , nrounds = 20L
, min_data = 20L , min_data = 20L
, min_hessian = 20.0 , min_hessian = 10.0
, objective = "multiclass" , objective = "multiclass"
, metric = "multi_error" , metric = "multi_error"
, num_class = 3L , num_class = 3L
...@@ -53,7 +53,7 @@ test_that("train and predict softmax", { ...@@ -53,7 +53,7 @@ test_that("train and predict softmax", {
expect_false(is.null(bst$record_evals)) expect_false(is.null(bst$record_evals))
record_results <- lgb.get.eval.result(bst, "train", "multi_error") record_results <- lgb.get.eval.result(bst, "train", "multi_error")
expect_lt(min(record_results), 0.05) expect_lt(min(record_results), 0.06)
pred <- predict(bst, as.matrix(iris[, -5L])) pred <- predict(bst, as.matrix(iris[, -5L]))
expect_equal(length(pred), nrow(iris) * 3L) expect_equal(length(pred), nrow(iris) * 3L)
......
...@@ -25,6 +25,10 @@ class MulticlassSoftmax: public ObjectiveFunction { ...@@ -25,6 +25,10 @@ class MulticlassSoftmax: public ObjectiveFunction {
public: public:
explicit MulticlassSoftmax(const Config& config) { explicit MulticlassSoftmax(const Config& config) {
num_class_ = config.num_class; num_class_ = config.num_class;
// This factor is to rescale the redundant form of K-classification, to the non-redundant form.
// In the traditional settings of K-classification, there is one redundant class, whose output is set to 0 (like the class 0 in binary classification).
// This is from the Friedman GBDT paper.
factor_ = static_cast<double>(num_class_) / (num_class_ - 1.0f);
} }
explicit MulticlassSoftmax(const std::vector<std::string>& strs) { explicit MulticlassSoftmax(const std::vector<std::string>& strs) {
...@@ -40,6 +44,7 @@ class MulticlassSoftmax: public ObjectiveFunction { ...@@ -40,6 +44,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
if (num_class_ < 0) { if (num_class_ < 0) {
Log::Fatal("Objective should contain num_class field"); Log::Fatal("Objective should contain num_class field");
} }
factor_ = static_cast<double>(num_class_) / (num_class_ - 1.0f);
} }
~MulticlassSoftmax() { ~MulticlassSoftmax() {
...@@ -97,7 +102,7 @@ class MulticlassSoftmax: public ObjectiveFunction { ...@@ -97,7 +102,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
} else { } else {
gradients[idx] = static_cast<score_t>(p); gradients[idx] = static_cast<score_t>(p);
} }
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p)); hessians[idx] = static_cast<score_t>(factor_ * p * (1.0f - p));
} }
} }
} else { } else {
...@@ -118,7 +123,7 @@ class MulticlassSoftmax: public ObjectiveFunction { ...@@ -118,7 +123,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
} else { } else {
gradients[idx] = static_cast<score_t>((p) * weights_[i]); gradients[idx] = static_cast<score_t>((p) * weights_[i]);
} }
hessians[idx] = static_cast<score_t>((2.0f * p * (1.0f - p))* weights_[i]); hessians[idx] = static_cast<score_t>((factor_ * p * (1.0f - p))* weights_[i]);
} }
} }
} }
...@@ -161,6 +166,7 @@ class MulticlassSoftmax: public ObjectiveFunction { ...@@ -161,6 +166,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
} }
private: private:
double factor_;
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes */ /*! \brief Number of classes */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment