Unverified Commit f1328d5c authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

Clear split info buffer in cost efficient gradient boosting before every...

Clear split info buffer in cost efficient gradient boosting before every iteration (fix partially #3679) (#5164)

* clear split info buffer in cegb_ before every iteration

* check nullable of cegb_ in serial_tree_learner.cpp

* add a test case for checking the split buffer in CEGB

* swith to Threading::For instead of raw OpenMP

* apply review suggestions

* apply review comments

* remove device cpu
parent 27d9ad2e
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <LightGBM/dataset.h> #include <LightGBM/dataset.h>
#include <LightGBM/utils/common.h> #include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h> #include <LightGBM/utils/log.h>
#include <LightGBM/utils/threading.h>
#include <vector> #include <vector>
...@@ -32,6 +33,7 @@ class CostEfficientGradientBoosting { ...@@ -32,6 +33,7 @@ class CostEfficientGradientBoosting {
return true; return true;
} }
} }
void Init() { void Init() {
auto train_data = tree_learner_->train_data_; auto train_data = tree_learner_->train_data_;
if (!init_) { if (!init_) {
...@@ -63,6 +65,17 @@ class CostEfficientGradientBoosting { ...@@ -63,6 +65,17 @@ class CostEfficientGradientBoosting {
} }
init_ = true; init_ = true;
} }
void BeforeTrain() {
// clear the splits in splits_per_leaf_
Threading::For<size_t>(0, splits_per_leaf_.size(), 1024,
[this] (int /*thread_index*/, size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
splits_per_leaf_[i].Reset();
}
});
}
double DeltaGain(int feature_index, int real_fidx, int leaf_index, double DeltaGain(int feature_index, int real_fidx, int leaf_index,
int num_data_in_leaf, SplitInfo split_info) { int num_data_in_leaf, SplitInfo split_info) {
auto config = tree_learner_->config_; auto config = tree_learner_->config_;
...@@ -82,6 +95,7 @@ class CostEfficientGradientBoosting { ...@@ -82,6 +95,7 @@ class CostEfficientGradientBoosting {
feature_index] = split_info; feature_index] = split_info;
return delta; return delta;
} }
void UpdateLeafBestSplits(Tree* tree, int best_leaf, void UpdateLeafBestSplits(Tree* tree, int best_leaf,
const SplitInfo* best_split_info, const SplitInfo* best_split_info,
std::vector<SplitInfo>* best_split_per_leaf) { std::vector<SplitInfo>* best_split_per_leaf) {
......
...@@ -278,6 +278,10 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -278,6 +278,10 @@ void SerialTreeLearner::BeforeTrain() {
} }
larger_leaf_splits_->Init(); larger_leaf_splits_->Init();
if (cegb_ != nullptr) {
cegb_->BeforeTrain();
}
} }
bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) { bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
......
...@@ -3578,3 +3578,48 @@ def test_boost_from_average_with_single_leaf_trees(): ...@@ -3578,3 +3578,48 @@ def test_boost_from_average_with_single_leaf_trees():
preds = model.predict(X) preds = model.predict(X)
mean_preds = np.mean(preds) mean_preds = np.mean(preds)
assert y.min() <= mean_preds <= y.max() assert y.min() <= mean_preds <= y.max()
def test_cegb_split_buffer_clean():
# modified from https://github.com/microsoft/LightGBM/issues/3679#issuecomment-938652811
# and https://github.com/microsoft/LightGBM/pull/5087
# test that the ``splits_per_leaf_`` of CEGB is cleaned before training a new tree
# which is done in the fix #5164
# without the fix:
# Check failed: (best_split_info.left_count) > (0)
R, C = 1000, 100
seed = 29
np.random.seed(seed)
data = np.random.randn(R, C)
for i in range(1, C):
data[i] += data[0] * np.random.randn()
N = int(0.8 * len(data))
train_data = data[:N]
test_data = data[N:]
train_y = np.sum(train_data, axis=1)
test_y = np.sum(test_data, axis=1)
train = lgb.Dataset(train_data, train_y, free_raw_data=True)
params = {
'boosting_type': 'gbdt',
'objective': 'regression',
'max_bin': 255,
'num_leaves': 31,
'seed': 0,
'learning_rate': 0.1,
'min_data_in_leaf': 0,
'verbose': -1,
'min_split_gain': 1000.0,
'cegb_penalty_feature_coupled': 5 * np.arange(C),
'cegb_penalty_split': 0.0002,
'cegb_tradeoff': 10.0,
'force_col_wise': True,
}
model = lgb.train(params, train, num_boost_round=10)
predicts = model.predict(test_data)
rmse = np.sqrt(mean_squared_error(test_y, predicts))
assert rmse < 10.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment