Unverified Commit c7e90393 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support most frequent bin (#2689)

* implement

* fix warning

* fix bug

* fix a bug

* remove unneed function

* fix data push bug

* fix valid data push

* fix bug for missing_type=zero

* refine split

* renames

* typo
parent 82886ba6
......@@ -117,6 +117,7 @@ class BinMapper {
return bin_2_categorical_[bin];
}
}
/*!
* \brief Get sizes in byte of this object
*/
......@@ -135,6 +136,11 @@ class BinMapper {
inline uint32_t GetDefaultBin() const {
return default_bin_;
}
inline uint32_t GetMostFreqBin() const {
return most_freq_bin_;
}
/*!
* \brief Construct feature value to bin mapper according feature values
* \param values (Sampled) values of this feature, Note: not include zero.
......@@ -211,6 +217,8 @@ class BinMapper {
double max_val_;
/*! \brief bin value of feature value 0 */
uint32_t default_bin_;
uint32_t most_freq_bin_;
};
/*!
......@@ -306,10 +314,10 @@ class Bin {
* \brief Get bin iterator of this bin for specific feature
* \param min_bin min_bin of current used feature
* \param max_bin max_bin of current used feature
* \param default_bin default bin if bin not in [min_bin, max_bin]
* \param most_freq_bin
* \return Iterator of this bin
*/
virtual BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const = 0;
virtual BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const = 0;
/*!
* \brief Save binary data to file
......@@ -381,7 +389,8 @@ class Bin {
* \brief Split data according to threshold, if bin <= threshold, will put into left(lte_indices), else put into right(gt_indices)
* \param min_bin min_bin of current used feature
* \param max_bin max_bin of current used feature
* \param default_bin default bin if bin not in [min_bin, max_bin]
* \param default_bin default bin for feature value 0
* \param most_freq_bin
* \param missing_type missing type
* \param default_left missing bin will go to left child
* \param threshold The split threshold.
......@@ -392,7 +401,7 @@ class Bin {
* \return The number of less than or equal data.
*/
virtual data_size_t Split(uint32_t min_bin, uint32_t max_bin,
uint32_t default_bin, MissingType missing_type, bool default_left, uint32_t threshold,
uint32_t default_bin, uint32_t most_freq_bin, MissingType missing_type, bool default_left, uint32_t threshold,
data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const = 0;
......@@ -400,7 +409,7 @@ class Bin {
* \brief Split data according to threshold, if bin <= threshold, will put into left(lte_indices), else put into right(gt_indices)
* \param min_bin min_bin of current used feature
* \param max_bin max_bin of current used feature
* \param default_bin default bin if bin not in [min_bin, max_bin]
* \param most_freq_bin
* \param threshold The split threshold.
* \param num_threshold Number of threshold
* \param data_indices Used data indices. After called this function. The less than or equal data indices will store on this object.
......@@ -410,7 +419,7 @@ class Bin {
* \return The number of less than or equal data.
*/
virtual data_size_t SplitCategorical(uint32_t min_bin, uint32_t max_bin,
uint32_t default_bin, const uint32_t* threshold, int num_threshold,
uint32_t most_freq_bin, const uint32_t* threshold, int num_threshold,
data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const = 0;
......@@ -433,7 +442,6 @@ class Bin {
* \param is_enable_sparse True if enable sparse feature
* \param sparse_threshold Threshold for treating a feature as a sparse feature
* \param is_sparse Will set to true if this bin is sparse
* \param default_bin Default bin for zeros value
* \return The bin data object
*/
static Bin* CreateBin(data_size_t num_data, int num_bin,
......
......@@ -293,6 +293,7 @@ class Dataset {
int num_total_features,
const std::vector<std::vector<double>>& forced_bins,
int** sample_non_zero_indices,
double** sample_values,
const int* num_per_col,
int num_sample_col,
size_t total_sample_cnt,
......@@ -319,6 +320,16 @@ class Dataset {
return true;
}
inline void FinishOneRow(int tid, data_size_t row_idx, const std::vector<bool>& is_feature_added) {
if (is_finish_load_) { return; }
for (auto fidx : feature_need_push_zeros_) {
if (is_feature_added[fidx]) { continue; }
const int group = feature2group_[fidx];
const int sub_feature = feature2subfeature_[fidx];
feature_groups_[group]->PushData(tid, sub_feature, row_idx, 0.0f);
}
}
inline void PushOneRow(int tid, data_size_t row_idx, const std::vector<double>& feature_values) {
if (is_finish_load_) { return; }
for (size_t i = 0; i < feature_values.size() && i < static_cast<size_t>(num_total_features_); ++i) {
......@@ -333,15 +344,18 @@ class Dataset {
inline void PushOneRow(int tid, data_size_t row_idx, const std::vector<std::pair<int, double>>& feature_values) {
if (is_finish_load_) { return; }
std::vector<bool> is_feature_added(num_features_, false);
for (auto& inner_data : feature_values) {
if (inner_data.first >= num_total_features_) { continue; }
int feature_idx = used_feature_map_[inner_data.first];
if (feature_idx >= 0) {
is_feature_added[feature_idx] = true;
const int group = feature2group_[feature_idx];
const int sub_feature = feature2subfeature_[feature_idx];
feature_groups_[group]->PushData(tid, sub_feature, row_idx, inner_data.second);
}
}
FinishOneRow(tid, row_idx, is_feature_added);
}
inline void PushOneData(int tid, data_size_t row_idx, int group, int sub_feature, double value) {
......@@ -647,6 +661,7 @@ class Dataset {
int min_data_in_bin_;
bool use_missing_;
bool zero_as_missing_;
std::vector<int> feature_need_push_zeros_;
};
} // namespace LightGBM
......
......@@ -34,14 +34,14 @@ class FeatureGroup {
std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
data_size_t num_data, double sparse_threshold, bool is_enable_sparse) : num_feature_(num_feature) {
CHECK(static_cast<int>(bin_mappers->size()) == num_feature);
// use bin at zero to store default_bin
// use bin at zero to store most_freq_bin
num_total_bin_ = 1;
bin_offsets_.emplace_back(num_total_bin_);
int cnt_non_zero = 0;
for (int i = 0; i < num_feature_; ++i) {
bin_mappers_.emplace_back(bin_mappers->at(i).release());
auto num_bin = bin_mappers_[i]->num_bin();
if (bin_mappers_[i]->GetDefaultBin() == 0) {
if (bin_mappers_[i]->GetMostFreqBin() == 0) {
num_bin -= 1;
}
num_total_bin_ += num_bin;
......@@ -57,13 +57,13 @@ class FeatureGroup {
std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
data_size_t num_data, bool is_sparse) : num_feature_(num_feature) {
CHECK(static_cast<int>(bin_mappers->size()) == num_feature);
// use bin at zero to store default_bin
// use bin at zero to store most_freq_bin
num_total_bin_ = 1;
bin_offsets_.emplace_back(num_total_bin_);
for (int i = 0; i < num_feature_; ++i) {
bin_mappers_.emplace_back(bin_mappers->at(i).release());
auto num_bin = bin_mappers_[i]->num_bin();
if (bin_mappers_[i]->GetDefaultBin() == 0) {
if (bin_mappers_[i]->GetMostFreqBin() == 0) {
num_bin -= 1;
}
num_total_bin_ += num_bin;
......@@ -99,7 +99,7 @@ class FeatureGroup {
for (int i = 0; i < num_feature_; ++i) {
bin_mappers_.emplace_back(new BinMapper(memory_ptr));
auto num_bin = bin_mappers_[i]->num_bin();
if (bin_mappers_[i]->GetDefaultBin() == 0) {
if (bin_mappers_[i]->GetMostFreqBin() == 0) {
num_bin -= 1;
}
num_total_bin_ += num_bin;
......@@ -130,9 +130,9 @@ class FeatureGroup {
*/
inline void PushData(int tid, int sub_feature_idx, data_size_t line_idx, double value) {
uint32_t bin = bin_mappers_[sub_feature_idx]->ValueToBin(value);
if (bin == bin_mappers_[sub_feature_idx]->GetDefaultBin()) { return; }
if (bin == bin_mappers_[sub_feature_idx]->GetMostFreqBin()) { return; }
bin += bin_offsets_[sub_feature_idx];
if (bin_mappers_[sub_feature_idx]->GetDefaultBin() == 0) {
if (bin_mappers_[sub_feature_idx]->GetMostFreqBin() == 0) {
bin -= 1;
}
bin_data_->Push(tid, line_idx, bin);
......@@ -145,8 +145,8 @@ class FeatureGroup {
inline BinIterator* SubFeatureIterator(int sub_feature) {
uint32_t min_bin = bin_offsets_[sub_feature];
uint32_t max_bin = bin_offsets_[sub_feature + 1] - 1;
uint32_t default_bin = bin_mappers_[sub_feature]->GetDefaultBin();
return bin_data_->GetIterator(min_bin, max_bin, default_bin);
uint32_t most_freq_bin = bin_mappers_[sub_feature]->GetMostFreqBin();
return bin_data_->GetIterator(min_bin, max_bin, most_freq_bin);
}
/*!
......@@ -157,8 +157,8 @@ class FeatureGroup {
inline BinIterator* FeatureGroupIterator() {
uint32_t min_bin = bin_offsets_[0];
uint32_t max_bin = bin_offsets_.back() - 1;
uint32_t default_bin = 0;
return bin_data_->GetIterator(min_bin, max_bin, default_bin);
uint32_t most_freq_bin = 0;
return bin_data_->GetIterator(min_bin, max_bin, most_freq_bin);
}
inline data_size_t Split(
......@@ -172,12 +172,13 @@ class FeatureGroup {
uint32_t min_bin = bin_offsets_[sub_feature];
uint32_t max_bin = bin_offsets_[sub_feature + 1] - 1;
uint32_t default_bin = bin_mappers_[sub_feature]->GetDefaultBin();
uint32_t most_freq_bin = bin_mappers_[sub_feature]->GetMostFreqBin();
if (bin_mappers_[sub_feature]->bin_type() == BinType::NumericalBin) {
auto missing_type = bin_mappers_[sub_feature]->missing_type();
return bin_data_->Split(min_bin, max_bin, default_bin, missing_type, default_left,
return bin_data_->Split(min_bin, max_bin, default_bin, most_freq_bin, missing_type, default_left,
*threshold, data_indices, num_data, lte_indices, gt_indices);
} else {
return bin_data_->SplitCategorical(min_bin, max_bin, default_bin, threshold, num_threshold, data_indices, num_data, lte_indices, gt_indices);
return bin_data_->SplitCategorical(min_bin, max_bin, most_freq_bin, threshold, num_threshold, data_indices, num_data, lte_indices, gt_indices);
}
}
/*!
......
......@@ -889,13 +889,21 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int group = ret->Feature2Group(feature_idx);
int sub_feature = ret->Feture2SubFeature(feature_idx);
CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
int row_idx = 0;
while (row_idx < nrow) {
auto pair = col_it.NextNonZero();
row_idx = pair.first;
// no more data
if (row_idx < 0) { break; }
ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
auto bin_mapper = ret->FeatureBinMapper(feature_idx);
if (bin_mapper->GetDefaultBin() == bin_mapper->GetMostFreqBin()) {
int row_idx = 0;
while (row_idx < nrow) {
auto pair = col_it.NextNonZero();
row_idx = pair.first;
// no more data
if (row_idx < 0) { break; }
ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
}
} else {
for (int row_idx = 0; row_idx < nrow; ++row_idx) {
auto val = col_it.Get(row_idx);
ret->PushOneData(tid, row_idx, group, sub_feature, val);
}
}
OMP_LOOP_EX_END();
}
......
......@@ -4,6 +4,7 @@
*/
#include <LightGBM/bin.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/file_io.h>
......@@ -40,6 +41,7 @@ namespace LightGBM {
min_val_ = other.min_val_;
max_val_ = other.max_val_;
default_bin_ = other.default_bin_;
most_freq_bin_ = other.most_freq_bin_;
}
BinMapper::BinMapper(const void* memory) {
......@@ -512,8 +514,15 @@ namespace LightGBM {
}
}
if (!is_trivial_) {
most_freq_bin_ = static_cast<uint32_t>(ArrayArgs<int>::ArgMax(cnt_in_bin));
// calculate sparse rate
sparse_rate_ = static_cast<double>(cnt_in_bin[default_bin_]) / static_cast<double>(total_sample_cnt);
sparse_rate_ = static_cast<double>(cnt_in_bin[default_bin_]) / total_sample_cnt;
const double max_sparse_rate = static_cast<double>(cnt_in_bin[most_freq_bin_]) / total_sample_cnt;
if (most_freq_bin_ != default_bin_ && max_sparse_rate > 0.7f) {
sparse_rate_ = max_sparse_rate;
} else {
most_freq_bin_ = default_bin_;
}
} else {
sparse_rate_ = 1.0f;
}
......@@ -529,7 +538,7 @@ namespace LightGBM {
size += sizeof(BinType);
size += 2 * sizeof(double);
size += bin * sizeof(double);
size += sizeof(uint32_t);
size += sizeof(uint32_t) * 2;
return size;
}
......@@ -550,6 +559,8 @@ namespace LightGBM {
buffer += sizeof(max_val_);
std::memcpy(buffer, &default_bin_, sizeof(default_bin_));
buffer += sizeof(default_bin_);
std::memcpy(buffer, &most_freq_bin_, sizeof(most_freq_bin_));
buffer += sizeof(most_freq_bin_);
if (bin_type_ == BinType::NumericalBin) {
std::memcpy(buffer, bin_upper_bound_.data(), num_bin_ * sizeof(double));
} else {
......@@ -574,6 +585,8 @@ namespace LightGBM {
buffer += sizeof(max_val_);
std::memcpy(&default_bin_, buffer, sizeof(default_bin_));
buffer += sizeof(default_bin_);
std::memcpy(&most_freq_bin_, buffer, sizeof(most_freq_bin_));
buffer += sizeof(most_freq_bin_);
if (bin_type_ == BinType::NumericalBin) {
bin_upper_bound_ = std::vector<double>(num_bin_);
std::memcpy(bin_upper_bound_.data(), buffer, num_bin_ * sizeof(double));
......@@ -596,6 +609,7 @@ namespace LightGBM {
writer->Write(&min_val_, sizeof(min_val_));
writer->Write(&max_val_, sizeof(max_val_));
writer->Write(&default_bin_, sizeof(default_bin_));
writer->Write(&most_freq_bin_, sizeof(most_freq_bin_));
if (bin_type_ == BinType::NumericalBin) {
writer->Write(bin_upper_bound_.data(), sizeof(double) * num_bin_);
} else {
......@@ -605,7 +619,7 @@ namespace LightGBM {
size_t BinMapper::SizesInByte() const {
size_t ret = sizeof(num_bin_) + sizeof(missing_type_) + sizeof(is_trivial_) + sizeof(sparse_rate_)
+ sizeof(bin_type_) + sizeof(min_val_) + sizeof(max_val_) + sizeof(default_bin_);
+ sizeof(bin_type_) + sizeof(min_val_) + sizeof(max_val_) + sizeof(default_bin_) + sizeof(most_freq_bin_);
if (bin_type_ == BinType::NumericalBin) {
ret += sizeof(double) * num_bin_;
} else {
......
......@@ -67,6 +67,27 @@ void MarkUsed(std::vector<bool>* mark, const int* indices, int num_indices) {
}
}
std::vector<int> FixSampleIndices(const BinMapper* bin_mapper, int num_total_samples, int num_indices, const int* sample_indices, const double* sample_values) {
std::vector<int> ret;
if (bin_mapper->GetDefaultBin() == bin_mapper->GetMostFreqBin()) {
return ret;
}
int i = 0, j = 0;
while (i < num_total_samples) {
if (j < num_indices && sample_indices[j] < i) {
++j;
} else if (j < num_indices && sample_indices[j] == i) {
if (bin_mapper->ValueToBin(sample_values[j]) != bin_mapper->GetMostFreqBin()) {
ret.push_back(i);
}
++i;
} else {
ret.push_back(i++);
}
}
return ret;
}
std::vector<std::vector<int>> FindGroups(const std::vector<std::unique_ptr<BinMapper>>& bin_mappers,
const std::vector<int>& find_order,
int** sample_indices,
......@@ -147,6 +168,7 @@ std::vector<std::vector<int>> FindGroups(const std::vector<std::unique_ptr<BinMa
std::vector<std::vector<int>> FastFeatureBundling(const std::vector<std::unique_ptr<BinMapper>>& bin_mappers,
int** sample_indices,
double** sample_values,
const int* num_per_col,
int num_sample_col,
size_t total_sample_cnt,
......@@ -187,8 +209,23 @@ std::vector<std::vector<int>> FastFeatureBundling(const std::vector<std::unique_
for (auto sidx : sorted_idx) {
feature_order_by_cnt.push_back(used_features[sidx]);
}
auto features_in_group = FindGroups(bin_mappers, used_features, sample_indices, num_per_col, num_sample_col, total_sample_cnt, max_error_cnt, filter_cnt, num_data, is_use_gpu);
auto group2 = FindGroups(bin_mappers, feature_order_by_cnt, sample_indices, num_per_col, num_sample_col, total_sample_cnt, max_error_cnt, filter_cnt, num_data, is_use_gpu);
std::vector<std::vector<int>> tmp_indices;
std::vector<int> tmp_num_per_col(num_sample_col, 0);
for (auto fidx : used_features) {
if (fidx >= num_sample_col) {
continue;
}
auto ret = FixSampleIndices(bin_mappers[fidx].get(), static_cast<int>(total_sample_cnt), num_per_col[fidx], sample_indices[fidx], sample_values[fidx]);
if (!ret.empty()) {
tmp_indices.push_back(ret);
tmp_num_per_col[fidx] = static_cast<int>(ret.size());
sample_indices[fidx] = tmp_indices.back().data();
} else {
tmp_num_per_col[fidx] = num_per_col[fidx];
}
}
auto features_in_group = FindGroups(bin_mappers, used_features, sample_indices, tmp_num_per_col.data(), num_sample_col, total_sample_cnt, max_error_cnt, filter_cnt, num_data, is_use_gpu);
auto group2 = FindGroups(bin_mappers, feature_order_by_cnt, sample_indices, tmp_num_per_col.data(), num_sample_col, total_sample_cnt, max_error_cnt, filter_cnt, num_data, is_use_gpu);
if (features_in_group.size() > group2.size()) {
features_in_group = group2;
}
......@@ -230,6 +267,7 @@ void Dataset::Construct(
int num_total_features,
const std::vector<std::vector<double>>& forced_bins,
int** sample_non_zero_indices,
double** sample_values,
const int* num_per_col,
int num_sample_col,
size_t total_sample_cnt,
......@@ -252,7 +290,7 @@ void Dataset::Construct(
if (io_config.enable_bundle && !used_features.empty()) {
features_in_group = FastFeatureBundling(*bin_mappers,
sample_non_zero_indices, num_per_col, num_sample_col, total_sample_cnt,
sample_non_zero_indices, sample_values, num_per_col, num_sample_col, total_sample_cnt,
used_features, io_config.max_conflict_rate,
num_data_, io_config.min_data_in_leaf,
sparse_threshold_, io_config.is_enable_sparse, io_config.device_type == std::string("gpu"));
......@@ -268,6 +306,7 @@ void Dataset::Construct(
real_feature_idx_.resize(num_features_);
feature2group_.resize(num_features_);
feature2subfeature_.resize(num_features_);
feature_need_push_zeros_.clear();
for (int i = 0; i < num_groups_; ++i) {
auto cur_features = features_in_group[i];
int cur_cnt_features = static_cast<int>(cur_features.size());
......@@ -280,6 +319,9 @@ void Dataset::Construct(
feature2group_[cur_fidx] = i;
feature2subfeature_[cur_fidx] = j;
cur_bin_mappers.emplace_back(ref_bin_mappers[real_fidx].release());
if (cur_bin_mappers.back()->GetDefaultBin() != cur_bin_mappers.back()->GetMostFreqBin()) {
feature_need_push_zeros_.push_back(cur_fidx);
}
++cur_fidx;
}
feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
......@@ -453,6 +495,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) {
monotone_types_ = dataset->monotone_types_;
feature_penalty_ = dataset->feature_penalty_;
forced_bin_bounds_ = dataset->forced_bin_bounds_;
feature_need_push_zeros_ = dataset->feature_need_push_zeros_;
}
void Dataset::CreateValid(const Dataset* dataset) {
......@@ -464,9 +507,13 @@ void Dataset::CreateValid(const Dataset* dataset) {
feature2group_.clear();
feature2subfeature_.clear();
// copy feature bin mapper data
feature_need_push_zeros_.clear();
for (int i = 0; i < num_features_; ++i) {
std::vector<std::unique_ptr<BinMapper>> bin_mappers;
bin_mappers.emplace_back(new BinMapper(*(dataset->FeatureBinMapper(i))));
if (bin_mappers.back()->GetDefaultBin() != bin_mappers.back()->GetMostFreqBin()) {
feature_need_push_zeros_.push_back(i);
}
feature_groups_.emplace_back(new FeatureGroup(
1,
&bin_mappers,
......@@ -812,7 +859,7 @@ void Dataset::DumpTextFile(const char* text_filename) {
if (inner_feature_idx < 0) {
fprintf(file, "NA, ");
} else {
fprintf(file, "%d, ", iterators[inner_feature_idx]->RawGet(i));
fprintf(file, "%d, ", iterators[inner_feature_idx]->Get(i));
}
}
}
......@@ -999,17 +1046,17 @@ void Dataset::FixHistogram(int feature_idx, double sum_gradient, double sum_hess
const int group = feature2group_[feature_idx];
const int sub_feature = feature2subfeature_[feature_idx];
const BinMapper* bin_mapper = feature_groups_[group]->bin_mappers_[sub_feature].get();
const int default_bin = bin_mapper->GetDefaultBin();
if (default_bin > 0) {
const int most_freq_bin = bin_mapper->GetMostFreqBin();
if (most_freq_bin > 0) {
const int num_bin = bin_mapper->num_bin();
data[default_bin].sum_gradients = sum_gradient;
data[default_bin].sum_hessians = sum_hessian;
data[default_bin].cnt = num_data;
data[most_freq_bin].sum_gradients = sum_gradient;
data[most_freq_bin].sum_hessians = sum_hessian;
data[most_freq_bin].cnt = num_data;
for (int i = 0; i < num_bin; ++i) {
if (i != default_bin) {
data[default_bin].sum_gradients -= data[i].sum_gradients;
data[default_bin].sum_hessians -= data[i].sum_hessians;
data[default_bin].cnt -= data[i].cnt;
if (i != most_freq_bin) {
data[most_freq_bin].sum_gradients -= data[i].sum_gradients;
data[most_freq_bin].sum_hessians -= data[i].sum_hessians;
data[most_freq_bin].cnt -= data[i].cnt;
}
}
}
......
......@@ -717,7 +717,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
}
}
auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, num_per_col, num_col, total_sample_size, config_);
dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
dataset->set_feature_names(feature_names_);
return dataset.release();
}
......@@ -1040,8 +1040,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
cp_ptr += bin_mappers[i]->SizesInByte();
}
}
sample_values.clear();
dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Common::Vector2Ptr<double>(&sample_values).data(),
Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
}
......@@ -1066,11 +1066,13 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat
ref_text_data[i].clear();
// shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
// text_reader_->Lines()[i].shrink_to_fit();
std::vector<bool> is_feature_added(dataset->num_features_, false);
// push data
for (auto& inner_data : oneline_features) {
if (inner_data.first >= dataset->num_total_features_) { continue; }
int feature_idx = dataset->used_feature_map_[inner_data.first];
if (feature_idx >= 0) {
is_feature_added[feature_idx] = true;
// if is used feature
int group = dataset->feature2group_[feature_idx];
int sub_feature = dataset->feature2subfeature_[feature_idx];
......@@ -1083,6 +1085,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat
}
}
}
dataset->FinishOneRow(tid, i, is_feature_added);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
......@@ -1110,10 +1113,12 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat
// shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
// text_reader_->Lines()[i].shrink_to_fit();
// push data
std::vector<bool> is_feature_added(dataset->num_features_, false);
for (auto& inner_data : oneline_features) {
if (inner_data.first >= dataset->num_total_features_) { continue; }
int feature_idx = dataset->used_feature_map_[inner_data.first];
if (feature_idx >= 0) {
is_feature_added[feature_idx] = true;
// if is used feature
int group = dataset->feature2group_[feature_idx];
int sub_feature = dataset->feature2subfeature_[feature_idx];
......@@ -1126,6 +1131,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat
}
}
}
dataset->FinishOneRow(tid, i, is_feature_added);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
......@@ -1167,11 +1173,13 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
}
// set label
dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
std::vector<bool> is_feature_added(dataset->num_features_, false);
// push data
for (auto& inner_data : oneline_features) {
if (inner_data.first >= dataset->num_total_features_) { continue; }
int feature_idx = dataset->used_feature_map_[inner_data.first];
if (feature_idx >= 0) {
is_feature_added[feature_idx] = true;
// if is used feature
int group = dataset->feature2group_[feature_idx];
int sub_feature = dataset->feature2subfeature_[feature_idx];
......@@ -1184,6 +1192,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
}
}
}
dataset->FinishOneRow(tid, i, is_feature_added);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
......
......@@ -19,11 +19,11 @@ class DenseBin;
template <typename VAL_T>
class DenseBinIterator: public BinIterator {
public:
explicit DenseBinIterator(const DenseBin<VAL_T>* bin_data, uint32_t min_bin, uint32_t max_bin, uint32_t default_bin)
explicit DenseBinIterator(const DenseBin<VAL_T>* bin_data, uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin)
: bin_data_(bin_data), min_bin_(static_cast<VAL_T>(min_bin)),
max_bin_(static_cast<VAL_T>(max_bin)),
default_bin_(static_cast<VAL_T>(default_bin)) {
if (default_bin_ == 0) {
most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
if (most_freq_bin_ == 0) {
offset_ = 1;
} else {
offset_ = 0;
......@@ -37,7 +37,7 @@ class DenseBinIterator: public BinIterator {
const DenseBin<VAL_T>* bin_data_;
VAL_T min_bin_;
VAL_T max_bin_;
VAL_T default_bin_;
VAL_T most_freq_bin_;
uint8_t offset_;
};
/*!
......@@ -66,7 +66,7 @@ class DenseBin: public Bin {
}
}
BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const override;
BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const override;
void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end,
const score_t* ordered_gradients, const score_t* ordered_hessians,
......@@ -128,7 +128,7 @@ class DenseBin: public Bin {
}
data_size_t Split(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, MissingType missing_type, bool default_left,
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t most_freq_bin, MissingType missing_type, bool default_left,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; }
......@@ -136,21 +136,23 @@ class DenseBin: public Bin {
const VAL_T minb = static_cast<VAL_T>(min_bin);
const VAL_T maxb = static_cast<VAL_T>(max_bin);
VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin);
if (default_bin == 0) {
VAL_T t_most_freq_bin = static_cast<VAL_T>(min_bin + most_freq_bin);
if (most_freq_bin == 0) {
th -= 1;
t_default_bin -= 1;
t_most_freq_bin -= 1;
}
data_size_t lte_count = 0;
data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count;
if (most_freq_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
if (missing_type == MissingType::NaN) {
if (default_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count;
if (default_left) {
missing_default_indices = lte_indices;
missing_default_count = &lte_count;
......@@ -158,10 +160,10 @@ class DenseBin: public Bin {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = data_[idx];
if (bin < minb || bin > maxb || t_default_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin == maxb) {
if (bin == maxb) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
......@@ -169,19 +171,36 @@ class DenseBin: public Bin {
}
}
} else {
if ((default_left && missing_type == MissingType::Zero) || (default_bin <= threshold && missing_type != MissingType::Zero)) {
default_indices = lte_indices;
default_count = &lte_count;
if ((default_left && missing_type == MissingType::Zero)
|| (default_bin <= threshold && missing_type != MissingType::Zero)) {
missing_default_indices = lte_indices;
missing_default_count = &lte_count;
}
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = data_[idx];
if (bin < minb || bin > maxb || t_default_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
if (default_bin == most_freq_bin) {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = data_[idx];
if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
} else {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = data_[idx];
if (bin == t_default_bin) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
}
}
......@@ -189,7 +208,7 @@ class DenseBin: public Bin {
}
data_size_t SplitCategorical(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin,
const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; }
......@@ -197,7 +216,7 @@ class DenseBin: public Bin {
data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
if (Common::FindInBitset(threshold, num_threahold, default_bin)) {
if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
default_indices = lte_indices;
default_count = &lte_count;
}
......@@ -271,7 +290,7 @@ uint32_t DenseBinIterator<VAL_T>::Get(data_size_t idx) {
if (ret >= min_bin_ && ret <= max_bin_) {
return ret - min_bin_ + offset_;
} else {
return default_bin_;
return most_freq_bin_;
}
}
......@@ -281,8 +300,8 @@ inline uint32_t DenseBinIterator<VAL_T>::RawGet(data_size_t idx) {
}
template <typename VAL_T>
BinIterator* DenseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const {
return new DenseBinIterator<VAL_T>(this, min_bin, max_bin, default_bin);
BinIterator* DenseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const {
return new DenseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
}
} // namespace LightGBM
......
......@@ -17,11 +17,11 @@ class Dense4bitsBin;
class Dense4bitsBinIterator : public BinIterator {
public:
explicit Dense4bitsBinIterator(const Dense4bitsBin* bin_data, uint32_t min_bin, uint32_t max_bin, uint32_t default_bin)
explicit Dense4bitsBinIterator(const Dense4bitsBin* bin_data, uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin)
: bin_data_(bin_data), min_bin_(static_cast<uint8_t>(min_bin)),
max_bin_(static_cast<uint8_t>(max_bin)),
default_bin_(static_cast<uint8_t>(default_bin)) {
if (default_bin_ == 0) {
most_freq_bin_(static_cast<uint8_t>(most_freq_bin)) {
if (most_freq_bin_ == 0) {
offset_ = 1;
} else {
offset_ = 0;
......@@ -35,7 +35,7 @@ class Dense4bitsBinIterator : public BinIterator {
const Dense4bitsBin* bin_data_;
uint8_t min_bin_;
uint8_t max_bin_;
uint8_t default_bin_;
uint8_t most_freq_bin_;
uint8_t offset_;
};
......@@ -71,7 +71,7 @@ class Dense4bitsBin : public Bin {
}
}
inline BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const override;
inline BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const override;
void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end,
const score_t* ordered_gradients, const score_t* ordered_hessians,
......@@ -134,7 +134,7 @@ class Dense4bitsBin : public Bin {
}
data_size_t Split(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, MissingType missing_type, bool default_left,
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t most_freq_bin, MissingType missing_type, bool default_left,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; }
......@@ -142,21 +142,23 @@ class Dense4bitsBin : public Bin {
const uint8_t minb = static_cast<uint8_t>(min_bin);
const uint8_t maxb = static_cast<uint8_t>(max_bin);
uint8_t t_default_bin = static_cast<uint8_t>(min_bin + default_bin);
if (default_bin == 0) {
uint8_t t_most_freq_bin = static_cast<uint8_t>(min_bin + most_freq_bin);
if (most_freq_bin == 0) {
th -= 1;
t_default_bin -= 1;
t_most_freq_bin -= 1;
}
data_size_t lte_count = 0;
data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count;
if (most_freq_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
if (missing_type == MissingType::NaN) {
if (default_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count;
if (default_left) {
missing_default_indices = lte_indices;
missing_default_count = &lte_count;
......@@ -164,10 +166,10 @@ class Dense4bitsBin : public Bin {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin < minb || bin > maxb || t_default_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin == maxb) {
if (bin == maxb) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
......@@ -175,19 +177,36 @@ class Dense4bitsBin : public Bin {
}
}
} else {
if ((default_left && missing_type == MissingType::Zero) || (default_bin <= threshold && missing_type != MissingType::Zero)) {
default_indices = lte_indices;
default_count = &lte_count;
if ((default_left && missing_type == MissingType::Zero)
|| (default_bin <= threshold && missing_type != MissingType::Zero)) {
missing_default_indices = lte_indices;
missing_default_count = &lte_count;
}
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin < minb || bin > maxb || t_default_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
if (default_bin == most_freq_bin) {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
} else {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin == t_default_bin) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
}
}
......@@ -195,7 +214,7 @@ class Dense4bitsBin : public Bin {
}
data_size_t SplitCategorical(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin,
const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; }
......@@ -203,7 +222,7 @@ class Dense4bitsBin : public Bin {
data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
if (Common::FindInBitset(threshold, num_threahold, default_bin)) {
if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
default_indices = lte_indices;
default_count = &lte_count;
}
......@@ -303,7 +322,7 @@ uint32_t Dense4bitsBinIterator::Get(data_size_t idx) {
if (bin >= min_bin_ && bin <= max_bin_) {
return bin - min_bin_ + offset_;
} else {
return default_bin_;
return most_freq_bin_;
}
}
......@@ -311,8 +330,8 @@ uint32_t Dense4bitsBinIterator::RawGet(data_size_t idx) {
return (bin_data_->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
}
inline BinIterator* Dense4bitsBin::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const {
return new Dense4bitsBinIterator(this, min_bin, max_bin, default_bin);
inline BinIterator* Dense4bitsBin::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const {
return new Dense4bitsBinIterator(this, min_bin, max_bin, most_freq_bin);
}
} // namespace LightGBM
......
......@@ -26,11 +26,11 @@ template <typename VAL_T>
class SparseBinIterator: public BinIterator {
public:
SparseBinIterator(const SparseBin<VAL_T>* bin_data,
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin)
uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin)
: bin_data_(bin_data), min_bin_(static_cast<VAL_T>(min_bin)),
max_bin_(static_cast<VAL_T>(max_bin)),
default_bin_(static_cast<VAL_T>(default_bin)) {
if (default_bin_ == 0) {
most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
if (most_freq_bin_ == 0) {
offset_ = 1;
} else {
offset_ = 0;
......@@ -50,7 +50,7 @@ class SparseBinIterator: public BinIterator {
if (ret >= min_bin_ && ret <= max_bin_) {
return ret - min_bin_ + offset_;
} else {
return default_bin_;
return most_freq_bin_;
}
}
......@@ -62,7 +62,7 @@ class SparseBinIterator: public BinIterator {
data_size_t i_delta_;
VAL_T min_bin_;
VAL_T max_bin_;
VAL_T default_bin_;
VAL_T most_freq_bin_;
uint8_t offset_;
};
......@@ -100,7 +100,7 @@ class SparseBin: public Bin {
}
}
BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const override;
BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const override;
void ConstructHistogram(const data_size_t*, data_size_t, data_size_t, const score_t*,
const score_t*, HistogramBinEntry*) const override {
......@@ -145,32 +145,34 @@ class SparseBin: public Bin {
}
}
data_size_t Split(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, MissingType missing_type, bool default_left,
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t most_freq_bin, MissingType missing_type, bool default_left,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
// not need to split
if (num_data <= 0) { return 0; }
VAL_T th = static_cast<VAL_T>(threshold + min_bin);
const VAL_T minb = static_cast<VAL_T>(min_bin);
const VAL_T maxb = static_cast<VAL_T>(max_bin);
VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin);
if (default_bin == 0) {
VAL_T t_most_freq_bin = static_cast<VAL_T>(min_bin + most_freq_bin);
if (most_freq_bin == 0) {
th -= 1;
t_default_bin -= 1;
t_most_freq_bin -= 1;
}
SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
data_size_t lte_count = 0;
data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count;
SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
if (most_freq_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
if (missing_type == MissingType::NaN) {
if (default_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count;
if (default_left) {
missing_default_indices = lte_indices;
missing_default_count = &lte_count;
......@@ -178,10 +180,10 @@ class SparseBin: public Bin {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = iterator.InnerRawGet(idx);
if (bin < minb || bin > maxb || t_default_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin == maxb) {
if (bin == maxb) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
......@@ -189,19 +191,36 @@ class SparseBin: public Bin {
}
}
} else {
if ((default_left && missing_type == MissingType::Zero) || (default_bin <= threshold && missing_type != MissingType::Zero)) {
default_indices = lte_indices;
default_count = &lte_count;
if ((default_left && missing_type == MissingType::Zero)
|| (default_bin <= threshold && missing_type != MissingType::Zero)) {
missing_default_indices = lte_indices;
missing_default_count = &lte_count;
}
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = iterator.InnerRawGet(idx);
if (bin < minb || bin > maxb || t_default_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
if (default_bin == most_freq_bin) {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = iterator.InnerRawGet(idx);
if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
} else {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = iterator.InnerRawGet(idx);
if (bin == t_default_bin) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
}
}
......@@ -209,7 +228,7 @@ class SparseBin: public Bin {
}
data_size_t SplitCategorical(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin,
const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; }
......@@ -218,7 +237,7 @@ class SparseBin: public Bin {
SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
if (Common::FindInBitset(threshold, num_threahold, default_bin)) {
if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
default_indices = lte_indices;
default_count = &lte_count;
}
......@@ -464,8 +483,8 @@ inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
}
template <typename VAL_T>
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const {
return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, default_bin);
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const {
return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
}
} // namespace LightGBM
......
......@@ -61,7 +61,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
int cur_min_machine = static_cast<int>(ArrayArgs<int>::ArgMin(num_bins_distributed));
feature_distribution[cur_min_machine].push_back(inner_feature_index);
auto num_bin = this->train_data_->FeatureNumBin(inner_feature_index);
if (this->train_data_->FeatureBinMapper(inner_feature_index)->GetDefaultBin() == 0) {
if (this->train_data_->FeatureBinMapper(inner_feature_index)->GetMostFreqBin() == 0) {
num_bin -= 1;
}
num_bins_distributed[cur_min_machine] += num_bin;
......@@ -79,7 +79,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
block_len_[i] = 0;
for (auto fid : feature_distribution[i]) {
auto num_bin = this->train_data_->FeatureNumBin(fid);
if (this->train_data_->FeatureBinMapper(fid)->GetDefaultBin() == 0) {
if (this->train_data_->FeatureBinMapper(fid)->GetMostFreqBin() == 0) {
num_bin -= 1;
}
block_len_[i] += num_bin * sizeof(HistogramBinEntry);
......@@ -98,7 +98,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
for (auto fid : feature_distribution[i]) {
buffer_write_start_pos_[fid] = bin_size;
auto num_bin = this->train_data_->FeatureNumBin(fid);
if (this->train_data_->FeatureBinMapper(fid)->GetDefaultBin() == 0) {
if (this->train_data_->FeatureBinMapper(fid)->GetMostFreqBin() == 0) {
num_bin -= 1;
}
bin_size += num_bin * sizeof(HistogramBinEntry);
......@@ -110,7 +110,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
for (auto fid : feature_distribution[rank_]) {
buffer_read_start_pos_[fid] = bin_size;
auto num_bin = this->train_data_->FeatureNumBin(fid);
if (this->train_data_->FeatureBinMapper(fid)->GetDefaultBin() == 0) {
if (this->train_data_->FeatureBinMapper(fid)->GetMostFreqBin() == 0) {
num_bin -= 1;
}
bin_size += num_bin * sizeof(HistogramBinEntry);
......
......@@ -710,7 +710,7 @@ class HistogramPool {
feature_metas_[i].missing_type = train_data->FeatureBinMapper(i)->missing_type();
feature_metas_[i].monotone_type = train_data->FeatureMonotone(i);
feature_metas_[i].penalty = train_data->FeaturePenalte(i);
if (train_data->FeatureBinMapper(i)->GetDefaultBin() == 0) {
if (train_data->FeatureBinMapper(i)->GetMostFreqBin() == 0) {
feature_metas_[i].offset = 1;
} else {
feature_metas_[i].offset = 0;
......@@ -740,7 +740,7 @@ class HistogramPool {
offset += static_cast<uint64_t>(train_data->SubFeatureBinOffset(j));
pool_[i][j].Init(data_[i].data() + offset, &feature_metas_[j]);
auto num_bin = train_data->FeatureNumBin(j);
if (train_data->FeatureBinMapper(j)->GetDefaultBin() == 0) {
if (train_data->FeatureBinMapper(j)->GetMostFreqBin() == 0) {
num_bin -= 1;
}
offset += static_cast<uint64_t>(num_bin);
......
......@@ -74,7 +74,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
feature_metas_[i].missing_type = train_data->FeatureBinMapper(i)->missing_type();
feature_metas_[i].monotone_type = train_data->FeatureMonotone(i);
feature_metas_[i].penalty = train_data->FeaturePenalte(i);
if (train_data->FeatureBinMapper(i)->GetDefaultBin() == 0) {
if (train_data->FeatureBinMapper(i)->GetMostFreqBin() == 0) {
feature_metas_[i].offset = 1;
} else {
feature_metas_[i].offset = 0;
......@@ -88,7 +88,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
smaller_leaf_histogram_array_global_[j].Init(smaller_leaf_histogram_data_.data() + offset, &feature_metas_[j]);
larger_leaf_histogram_array_global_[j].Init(larger_leaf_histogram_data_.data() + offset, &feature_metas_[j]);
auto num_bin = train_data->FeatureNumBin(j);
if (train_data->FeatureBinMapper(j)->GetDefaultBin() == 0) {
if (train_data->FeatureBinMapper(j)->GetMostFreqBin() == 0) {
num_bin -= 1;
}
offset += static_cast<uint64_t>(num_bin);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment