Unverified Commit 4f28233b authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix multi-class objective (softmax) (#3256)



* Update multiclass_objective.hpp

* Apply suggestions from code review

* Update src/objective/multiclass_objective.hpp

* Apply suggestions from code review

* Update test_basic.R

* Update test_basic.R

* Update src/objective/multiclass_objective.hpp
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent b5027de3
......@@ -41,10 +41,10 @@ test_that("train and predict softmax", {
data = as.matrix(iris[, -5L])
, label = lb
, num_leaves = 4L
, learning_rate = 0.1
, learning_rate = 0.05
, nrounds = 20L
, min_data = 20L
, min_hessian = 20.0
, min_hessian = 10.0
, objective = "multiclass"
, metric = "multi_error"
, num_class = 3L
......@@ -53,7 +53,7 @@ test_that("train and predict softmax", {
expect_false(is.null(bst$record_evals))
record_results <- lgb.get.eval.result(bst, "train", "multi_error")
expect_lt(min(record_results), 0.05)
expect_lt(min(record_results), 0.06)
pred <- predict(bst, as.matrix(iris[, -5L]))
expect_equal(length(pred), nrow(iris) * 3L)
......
......@@ -25,6 +25,10 @@ class MulticlassSoftmax: public ObjectiveFunction {
public:
explicit MulticlassSoftmax(const Config& config) {
num_class_ = config.num_class;
// This factor is to rescale the redundant form of K-classification, to the non-redundant form.
// In the traditional settings of K-classification, there is one redundant class, whose output is set to 0 (like the class 0 in binary classification).
// This is from the Friedman GBDT paper.
factor_ = static_cast<double>(num_class_) / (num_class_ - 1.0f);
}
explicit MulticlassSoftmax(const std::vector<std::string>& strs) {
......@@ -40,6 +44,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
if (num_class_ < 0) {
Log::Fatal("Objective should contain num_class field");
}
factor_ = static_cast<double>(num_class_) / (num_class_ - 1.0f);
}
~MulticlassSoftmax() {
......@@ -97,7 +102,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
} else {
gradients[idx] = static_cast<score_t>(p);
}
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p));
hessians[idx] = static_cast<score_t>(factor_ * p * (1.0f - p));
}
}
} else {
......@@ -118,7 +123,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
} else {
gradients[idx] = static_cast<score_t>((p) * weights_[i]);
}
hessians[idx] = static_cast<score_t>((2.0f * p * (1.0f - p))* weights_[i]);
hessians[idx] = static_cast<score_t>((factor_ * p * (1.0f - p))* weights_[i]);
}
}
}
......@@ -161,6 +166,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
}
private:
double factor_;
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment