"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "a9df7f113f7e603ed06dcc04baf2bc55f28ba090"
Commit 7b4ead1e authored by Guolin Ke's avatar Guolin Ke
Browse files

convert the probabilities to raw score in boost_from_average of classification.

parent b38a19a4
...@@ -33,6 +33,10 @@ public: ...@@ -33,6 +33,10 @@ public:
virtual const char* GetName() const = 0; virtual const char* GetName() const = 0;
virtual std::vector<double> ConvertToRawScore(const std::vector<double>& preds) const {
return preds;
}
ObjectiveFunction() = default; ObjectiveFunction() = default;
/*! \brief Disable copy */ /*! \brief Disable copy */
ObjectiveFunction& operator=(const ObjectiveFunction&) = delete; ObjectiveFunction& operator=(const ObjectiveFunction&) = delete;
......
...@@ -306,13 +306,17 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -306,13 +306,17 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
sum_per_class[0] += label[i]; sum_per_class[0] += label[i];
} }
} }
std::vector<double > init_scores(num_class_);
for (int i = 0; i < num_class_; ++i) {
init_scores[i] = sum_per_class[i] / num_data_;
}
init_scores = object_function_->ConvertToRawScore(init_scores);
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
double init_score = sum_per_class[curr_class] / num_data_;
std::unique_ptr<Tree> new_tree(new Tree(2)); std::unique_ptr<Tree> new_tree(new Tree(2));
new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, init_score, init_score, 0, num_data_, 1); new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, init_scores[curr_class], init_scores[curr_class], 0, num_data_, 1);
train_score_updater_->AddScore(init_score, curr_class); train_score_updater_->AddScore(init_scores[curr_class], curr_class);
for (auto& score_updater : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(init_score, curr_class); score_updater->AddScore(init_scores[curr_class], curr_class);
} }
models_.push_back(std::move(new_tree)); models_.push_back(std::move(new_tree));
} }
......
...@@ -86,6 +86,18 @@ public: ...@@ -86,6 +86,18 @@ public:
} }
} }
std::vector<double> ConvertToRawScore(const std::vector<double>& preds) const override {
std::vector<double> ret;
for (auto pred : preds) {
if (pred > kEpsilon && pred < 1.0f) {
ret.push_back(-std::log(1.0f / pred - 1.0f) / sigmoid_);
} else {
ret.push_back(0.0f);
}
}
return ret;
}
const char* GetName() const override { const char* GetName() const override {
return "binary"; return "binary";
} }
......
...@@ -93,6 +93,18 @@ public: ...@@ -93,6 +93,18 @@ public:
} }
} }
std::vector<double> ConvertToRawScore(const std::vector<double>& preds) const override {
std::vector<double> ret;
for (auto pred : preds) {
if (pred > kEpsilon) {
ret.push_back(std::log(pred));
} else {
ret.push_back(0);
}
}
return ret;
}
const char* GetName() const override { const char* GetName() const override {
return "multiclass"; return "multiclass";
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment