Commit fd5b5916 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix subset bug in 4bits_bin.

parent 5f5c61f3
......@@ -20,6 +20,7 @@ namespace LightGBM {
#ifdef TIMETAG
std::chrono::duration<double, std::milli> boosting_time;
std::chrono::duration<double, std::milli> train_score_time;
std::chrono::duration<double, std::milli> out_of_bag_score_time;
std::chrono::duration<double, std::milli> valid_score_time;
std::chrono::duration<double, std::milli> metric_time;
std::chrono::duration<double, std::milli> bagging_time;
......@@ -49,6 +50,7 @@ GBDT::~GBDT() {
#ifdef TIMETAG
Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
Log::Info("GBDT::valid_score costs %f", valid_score_time * 1e-3);
Log::Info("GBDT::metric costs %f", metric_time * 1e-3);
Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
......@@ -285,7 +287,7 @@ void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, curr_class);
}
#ifdef TIMETAG
train_score_time += std::chrono::steady_clock::now() - start_time;
out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
#endif
}
......
......@@ -26,7 +26,6 @@ public:
}
~GOSS() {
}
......@@ -37,7 +36,7 @@ public:
CHECK(gbdt_config_->top_rate + gbdt_config_->other_rate <= 1.0f);
CHECK(gbdt_config_->top_rate > 0.0f && gbdt_config_->other_rate > 0.0f);
if (gbdt_config_->bagging_freq > 0 && gbdt_config_->bagging_fraction != 1.0f) {
Log::Fatal("cannot used bagging in GOSS");
Log::Fatal("cannot use bagging in GOSS");
}
Log::Info("using GOSS");
}
......@@ -45,7 +44,7 @@ public:
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) override {
if (config->bagging_freq > 0 && config->bagging_fraction != 1.0f) {
Log::Fatal("cannot used bagging in GOSS");
Log::Fatal("cannot use bagging in GOSS");
}
GBDT::ResetTrainingData(config, train_data, object_function, training_metrics);
if (train_data_ == nullptr) { return; }
......@@ -118,7 +117,7 @@ public:
// not subsample for first iterations
if (iter < static_cast<int>(1.0f / gbdt_config_->learning_rate)) { return; }
const data_size_t min_inner_size = 1000;
const data_size_t min_inner_size = 100;
data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
if (inner_size < min_inner_size) { inner_size = min_inner_size; }
......
......@@ -211,7 +211,10 @@ public:
const data_size_t idx = local_used_indices[i];
const auto bin = (mem_data[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
// add
Push(0, i, bin);
const int i1 = i >> 1;
const int i2 = (i & 1) << 2;
const uint8_t val = static_cast<uint8_t>(bin) << i2;
data_[i1] |= val;
}
} else {
for (size_t i = 0; i < data_.size(); ++i) {
......@@ -225,7 +228,10 @@ public:
for (int i = 0; i < num_used_indices; ++i) {
const data_size_t idx = used_indices[i];
const auto bin = (other_bin->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
Push(0, i, bin);
const int i1 = i >> 1;
const int i2 = (i & 1) << 2;
const uint8_t val = static_cast<uint8_t>(bin) << i2;
data_[i1] |= val;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment