Commit fd5b5916 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix subset bug in 4bits_bin.

parent 5f5c61f3
...@@ -20,6 +20,7 @@ namespace LightGBM { ...@@ -20,6 +20,7 @@ namespace LightGBM {
#ifdef TIMETAG #ifdef TIMETAG
std::chrono::duration<double, std::milli> boosting_time; std::chrono::duration<double, std::milli> boosting_time;
std::chrono::duration<double, std::milli> train_score_time; std::chrono::duration<double, std::milli> train_score_time;
std::chrono::duration<double, std::milli> out_of_bag_score_time;
std::chrono::duration<double, std::milli> valid_score_time; std::chrono::duration<double, std::milli> valid_score_time;
std::chrono::duration<double, std::milli> metric_time; std::chrono::duration<double, std::milli> metric_time;
std::chrono::duration<double, std::milli> bagging_time; std::chrono::duration<double, std::milli> bagging_time;
...@@ -49,6 +50,7 @@ GBDT::~GBDT() { ...@@ -49,6 +50,7 @@ GBDT::~GBDT() {
#ifdef TIMETAG #ifdef TIMETAG
Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3); Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3); Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3); Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3);
Log::Info("GBDT::metric costs %f", metric_time * 1e-3); Log::Info("GBDT::metric costs %f", metric_time * 1e-3);
Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3); Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
...@@ -285,7 +287,7 @@ void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) { ...@@ -285,7 +287,7 @@ void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, curr_class); train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, curr_class);
} }
#ifdef TIMETAG #ifdef TIMETAG
train_score_time += std::chrono::steady_clock::now() - start_time; out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
} }
......
...@@ -26,7 +26,6 @@ public: ...@@ -26,7 +26,6 @@ public:
} }
~GOSS() { ~GOSS() {
} }
...@@ -37,7 +36,7 @@ public: ...@@ -37,7 +36,7 @@ public:
CHECK(gbdt_config_->top_rate + gbdt_config_->other_rate <= 1.0f); CHECK(gbdt_config_->top_rate + gbdt_config_->other_rate <= 1.0f);
CHECK(gbdt_config_->top_rate > 0.0f && gbdt_config_->other_rate > 0.0f); CHECK(gbdt_config_->top_rate > 0.0f && gbdt_config_->other_rate > 0.0f);
if (gbdt_config_->bagging_freq > 0 && gbdt_config_->bagging_fraction != 1.0f) { if (gbdt_config_->bagging_freq > 0 && gbdt_config_->bagging_fraction != 1.0f) {
Log::Fatal("cannot used bagging in GOSS"); Log::Fatal("cannot use bagging in GOSS");
} }
Log::Info("using GOSS"); Log::Info("using GOSS");
} }
...@@ -45,7 +44,7 @@ public: ...@@ -45,7 +44,7 @@ public:
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
if (config->bagging_freq > 0 && config->bagging_fraction != 1.0f) { if (config->bagging_freq > 0 && config->bagging_fraction != 1.0f) {
Log::Fatal("cannot used bagging in GOSS"); Log::Fatal("cannot use bagging in GOSS");
} }
GBDT::ResetTrainingData(config, train_data, object_function, training_metrics); GBDT::ResetTrainingData(config, train_data, object_function, training_metrics);
if (train_data_ == nullptr) { return; } if (train_data_ == nullptr) { return; }
...@@ -118,7 +117,7 @@ public: ...@@ -118,7 +117,7 @@ public:
// not subsample for first iterations // not subsample for first iterations
if (iter < static_cast<int>(1.0f / gbdt_config_->learning_rate)) { return; } if (iter < static_cast<int>(1.0f / gbdt_config_->learning_rate)) { return; }
const data_size_t min_inner_size = 1000; const data_size_t min_inner_size = 100;
data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_; data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
if (inner_size < min_inner_size) { inner_size = min_inner_size; } if (inner_size < min_inner_size) { inner_size = min_inner_size; }
......
...@@ -211,7 +211,10 @@ public: ...@@ -211,7 +211,10 @@ public:
const data_size_t idx = local_used_indices[i]; const data_size_t idx = local_used_indices[i];
const auto bin = (mem_data[idx >> 1] >> ((idx & 1) << 2)) & 0xf; const auto bin = (mem_data[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
// add // add
Push(0, i, bin); const int i1 = i >> 1;
const int i2 = (i & 1) << 2;
const uint8_t val = static_cast<uint8_t>(bin) << i2;
data_[i1] |= val;
} }
} else { } else {
for (size_t i = 0; i < data_.size(); ++i) { for (size_t i = 0; i < data_.size(); ++i) {
...@@ -225,7 +228,10 @@ public: ...@@ -225,7 +228,10 @@ public:
for (int i = 0; i < num_used_indices; ++i) { for (int i = 0; i < num_used_indices; ++i) {
const data_size_t idx = used_indices[i]; const data_size_t idx = used_indices[i];
const auto bin = (other_bin->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf; const auto bin = (other_bin->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
Push(0, i, bin); const int i1 = i >> 1;
const int i2 = (i & 1) << 2;
const uint8_t val = static_cast<uint8_t>(bin) << i2;
data_[i1] |= val;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment