"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "9f1af051b44564eaab2bebe1612c6a52217bb32b"
Unverified Commit c7e90393 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support most frequent bin (#2689)

* implement

* fix warning

* fix bug

* fix a bug

* remove unneed function

* fix data push bug

* fix valid data push

* fix bug for missing_type=zero

* refine split

* renames

* typo
parent 82886ba6
...@@ -117,6 +117,7 @@ class BinMapper { ...@@ -117,6 +117,7 @@ class BinMapper {
return bin_2_categorical_[bin]; return bin_2_categorical_[bin];
} }
} }
/*! /*!
* \brief Get sizes in byte of this object * \brief Get sizes in byte of this object
*/ */
...@@ -135,6 +136,11 @@ class BinMapper { ...@@ -135,6 +136,11 @@ class BinMapper {
inline uint32_t GetDefaultBin() const { inline uint32_t GetDefaultBin() const {
return default_bin_; return default_bin_;
} }
inline uint32_t GetMostFreqBin() const {
return most_freq_bin_;
}
/*! /*!
* \brief Construct feature value to bin mapper according feature values * \brief Construct feature value to bin mapper according feature values
* \param values (Sampled) values of this feature, Note: not include zero. * \param values (Sampled) values of this feature, Note: not include zero.
...@@ -211,6 +217,8 @@ class BinMapper { ...@@ -211,6 +217,8 @@ class BinMapper {
double max_val_; double max_val_;
/*! \brief bin value of feature value 0 */ /*! \brief bin value of feature value 0 */
uint32_t default_bin_; uint32_t default_bin_;
uint32_t most_freq_bin_;
}; };
/*! /*!
...@@ -306,10 +314,10 @@ class Bin { ...@@ -306,10 +314,10 @@ class Bin {
* \brief Get bin iterator of this bin for specific feature * \brief Get bin iterator of this bin for specific feature
* \param min_bin min_bin of current used feature * \param min_bin min_bin of current used feature
* \param max_bin max_bin of current used feature * \param max_bin max_bin of current used feature
* \param default_bin default bin if bin not in [min_bin, max_bin] * \param most_freq_bin
* \return Iterator of this bin * \return Iterator of this bin
*/ */
virtual BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const = 0; virtual BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const = 0;
/*! /*!
* \brief Save binary data to file * \brief Save binary data to file
...@@ -381,7 +389,8 @@ class Bin { ...@@ -381,7 +389,8 @@ class Bin {
* \brief Split data according to threshold, if bin <= threshold, will put into left(lte_indices), else put into right(gt_indices) * \brief Split data according to threshold, if bin <= threshold, will put into left(lte_indices), else put into right(gt_indices)
* \param min_bin min_bin of current used feature * \param min_bin min_bin of current used feature
* \param max_bin max_bin of current used feature * \param max_bin max_bin of current used feature
* \param default_bin default bin if bin not in [min_bin, max_bin] * \param default_bin default bin for feature value 0
* \param most_freq_bin
* \param missing_type missing type * \param missing_type missing type
* \param default_left missing bin will go to left child * \param default_left missing bin will go to left child
* \param threshold The split threshold. * \param threshold The split threshold.
...@@ -392,7 +401,7 @@ class Bin { ...@@ -392,7 +401,7 @@ class Bin {
* \return The number of less than or equal data. * \return The number of less than or equal data.
*/ */
virtual data_size_t Split(uint32_t min_bin, uint32_t max_bin, virtual data_size_t Split(uint32_t min_bin, uint32_t max_bin,
uint32_t default_bin, MissingType missing_type, bool default_left, uint32_t threshold, uint32_t default_bin, uint32_t most_freq_bin, MissingType missing_type, bool default_left, uint32_t threshold,
data_size_t* data_indices, data_size_t num_data, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const = 0; data_size_t* lte_indices, data_size_t* gt_indices) const = 0;
...@@ -400,7 +409,7 @@ class Bin { ...@@ -400,7 +409,7 @@ class Bin {
* \brief Split data according to threshold, if bin <= threshold, will put into left(lte_indices), else put into right(gt_indices) * \brief Split data according to threshold, if bin <= threshold, will put into left(lte_indices), else put into right(gt_indices)
* \param min_bin min_bin of current used feature * \param min_bin min_bin of current used feature
* \param max_bin max_bin of current used feature * \param max_bin max_bin of current used feature
* \param default_bin default bin if bin not in [min_bin, max_bin] * \param most_freq_bin
* \param threshold The split threshold. * \param threshold The split threshold.
* \param num_threshold Number of threshold * \param num_threshold Number of threshold
* \param data_indices Used data indices. After called this function. The less than or equal data indices will store on this object. * \param data_indices Used data indices. After called this function. The less than or equal data indices will store on this object.
...@@ -410,7 +419,7 @@ class Bin { ...@@ -410,7 +419,7 @@ class Bin {
* \return The number of less than or equal data. * \return The number of less than or equal data.
*/ */
virtual data_size_t SplitCategorical(uint32_t min_bin, uint32_t max_bin, virtual data_size_t SplitCategorical(uint32_t min_bin, uint32_t max_bin,
uint32_t default_bin, const uint32_t* threshold, int num_threshold, uint32_t most_freq_bin, const uint32_t* threshold, int num_threshold,
data_size_t* data_indices, data_size_t num_data, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const = 0; data_size_t* lte_indices, data_size_t* gt_indices) const = 0;
...@@ -433,7 +442,6 @@ class Bin { ...@@ -433,7 +442,6 @@ class Bin {
* \param is_enable_sparse True if enable sparse feature * \param is_enable_sparse True if enable sparse feature
* \param sparse_threshold Threshold for treating a feature as a sparse feature * \param sparse_threshold Threshold for treating a feature as a sparse feature
* \param is_sparse Will set to true if this bin is sparse * \param is_sparse Will set to true if this bin is sparse
* \param default_bin Default bin for zeros value
* \return The bin data object * \return The bin data object
*/ */
static Bin* CreateBin(data_size_t num_data, int num_bin, static Bin* CreateBin(data_size_t num_data, int num_bin,
......
...@@ -293,6 +293,7 @@ class Dataset { ...@@ -293,6 +293,7 @@ class Dataset {
int num_total_features, int num_total_features,
const std::vector<std::vector<double>>& forced_bins, const std::vector<std::vector<double>>& forced_bins,
int** sample_non_zero_indices, int** sample_non_zero_indices,
double** sample_values,
const int* num_per_col, const int* num_per_col,
int num_sample_col, int num_sample_col,
size_t total_sample_cnt, size_t total_sample_cnt,
...@@ -319,6 +320,16 @@ class Dataset { ...@@ -319,6 +320,16 @@ class Dataset {
return true; return true;
} }
inline void FinishOneRow(int tid, data_size_t row_idx, const std::vector<bool>& is_feature_added) {
if (is_finish_load_) { return; }
for (auto fidx : feature_need_push_zeros_) {
if (is_feature_added[fidx]) { continue; }
const int group = feature2group_[fidx];
const int sub_feature = feature2subfeature_[fidx];
feature_groups_[group]->PushData(tid, sub_feature, row_idx, 0.0f);
}
}
inline void PushOneRow(int tid, data_size_t row_idx, const std::vector<double>& feature_values) { inline void PushOneRow(int tid, data_size_t row_idx, const std::vector<double>& feature_values) {
if (is_finish_load_) { return; } if (is_finish_load_) { return; }
for (size_t i = 0; i < feature_values.size() && i < static_cast<size_t>(num_total_features_); ++i) { for (size_t i = 0; i < feature_values.size() && i < static_cast<size_t>(num_total_features_); ++i) {
...@@ -333,15 +344,18 @@ class Dataset { ...@@ -333,15 +344,18 @@ class Dataset {
inline void PushOneRow(int tid, data_size_t row_idx, const std::vector<std::pair<int, double>>& feature_values) { inline void PushOneRow(int tid, data_size_t row_idx, const std::vector<std::pair<int, double>>& feature_values) {
if (is_finish_load_) { return; } if (is_finish_load_) { return; }
std::vector<bool> is_feature_added(num_features_, false);
for (auto& inner_data : feature_values) { for (auto& inner_data : feature_values) {
if (inner_data.first >= num_total_features_) { continue; } if (inner_data.first >= num_total_features_) { continue; }
int feature_idx = used_feature_map_[inner_data.first]; int feature_idx = used_feature_map_[inner_data.first];
if (feature_idx >= 0) { if (feature_idx >= 0) {
is_feature_added[feature_idx] = true;
const int group = feature2group_[feature_idx]; const int group = feature2group_[feature_idx];
const int sub_feature = feature2subfeature_[feature_idx]; const int sub_feature = feature2subfeature_[feature_idx];
feature_groups_[group]->PushData(tid, sub_feature, row_idx, inner_data.second); feature_groups_[group]->PushData(tid, sub_feature, row_idx, inner_data.second);
} }
} }
FinishOneRow(tid, row_idx, is_feature_added);
} }
inline void PushOneData(int tid, data_size_t row_idx, int group, int sub_feature, double value) { inline void PushOneData(int tid, data_size_t row_idx, int group, int sub_feature, double value) {
...@@ -647,6 +661,7 @@ class Dataset { ...@@ -647,6 +661,7 @@ class Dataset {
int min_data_in_bin_; int min_data_in_bin_;
bool use_missing_; bool use_missing_;
bool zero_as_missing_; bool zero_as_missing_;
std::vector<int> feature_need_push_zeros_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -34,14 +34,14 @@ class FeatureGroup { ...@@ -34,14 +34,14 @@ class FeatureGroup {
std::vector<std::unique_ptr<BinMapper>>* bin_mappers, std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
data_size_t num_data, double sparse_threshold, bool is_enable_sparse) : num_feature_(num_feature) { data_size_t num_data, double sparse_threshold, bool is_enable_sparse) : num_feature_(num_feature) {
CHECK(static_cast<int>(bin_mappers->size()) == num_feature); CHECK(static_cast<int>(bin_mappers->size()) == num_feature);
// use bin at zero to store default_bin // use bin at zero to store most_freq_bin
num_total_bin_ = 1; num_total_bin_ = 1;
bin_offsets_.emplace_back(num_total_bin_); bin_offsets_.emplace_back(num_total_bin_);
int cnt_non_zero = 0; int cnt_non_zero = 0;
for (int i = 0; i < num_feature_; ++i) { for (int i = 0; i < num_feature_; ++i) {
bin_mappers_.emplace_back(bin_mappers->at(i).release()); bin_mappers_.emplace_back(bin_mappers->at(i).release());
auto num_bin = bin_mappers_[i]->num_bin(); auto num_bin = bin_mappers_[i]->num_bin();
if (bin_mappers_[i]->GetDefaultBin() == 0) { if (bin_mappers_[i]->GetMostFreqBin() == 0) {
num_bin -= 1; num_bin -= 1;
} }
num_total_bin_ += num_bin; num_total_bin_ += num_bin;
...@@ -57,13 +57,13 @@ class FeatureGroup { ...@@ -57,13 +57,13 @@ class FeatureGroup {
std::vector<std::unique_ptr<BinMapper>>* bin_mappers, std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
data_size_t num_data, bool is_sparse) : num_feature_(num_feature) { data_size_t num_data, bool is_sparse) : num_feature_(num_feature) {
CHECK(static_cast<int>(bin_mappers->size()) == num_feature); CHECK(static_cast<int>(bin_mappers->size()) == num_feature);
// use bin at zero to store default_bin // use bin at zero to store most_freq_bin
num_total_bin_ = 1; num_total_bin_ = 1;
bin_offsets_.emplace_back(num_total_bin_); bin_offsets_.emplace_back(num_total_bin_);
for (int i = 0; i < num_feature_; ++i) { for (int i = 0; i < num_feature_; ++i) {
bin_mappers_.emplace_back(bin_mappers->at(i).release()); bin_mappers_.emplace_back(bin_mappers->at(i).release());
auto num_bin = bin_mappers_[i]->num_bin(); auto num_bin = bin_mappers_[i]->num_bin();
if (bin_mappers_[i]->GetDefaultBin() == 0) { if (bin_mappers_[i]->GetMostFreqBin() == 0) {
num_bin -= 1; num_bin -= 1;
} }
num_total_bin_ += num_bin; num_total_bin_ += num_bin;
...@@ -99,7 +99,7 @@ class FeatureGroup { ...@@ -99,7 +99,7 @@ class FeatureGroup {
for (int i = 0; i < num_feature_; ++i) { for (int i = 0; i < num_feature_; ++i) {
bin_mappers_.emplace_back(new BinMapper(memory_ptr)); bin_mappers_.emplace_back(new BinMapper(memory_ptr));
auto num_bin = bin_mappers_[i]->num_bin(); auto num_bin = bin_mappers_[i]->num_bin();
if (bin_mappers_[i]->GetDefaultBin() == 0) { if (bin_mappers_[i]->GetMostFreqBin() == 0) {
num_bin -= 1; num_bin -= 1;
} }
num_total_bin_ += num_bin; num_total_bin_ += num_bin;
...@@ -130,9 +130,9 @@ class FeatureGroup { ...@@ -130,9 +130,9 @@ class FeatureGroup {
*/ */
inline void PushData(int tid, int sub_feature_idx, data_size_t line_idx, double value) { inline void PushData(int tid, int sub_feature_idx, data_size_t line_idx, double value) {
uint32_t bin = bin_mappers_[sub_feature_idx]->ValueToBin(value); uint32_t bin = bin_mappers_[sub_feature_idx]->ValueToBin(value);
if (bin == bin_mappers_[sub_feature_idx]->GetDefaultBin()) { return; } if (bin == bin_mappers_[sub_feature_idx]->GetMostFreqBin()) { return; }
bin += bin_offsets_[sub_feature_idx]; bin += bin_offsets_[sub_feature_idx];
if (bin_mappers_[sub_feature_idx]->GetDefaultBin() == 0) { if (bin_mappers_[sub_feature_idx]->GetMostFreqBin() == 0) {
bin -= 1; bin -= 1;
} }
bin_data_->Push(tid, line_idx, bin); bin_data_->Push(tid, line_idx, bin);
...@@ -145,8 +145,8 @@ class FeatureGroup { ...@@ -145,8 +145,8 @@ class FeatureGroup {
inline BinIterator* SubFeatureIterator(int sub_feature) { inline BinIterator* SubFeatureIterator(int sub_feature) {
uint32_t min_bin = bin_offsets_[sub_feature]; uint32_t min_bin = bin_offsets_[sub_feature];
uint32_t max_bin = bin_offsets_[sub_feature + 1] - 1; uint32_t max_bin = bin_offsets_[sub_feature + 1] - 1;
uint32_t default_bin = bin_mappers_[sub_feature]->GetDefaultBin(); uint32_t most_freq_bin = bin_mappers_[sub_feature]->GetMostFreqBin();
return bin_data_->GetIterator(min_bin, max_bin, default_bin); return bin_data_->GetIterator(min_bin, max_bin, most_freq_bin);
} }
/*! /*!
...@@ -157,8 +157,8 @@ class FeatureGroup { ...@@ -157,8 +157,8 @@ class FeatureGroup {
inline BinIterator* FeatureGroupIterator() { inline BinIterator* FeatureGroupIterator() {
uint32_t min_bin = bin_offsets_[0]; uint32_t min_bin = bin_offsets_[0];
uint32_t max_bin = bin_offsets_.back() - 1; uint32_t max_bin = bin_offsets_.back() - 1;
uint32_t default_bin = 0; uint32_t most_freq_bin = 0;
return bin_data_->GetIterator(min_bin, max_bin, default_bin); return bin_data_->GetIterator(min_bin, max_bin, most_freq_bin);
} }
inline data_size_t Split( inline data_size_t Split(
...@@ -172,12 +172,13 @@ class FeatureGroup { ...@@ -172,12 +172,13 @@ class FeatureGroup {
uint32_t min_bin = bin_offsets_[sub_feature]; uint32_t min_bin = bin_offsets_[sub_feature];
uint32_t max_bin = bin_offsets_[sub_feature + 1] - 1; uint32_t max_bin = bin_offsets_[sub_feature + 1] - 1;
uint32_t default_bin = bin_mappers_[sub_feature]->GetDefaultBin(); uint32_t default_bin = bin_mappers_[sub_feature]->GetDefaultBin();
uint32_t most_freq_bin = bin_mappers_[sub_feature]->GetMostFreqBin();
if (bin_mappers_[sub_feature]->bin_type() == BinType::NumericalBin) { if (bin_mappers_[sub_feature]->bin_type() == BinType::NumericalBin) {
auto missing_type = bin_mappers_[sub_feature]->missing_type(); auto missing_type = bin_mappers_[sub_feature]->missing_type();
return bin_data_->Split(min_bin, max_bin, default_bin, missing_type, default_left, return bin_data_->Split(min_bin, max_bin, default_bin, most_freq_bin, missing_type, default_left,
*threshold, data_indices, num_data, lte_indices, gt_indices); *threshold, data_indices, num_data, lte_indices, gt_indices);
} else { } else {
return bin_data_->SplitCategorical(min_bin, max_bin, default_bin, threshold, num_threshold, data_indices, num_data, lte_indices, gt_indices); return bin_data_->SplitCategorical(min_bin, max_bin, most_freq_bin, threshold, num_threshold, data_indices, num_data, lte_indices, gt_indices);
} }
} }
/*! /*!
......
...@@ -889,13 +889,21 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -889,13 +889,21 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int group = ret->Feature2Group(feature_idx); int group = ret->Feature2Group(feature_idx);
int sub_feature = ret->Feture2SubFeature(feature_idx); int sub_feature = ret->Feture2SubFeature(feature_idx);
CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i); CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
int row_idx = 0; auto bin_mapper = ret->FeatureBinMapper(feature_idx);
while (row_idx < nrow) { if (bin_mapper->GetDefaultBin() == bin_mapper->GetMostFreqBin()) {
auto pair = col_it.NextNonZero(); int row_idx = 0;
row_idx = pair.first; while (row_idx < nrow) {
// no more data auto pair = col_it.NextNonZero();
if (row_idx < 0) { break; } row_idx = pair.first;
ret->PushOneData(tid, row_idx, group, sub_feature, pair.second); // no more data
if (row_idx < 0) { break; }
ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
}
} else {
for (int row_idx = 0; row_idx < nrow; ++row_idx) {
auto val = col_it.Get(row_idx);
ret->PushOneData(tid, row_idx, group, sub_feature, val);
}
} }
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
*/ */
#include <LightGBM/bin.h> #include <LightGBM/bin.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/common.h> #include <LightGBM/utils/common.h>
#include <LightGBM/utils/file_io.h> #include <LightGBM/utils/file_io.h>
...@@ -40,6 +41,7 @@ namespace LightGBM { ...@@ -40,6 +41,7 @@ namespace LightGBM {
min_val_ = other.min_val_; min_val_ = other.min_val_;
max_val_ = other.max_val_; max_val_ = other.max_val_;
default_bin_ = other.default_bin_; default_bin_ = other.default_bin_;
most_freq_bin_ = other.most_freq_bin_;
} }
BinMapper::BinMapper(const void* memory) { BinMapper::BinMapper(const void* memory) {
...@@ -512,8 +514,15 @@ namespace LightGBM { ...@@ -512,8 +514,15 @@ namespace LightGBM {
} }
} }
if (!is_trivial_) { if (!is_trivial_) {
most_freq_bin_ = static_cast<uint32_t>(ArrayArgs<int>::ArgMax(cnt_in_bin));
// calculate sparse rate // calculate sparse rate
sparse_rate_ = static_cast<double>(cnt_in_bin[default_bin_]) / static_cast<double>(total_sample_cnt); sparse_rate_ = static_cast<double>(cnt_in_bin[default_bin_]) / total_sample_cnt;
const double max_sparse_rate = static_cast<double>(cnt_in_bin[most_freq_bin_]) / total_sample_cnt;
if (most_freq_bin_ != default_bin_ && max_sparse_rate > 0.7f) {
sparse_rate_ = max_sparse_rate;
} else {
most_freq_bin_ = default_bin_;
}
} else { } else {
sparse_rate_ = 1.0f; sparse_rate_ = 1.0f;
} }
...@@ -529,7 +538,7 @@ namespace LightGBM { ...@@ -529,7 +538,7 @@ namespace LightGBM {
size += sizeof(BinType); size += sizeof(BinType);
size += 2 * sizeof(double); size += 2 * sizeof(double);
size += bin * sizeof(double); size += bin * sizeof(double);
size += sizeof(uint32_t); size += sizeof(uint32_t) * 2;
return size; return size;
} }
...@@ -550,6 +559,8 @@ namespace LightGBM { ...@@ -550,6 +559,8 @@ namespace LightGBM {
buffer += sizeof(max_val_); buffer += sizeof(max_val_);
std::memcpy(buffer, &default_bin_, sizeof(default_bin_)); std::memcpy(buffer, &default_bin_, sizeof(default_bin_));
buffer += sizeof(default_bin_); buffer += sizeof(default_bin_);
std::memcpy(buffer, &most_freq_bin_, sizeof(most_freq_bin_));
buffer += sizeof(most_freq_bin_);
if (bin_type_ == BinType::NumericalBin) { if (bin_type_ == BinType::NumericalBin) {
std::memcpy(buffer, bin_upper_bound_.data(), num_bin_ * sizeof(double)); std::memcpy(buffer, bin_upper_bound_.data(), num_bin_ * sizeof(double));
} else { } else {
...@@ -574,6 +585,8 @@ namespace LightGBM { ...@@ -574,6 +585,8 @@ namespace LightGBM {
buffer += sizeof(max_val_); buffer += sizeof(max_val_);
std::memcpy(&default_bin_, buffer, sizeof(default_bin_)); std::memcpy(&default_bin_, buffer, sizeof(default_bin_));
buffer += sizeof(default_bin_); buffer += sizeof(default_bin_);
std::memcpy(&most_freq_bin_, buffer, sizeof(most_freq_bin_));
buffer += sizeof(most_freq_bin_);
if (bin_type_ == BinType::NumericalBin) { if (bin_type_ == BinType::NumericalBin) {
bin_upper_bound_ = std::vector<double>(num_bin_); bin_upper_bound_ = std::vector<double>(num_bin_);
std::memcpy(bin_upper_bound_.data(), buffer, num_bin_ * sizeof(double)); std::memcpy(bin_upper_bound_.data(), buffer, num_bin_ * sizeof(double));
...@@ -596,6 +609,7 @@ namespace LightGBM { ...@@ -596,6 +609,7 @@ namespace LightGBM {
writer->Write(&min_val_, sizeof(min_val_)); writer->Write(&min_val_, sizeof(min_val_));
writer->Write(&max_val_, sizeof(max_val_)); writer->Write(&max_val_, sizeof(max_val_));
writer->Write(&default_bin_, sizeof(default_bin_)); writer->Write(&default_bin_, sizeof(default_bin_));
writer->Write(&most_freq_bin_, sizeof(most_freq_bin_));
if (bin_type_ == BinType::NumericalBin) { if (bin_type_ == BinType::NumericalBin) {
writer->Write(bin_upper_bound_.data(), sizeof(double) * num_bin_); writer->Write(bin_upper_bound_.data(), sizeof(double) * num_bin_);
} else { } else {
...@@ -605,7 +619,7 @@ namespace LightGBM { ...@@ -605,7 +619,7 @@ namespace LightGBM {
size_t BinMapper::SizesInByte() const { size_t BinMapper::SizesInByte() const {
size_t ret = sizeof(num_bin_) + sizeof(missing_type_) + sizeof(is_trivial_) + sizeof(sparse_rate_) size_t ret = sizeof(num_bin_) + sizeof(missing_type_) + sizeof(is_trivial_) + sizeof(sparse_rate_)
+ sizeof(bin_type_) + sizeof(min_val_) + sizeof(max_val_) + sizeof(default_bin_); + sizeof(bin_type_) + sizeof(min_val_) + sizeof(max_val_) + sizeof(default_bin_) + sizeof(most_freq_bin_);
if (bin_type_ == BinType::NumericalBin) { if (bin_type_ == BinType::NumericalBin) {
ret += sizeof(double) * num_bin_; ret += sizeof(double) * num_bin_;
} else { } else {
......
...@@ -67,6 +67,27 @@ void MarkUsed(std::vector<bool>* mark, const int* indices, int num_indices) { ...@@ -67,6 +67,27 @@ void MarkUsed(std::vector<bool>* mark, const int* indices, int num_indices) {
} }
} }
std::vector<int> FixSampleIndices(const BinMapper* bin_mapper, int num_total_samples, int num_indices, const int* sample_indices, const double* sample_values) {
std::vector<int> ret;
if (bin_mapper->GetDefaultBin() == bin_mapper->GetMostFreqBin()) {
return ret;
}
int i = 0, j = 0;
while (i < num_total_samples) {
if (j < num_indices && sample_indices[j] < i) {
++j;
} else if (j < num_indices && sample_indices[j] == i) {
if (bin_mapper->ValueToBin(sample_values[j]) != bin_mapper->GetMostFreqBin()) {
ret.push_back(i);
}
++i;
} else {
ret.push_back(i++);
}
}
return ret;
}
std::vector<std::vector<int>> FindGroups(const std::vector<std::unique_ptr<BinMapper>>& bin_mappers, std::vector<std::vector<int>> FindGroups(const std::vector<std::unique_ptr<BinMapper>>& bin_mappers,
const std::vector<int>& find_order, const std::vector<int>& find_order,
int** sample_indices, int** sample_indices,
...@@ -147,6 +168,7 @@ std::vector<std::vector<int>> FindGroups(const std::vector<std::unique_ptr<BinMa ...@@ -147,6 +168,7 @@ std::vector<std::vector<int>> FindGroups(const std::vector<std::unique_ptr<BinMa
std::vector<std::vector<int>> FastFeatureBundling(const std::vector<std::unique_ptr<BinMapper>>& bin_mappers, std::vector<std::vector<int>> FastFeatureBundling(const std::vector<std::unique_ptr<BinMapper>>& bin_mappers,
int** sample_indices, int** sample_indices,
double** sample_values,
const int* num_per_col, const int* num_per_col,
int num_sample_col, int num_sample_col,
size_t total_sample_cnt, size_t total_sample_cnt,
...@@ -187,8 +209,23 @@ std::vector<std::vector<int>> FastFeatureBundling(const std::vector<std::unique_ ...@@ -187,8 +209,23 @@ std::vector<std::vector<int>> FastFeatureBundling(const std::vector<std::unique_
for (auto sidx : sorted_idx) { for (auto sidx : sorted_idx) {
feature_order_by_cnt.push_back(used_features[sidx]); feature_order_by_cnt.push_back(used_features[sidx]);
} }
auto features_in_group = FindGroups(bin_mappers, used_features, sample_indices, num_per_col, num_sample_col, total_sample_cnt, max_error_cnt, filter_cnt, num_data, is_use_gpu); std::vector<std::vector<int>> tmp_indices;
auto group2 = FindGroups(bin_mappers, feature_order_by_cnt, sample_indices, num_per_col, num_sample_col, total_sample_cnt, max_error_cnt, filter_cnt, num_data, is_use_gpu); std::vector<int> tmp_num_per_col(num_sample_col, 0);
for (auto fidx : used_features) {
if (fidx >= num_sample_col) {
continue;
}
auto ret = FixSampleIndices(bin_mappers[fidx].get(), static_cast<int>(total_sample_cnt), num_per_col[fidx], sample_indices[fidx], sample_values[fidx]);
if (!ret.empty()) {
tmp_indices.push_back(ret);
tmp_num_per_col[fidx] = static_cast<int>(ret.size());
sample_indices[fidx] = tmp_indices.back().data();
} else {
tmp_num_per_col[fidx] = num_per_col[fidx];
}
}
auto features_in_group = FindGroups(bin_mappers, used_features, sample_indices, tmp_num_per_col.data(), num_sample_col, total_sample_cnt, max_error_cnt, filter_cnt, num_data, is_use_gpu);
auto group2 = FindGroups(bin_mappers, feature_order_by_cnt, sample_indices, tmp_num_per_col.data(), num_sample_col, total_sample_cnt, max_error_cnt, filter_cnt, num_data, is_use_gpu);
if (features_in_group.size() > group2.size()) { if (features_in_group.size() > group2.size()) {
features_in_group = group2; features_in_group = group2;
} }
...@@ -230,6 +267,7 @@ void Dataset::Construct( ...@@ -230,6 +267,7 @@ void Dataset::Construct(
int num_total_features, int num_total_features,
const std::vector<std::vector<double>>& forced_bins, const std::vector<std::vector<double>>& forced_bins,
int** sample_non_zero_indices, int** sample_non_zero_indices,
double** sample_values,
const int* num_per_col, const int* num_per_col,
int num_sample_col, int num_sample_col,
size_t total_sample_cnt, size_t total_sample_cnt,
...@@ -252,7 +290,7 @@ void Dataset::Construct( ...@@ -252,7 +290,7 @@ void Dataset::Construct(
if (io_config.enable_bundle && !used_features.empty()) { if (io_config.enable_bundle && !used_features.empty()) {
features_in_group = FastFeatureBundling(*bin_mappers, features_in_group = FastFeatureBundling(*bin_mappers,
sample_non_zero_indices, num_per_col, num_sample_col, total_sample_cnt, sample_non_zero_indices, sample_values, num_per_col, num_sample_col, total_sample_cnt,
used_features, io_config.max_conflict_rate, used_features, io_config.max_conflict_rate,
num_data_, io_config.min_data_in_leaf, num_data_, io_config.min_data_in_leaf,
sparse_threshold_, io_config.is_enable_sparse, io_config.device_type == std::string("gpu")); sparse_threshold_, io_config.is_enable_sparse, io_config.device_type == std::string("gpu"));
...@@ -268,6 +306,7 @@ void Dataset::Construct( ...@@ -268,6 +306,7 @@ void Dataset::Construct(
real_feature_idx_.resize(num_features_); real_feature_idx_.resize(num_features_);
feature2group_.resize(num_features_); feature2group_.resize(num_features_);
feature2subfeature_.resize(num_features_); feature2subfeature_.resize(num_features_);
feature_need_push_zeros_.clear();
for (int i = 0; i < num_groups_; ++i) { for (int i = 0; i < num_groups_; ++i) {
auto cur_features = features_in_group[i]; auto cur_features = features_in_group[i];
int cur_cnt_features = static_cast<int>(cur_features.size()); int cur_cnt_features = static_cast<int>(cur_features.size());
...@@ -280,6 +319,9 @@ void Dataset::Construct( ...@@ -280,6 +319,9 @@ void Dataset::Construct(
feature2group_[cur_fidx] = i; feature2group_[cur_fidx] = i;
feature2subfeature_[cur_fidx] = j; feature2subfeature_[cur_fidx] = j;
cur_bin_mappers.emplace_back(ref_bin_mappers[real_fidx].release()); cur_bin_mappers.emplace_back(ref_bin_mappers[real_fidx].release());
if (cur_bin_mappers.back()->GetDefaultBin() != cur_bin_mappers.back()->GetMostFreqBin()) {
feature_need_push_zeros_.push_back(cur_fidx);
}
++cur_fidx; ++cur_fidx;
} }
feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>( feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
...@@ -453,6 +495,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) { ...@@ -453,6 +495,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) {
monotone_types_ = dataset->monotone_types_; monotone_types_ = dataset->monotone_types_;
feature_penalty_ = dataset->feature_penalty_; feature_penalty_ = dataset->feature_penalty_;
forced_bin_bounds_ = dataset->forced_bin_bounds_; forced_bin_bounds_ = dataset->forced_bin_bounds_;
feature_need_push_zeros_ = dataset->feature_need_push_zeros_;
} }
void Dataset::CreateValid(const Dataset* dataset) { void Dataset::CreateValid(const Dataset* dataset) {
...@@ -464,9 +507,13 @@ void Dataset::CreateValid(const Dataset* dataset) { ...@@ -464,9 +507,13 @@ void Dataset::CreateValid(const Dataset* dataset) {
feature2group_.clear(); feature2group_.clear();
feature2subfeature_.clear(); feature2subfeature_.clear();
// copy feature bin mapper data // copy feature bin mapper data
feature_need_push_zeros_.clear();
for (int i = 0; i < num_features_; ++i) { for (int i = 0; i < num_features_; ++i) {
std::vector<std::unique_ptr<BinMapper>> bin_mappers; std::vector<std::unique_ptr<BinMapper>> bin_mappers;
bin_mappers.emplace_back(new BinMapper(*(dataset->FeatureBinMapper(i)))); bin_mappers.emplace_back(new BinMapper(*(dataset->FeatureBinMapper(i))));
if (bin_mappers.back()->GetDefaultBin() != bin_mappers.back()->GetMostFreqBin()) {
feature_need_push_zeros_.push_back(i);
}
feature_groups_.emplace_back(new FeatureGroup( feature_groups_.emplace_back(new FeatureGroup(
1, 1,
&bin_mappers, &bin_mappers,
...@@ -812,7 +859,7 @@ void Dataset::DumpTextFile(const char* text_filename) { ...@@ -812,7 +859,7 @@ void Dataset::DumpTextFile(const char* text_filename) {
if (inner_feature_idx < 0) { if (inner_feature_idx < 0) {
fprintf(file, "NA, "); fprintf(file, "NA, ");
} else { } else {
fprintf(file, "%d, ", iterators[inner_feature_idx]->RawGet(i)); fprintf(file, "%d, ", iterators[inner_feature_idx]->Get(i));
} }
} }
} }
...@@ -999,17 +1046,17 @@ void Dataset::FixHistogram(int feature_idx, double sum_gradient, double sum_hess ...@@ -999,17 +1046,17 @@ void Dataset::FixHistogram(int feature_idx, double sum_gradient, double sum_hess
const int group = feature2group_[feature_idx]; const int group = feature2group_[feature_idx];
const int sub_feature = feature2subfeature_[feature_idx]; const int sub_feature = feature2subfeature_[feature_idx];
const BinMapper* bin_mapper = feature_groups_[group]->bin_mappers_[sub_feature].get(); const BinMapper* bin_mapper = feature_groups_[group]->bin_mappers_[sub_feature].get();
const int default_bin = bin_mapper->GetDefaultBin(); const int most_freq_bin = bin_mapper->GetMostFreqBin();
if (default_bin > 0) { if (most_freq_bin > 0) {
const int num_bin = bin_mapper->num_bin(); const int num_bin = bin_mapper->num_bin();
data[default_bin].sum_gradients = sum_gradient; data[most_freq_bin].sum_gradients = sum_gradient;
data[default_bin].sum_hessians = sum_hessian; data[most_freq_bin].sum_hessians = sum_hessian;
data[default_bin].cnt = num_data; data[most_freq_bin].cnt = num_data;
for (int i = 0; i < num_bin; ++i) { for (int i = 0; i < num_bin; ++i) {
if (i != default_bin) { if (i != most_freq_bin) {
data[default_bin].sum_gradients -= data[i].sum_gradients; data[most_freq_bin].sum_gradients -= data[i].sum_gradients;
data[default_bin].sum_hessians -= data[i].sum_hessians; data[most_freq_bin].sum_hessians -= data[i].sum_hessians;
data[default_bin].cnt -= data[i].cnt; data[most_freq_bin].cnt -= data[i].cnt;
} }
} }
} }
......
...@@ -717,7 +717,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -717,7 +717,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
} }
} }
auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data)); auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, num_per_col, num_col, total_sample_size, config_); dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
dataset->set_feature_names(feature_names_); dataset->set_feature_names(feature_names_);
return dataset.release(); return dataset.release();
} }
...@@ -1040,8 +1040,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -1040,8 +1040,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
cp_ptr += bin_mappers[i]->SizesInByte(); cp_ptr += bin_mappers[i]->SizesInByte();
} }
} }
sample_values.clear();
dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(), dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Common::Vector2Ptr<double>(&sample_values).data(),
Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_); Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
} }
...@@ -1066,11 +1066,13 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat ...@@ -1066,11 +1066,13 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat
ref_text_data[i].clear(); ref_text_data[i].clear();
// shrink_to_fit will be very slow in linux, and seems not free memory, disable for now // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
// text_reader_->Lines()[i].shrink_to_fit(); // text_reader_->Lines()[i].shrink_to_fit();
std::vector<bool> is_feature_added(dataset->num_features_, false);
// push data // push data
for (auto& inner_data : oneline_features) { for (auto& inner_data : oneline_features) {
if (inner_data.first >= dataset->num_total_features_) { continue; } if (inner_data.first >= dataset->num_total_features_) { continue; }
int feature_idx = dataset->used_feature_map_[inner_data.first]; int feature_idx = dataset->used_feature_map_[inner_data.first];
if (feature_idx >= 0) { if (feature_idx >= 0) {
is_feature_added[feature_idx] = true;
// if is used feature // if is used feature
int group = dataset->feature2group_[feature_idx]; int group = dataset->feature2group_[feature_idx];
int sub_feature = dataset->feature2subfeature_[feature_idx]; int sub_feature = dataset->feature2subfeature_[feature_idx];
...@@ -1083,6 +1085,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat ...@@ -1083,6 +1085,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat
} }
} }
} }
dataset->FinishOneRow(tid, i, is_feature_added);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
...@@ -1110,10 +1113,12 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat ...@@ -1110,10 +1113,12 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat
// shrink_to_fit will be very slow in linux, and seems not free memory, disable for now // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
// text_reader_->Lines()[i].shrink_to_fit(); // text_reader_->Lines()[i].shrink_to_fit();
// push data // push data
std::vector<bool> is_feature_added(dataset->num_features_, false);
for (auto& inner_data : oneline_features) { for (auto& inner_data : oneline_features) {
if (inner_data.first >= dataset->num_total_features_) { continue; } if (inner_data.first >= dataset->num_total_features_) { continue; }
int feature_idx = dataset->used_feature_map_[inner_data.first]; int feature_idx = dataset->used_feature_map_[inner_data.first];
if (feature_idx >= 0) { if (feature_idx >= 0) {
is_feature_added[feature_idx] = true;
// if is used feature // if is used feature
int group = dataset->feature2group_[feature_idx]; int group = dataset->feature2group_[feature_idx];
int sub_feature = dataset->feature2subfeature_[feature_idx]; int sub_feature = dataset->feature2subfeature_[feature_idx];
...@@ -1126,6 +1131,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat ...@@ -1126,6 +1131,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat
} }
} }
} }
dataset->FinishOneRow(tid, i, is_feature_added);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
...@@ -1167,11 +1173,13 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* ...@@ -1167,11 +1173,13 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
} }
// set label // set label
dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label)); dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
std::vector<bool> is_feature_added(dataset->num_features_, false);
// push data // push data
for (auto& inner_data : oneline_features) { for (auto& inner_data : oneline_features) {
if (inner_data.first >= dataset->num_total_features_) { continue; } if (inner_data.first >= dataset->num_total_features_) { continue; }
int feature_idx = dataset->used_feature_map_[inner_data.first]; int feature_idx = dataset->used_feature_map_[inner_data.first];
if (feature_idx >= 0) { if (feature_idx >= 0) {
is_feature_added[feature_idx] = true;
// if is used feature // if is used feature
int group = dataset->feature2group_[feature_idx]; int group = dataset->feature2group_[feature_idx];
int sub_feature = dataset->feature2subfeature_[feature_idx]; int sub_feature = dataset->feature2subfeature_[feature_idx];
...@@ -1184,6 +1192,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* ...@@ -1184,6 +1192,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
} }
} }
} }
dataset->FinishOneRow(tid, i, is_feature_added);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
......
...@@ -19,11 +19,11 @@ class DenseBin; ...@@ -19,11 +19,11 @@ class DenseBin;
template <typename VAL_T> template <typename VAL_T>
class DenseBinIterator: public BinIterator { class DenseBinIterator: public BinIterator {
public: public:
explicit DenseBinIterator(const DenseBin<VAL_T>* bin_data, uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) explicit DenseBinIterator(const DenseBin<VAL_T>* bin_data, uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin)
: bin_data_(bin_data), min_bin_(static_cast<VAL_T>(min_bin)), : bin_data_(bin_data), min_bin_(static_cast<VAL_T>(min_bin)),
max_bin_(static_cast<VAL_T>(max_bin)), max_bin_(static_cast<VAL_T>(max_bin)),
default_bin_(static_cast<VAL_T>(default_bin)) { most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
if (default_bin_ == 0) { if (most_freq_bin_ == 0) {
offset_ = 1; offset_ = 1;
} else { } else {
offset_ = 0; offset_ = 0;
...@@ -37,7 +37,7 @@ class DenseBinIterator: public BinIterator { ...@@ -37,7 +37,7 @@ class DenseBinIterator: public BinIterator {
const DenseBin<VAL_T>* bin_data_; const DenseBin<VAL_T>* bin_data_;
VAL_T min_bin_; VAL_T min_bin_;
VAL_T max_bin_; VAL_T max_bin_;
VAL_T default_bin_; VAL_T most_freq_bin_;
uint8_t offset_; uint8_t offset_;
}; };
/*! /*!
...@@ -66,7 +66,7 @@ class DenseBin: public Bin { ...@@ -66,7 +66,7 @@ class DenseBin: public Bin {
} }
} }
BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const override; BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const override;
void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end, void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end,
const score_t* ordered_gradients, const score_t* ordered_hessians, const score_t* ordered_gradients, const score_t* ordered_hessians,
...@@ -128,7 +128,7 @@ class DenseBin: public Bin { ...@@ -128,7 +128,7 @@ class DenseBin: public Bin {
} }
data_size_t Split( data_size_t Split(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, MissingType missing_type, bool default_left, uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t most_freq_bin, MissingType missing_type, bool default_left,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data, uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override { data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; } if (num_data <= 0) { return 0; }
...@@ -136,21 +136,23 @@ class DenseBin: public Bin { ...@@ -136,21 +136,23 @@ class DenseBin: public Bin {
const VAL_T minb = static_cast<VAL_T>(min_bin); const VAL_T minb = static_cast<VAL_T>(min_bin);
const VAL_T maxb = static_cast<VAL_T>(max_bin); const VAL_T maxb = static_cast<VAL_T>(max_bin);
VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin); VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin);
if (default_bin == 0) { VAL_T t_most_freq_bin = static_cast<VAL_T>(min_bin + most_freq_bin);
if (most_freq_bin == 0) {
th -= 1; th -= 1;
t_default_bin -= 1; t_default_bin -= 1;
t_most_freq_bin -= 1;
} }
data_size_t lte_count = 0; data_size_t lte_count = 0;
data_size_t gt_count = 0; data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices; data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count; data_size_t* default_count = &gt_count;
data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count;
if (most_freq_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
if (missing_type == MissingType::NaN) { if (missing_type == MissingType::NaN) {
if (default_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count;
if (default_left) { if (default_left) {
missing_default_indices = lte_indices; missing_default_indices = lte_indices;
missing_default_count = &lte_count; missing_default_count = &lte_count;
...@@ -158,10 +160,10 @@ class DenseBin: public Bin { ...@@ -158,10 +160,10 @@ class DenseBin: public Bin {
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
const VAL_T bin = data_[idx]; const VAL_T bin = data_[idx];
if (bin < minb || bin > maxb || t_default_bin == bin) { if (bin == maxb) {
default_indices[(*default_count)++] = idx;
} else if (bin == maxb) {
missing_default_indices[(*missing_default_count)++] = idx; missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) { } else if (bin > th) {
gt_indices[gt_count++] = idx; gt_indices[gt_count++] = idx;
} else { } else {
...@@ -169,19 +171,36 @@ class DenseBin: public Bin { ...@@ -169,19 +171,36 @@ class DenseBin: public Bin {
} }
} }
} else { } else {
if ((default_left && missing_type == MissingType::Zero) || (default_bin <= threshold && missing_type != MissingType::Zero)) { if ((default_left && missing_type == MissingType::Zero)
default_indices = lte_indices; || (default_bin <= threshold && missing_type != MissingType::Zero)) {
default_count = &lte_count; missing_default_indices = lte_indices;
missing_default_count = &lte_count;
} }
for (data_size_t i = 0; i < num_data; ++i) { if (default_bin == most_freq_bin) {
const data_size_t idx = data_indices[i]; for (data_size_t i = 0; i < num_data; ++i) {
const VAL_T bin = data_[idx]; const data_size_t idx = data_indices[i];
if (bin < minb || bin > maxb || t_default_bin == bin) { const VAL_T bin = data_[idx];
default_indices[(*default_count)++] = idx; if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
} else if (bin > th) { missing_default_indices[(*missing_default_count)++] = idx;
gt_indices[gt_count++] = idx; } else if (bin > th) {
} else { gt_indices[gt_count++] = idx;
lte_indices[lte_count++] = idx; } else {
lte_indices[lte_count++] = idx;
}
}
} else {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = data_[idx];
if (bin == t_default_bin) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
} }
} }
} }
...@@ -189,7 +208,7 @@ class DenseBin: public Bin { ...@@ -189,7 +208,7 @@ class DenseBin: public Bin {
} }
data_size_t SplitCategorical( data_size_t SplitCategorical(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin,
const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data, const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override { data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; } if (num_data <= 0) { return 0; }
...@@ -197,7 +216,7 @@ class DenseBin: public Bin { ...@@ -197,7 +216,7 @@ class DenseBin: public Bin {
data_size_t gt_count = 0; data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices; data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count; data_size_t* default_count = &gt_count;
if (Common::FindInBitset(threshold, num_threahold, default_bin)) { if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
default_indices = lte_indices; default_indices = lte_indices;
default_count = &lte_count; default_count = &lte_count;
} }
...@@ -271,7 +290,7 @@ uint32_t DenseBinIterator<VAL_T>::Get(data_size_t idx) { ...@@ -271,7 +290,7 @@ uint32_t DenseBinIterator<VAL_T>::Get(data_size_t idx) {
if (ret >= min_bin_ && ret <= max_bin_) { if (ret >= min_bin_ && ret <= max_bin_) {
return ret - min_bin_ + offset_; return ret - min_bin_ + offset_;
} else { } else {
return default_bin_; return most_freq_bin_;
} }
} }
...@@ -281,8 +300,8 @@ inline uint32_t DenseBinIterator<VAL_T>::RawGet(data_size_t idx) { ...@@ -281,8 +300,8 @@ inline uint32_t DenseBinIterator<VAL_T>::RawGet(data_size_t idx) {
} }
template <typename VAL_T> template <typename VAL_T>
BinIterator* DenseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const { BinIterator* DenseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const {
return new DenseBinIterator<VAL_T>(this, min_bin, max_bin, default_bin); return new DenseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
} }
} // namespace LightGBM } // namespace LightGBM
......
...@@ -17,11 +17,11 @@ class Dense4bitsBin; ...@@ -17,11 +17,11 @@ class Dense4bitsBin;
class Dense4bitsBinIterator : public BinIterator { class Dense4bitsBinIterator : public BinIterator {
public: public:
explicit Dense4bitsBinIterator(const Dense4bitsBin* bin_data, uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) explicit Dense4bitsBinIterator(const Dense4bitsBin* bin_data, uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin)
: bin_data_(bin_data), min_bin_(static_cast<uint8_t>(min_bin)), : bin_data_(bin_data), min_bin_(static_cast<uint8_t>(min_bin)),
max_bin_(static_cast<uint8_t>(max_bin)), max_bin_(static_cast<uint8_t>(max_bin)),
default_bin_(static_cast<uint8_t>(default_bin)) { most_freq_bin_(static_cast<uint8_t>(most_freq_bin)) {
if (default_bin_ == 0) { if (most_freq_bin_ == 0) {
offset_ = 1; offset_ = 1;
} else { } else {
offset_ = 0; offset_ = 0;
...@@ -35,7 +35,7 @@ class Dense4bitsBinIterator : public BinIterator { ...@@ -35,7 +35,7 @@ class Dense4bitsBinIterator : public BinIterator {
const Dense4bitsBin* bin_data_; const Dense4bitsBin* bin_data_;
uint8_t min_bin_; uint8_t min_bin_;
uint8_t max_bin_; uint8_t max_bin_;
uint8_t default_bin_; uint8_t most_freq_bin_;
uint8_t offset_; uint8_t offset_;
}; };
...@@ -71,7 +71,7 @@ class Dense4bitsBin : public Bin { ...@@ -71,7 +71,7 @@ class Dense4bitsBin : public Bin {
} }
} }
inline BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const override; inline BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const override;
void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end, void ConstructHistogram(const data_size_t* data_indices, data_size_t start, data_size_t end,
const score_t* ordered_gradients, const score_t* ordered_hessians, const score_t* ordered_gradients, const score_t* ordered_hessians,
...@@ -134,7 +134,7 @@ class Dense4bitsBin : public Bin { ...@@ -134,7 +134,7 @@ class Dense4bitsBin : public Bin {
} }
data_size_t Split( data_size_t Split(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, MissingType missing_type, bool default_left, uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t most_freq_bin, MissingType missing_type, bool default_left,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data, uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override { data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; } if (num_data <= 0) { return 0; }
...@@ -142,21 +142,23 @@ class Dense4bitsBin : public Bin { ...@@ -142,21 +142,23 @@ class Dense4bitsBin : public Bin {
const uint8_t minb = static_cast<uint8_t>(min_bin); const uint8_t minb = static_cast<uint8_t>(min_bin);
const uint8_t maxb = static_cast<uint8_t>(max_bin); const uint8_t maxb = static_cast<uint8_t>(max_bin);
uint8_t t_default_bin = static_cast<uint8_t>(min_bin + default_bin); uint8_t t_default_bin = static_cast<uint8_t>(min_bin + default_bin);
if (default_bin == 0) { uint8_t t_most_freq_bin = static_cast<uint8_t>(min_bin + most_freq_bin);
if (most_freq_bin == 0) {
th -= 1; th -= 1;
t_default_bin -= 1; t_default_bin -= 1;
t_most_freq_bin -= 1;
} }
data_size_t lte_count = 0; data_size_t lte_count = 0;
data_size_t gt_count = 0; data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices; data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count; data_size_t* default_count = &gt_count;
data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count;
if (most_freq_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
if (missing_type == MissingType::NaN) { if (missing_type == MissingType::NaN) {
if (default_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count;
if (default_left) { if (default_left) {
missing_default_indices = lte_indices; missing_default_indices = lte_indices;
missing_default_count = &lte_count; missing_default_count = &lte_count;
...@@ -164,10 +166,10 @@ class Dense4bitsBin : public Bin { ...@@ -164,10 +166,10 @@ class Dense4bitsBin : public Bin {
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf; const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin < minb || bin > maxb || t_default_bin == bin) { if (bin == maxb) {
default_indices[(*default_count)++] = idx;
} else if (bin == maxb) {
missing_default_indices[(*missing_default_count)++] = idx; missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) { } else if (bin > th) {
gt_indices[gt_count++] = idx; gt_indices[gt_count++] = idx;
} else { } else {
...@@ -175,19 +177,36 @@ class Dense4bitsBin : public Bin { ...@@ -175,19 +177,36 @@ class Dense4bitsBin : public Bin {
} }
} }
} else { } else {
if ((default_left && missing_type == MissingType::Zero) || (default_bin <= threshold && missing_type != MissingType::Zero)) { if ((default_left && missing_type == MissingType::Zero)
default_indices = lte_indices; || (default_bin <= threshold && missing_type != MissingType::Zero)) {
default_count = &lte_count; missing_default_indices = lte_indices;
missing_default_count = &lte_count;
} }
for (data_size_t i = 0; i < num_data; ++i) { if (default_bin == most_freq_bin) {
const data_size_t idx = data_indices[i]; for (data_size_t i = 0; i < num_data; ++i) {
const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf; const data_size_t idx = data_indices[i];
if (bin < minb || bin > maxb || t_default_bin == bin) { const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
default_indices[(*default_count)++] = idx; if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
} else if (bin > th) { missing_default_indices[(*missing_default_count)++] = idx;
gt_indices[gt_count++] = idx; } else if (bin > th) {
} else { gt_indices[gt_count++] = idx;
lte_indices[lte_count++] = idx; } else {
lte_indices[lte_count++] = idx;
}
}
} else {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin == t_default_bin) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
} }
} }
} }
...@@ -195,7 +214,7 @@ class Dense4bitsBin : public Bin { ...@@ -195,7 +214,7 @@ class Dense4bitsBin : public Bin {
} }
data_size_t SplitCategorical( data_size_t SplitCategorical(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin,
const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data, const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override { data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; } if (num_data <= 0) { return 0; }
...@@ -203,7 +222,7 @@ class Dense4bitsBin : public Bin { ...@@ -203,7 +222,7 @@ class Dense4bitsBin : public Bin {
data_size_t gt_count = 0; data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices; data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count; data_size_t* default_count = &gt_count;
if (Common::FindInBitset(threshold, num_threahold, default_bin)) { if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
default_indices = lte_indices; default_indices = lte_indices;
default_count = &lte_count; default_count = &lte_count;
} }
...@@ -303,7 +322,7 @@ uint32_t Dense4bitsBinIterator::Get(data_size_t idx) { ...@@ -303,7 +322,7 @@ uint32_t Dense4bitsBinIterator::Get(data_size_t idx) {
if (bin >= min_bin_ && bin <= max_bin_) { if (bin >= min_bin_ && bin <= max_bin_) {
return bin - min_bin_ + offset_; return bin - min_bin_ + offset_;
} else { } else {
return default_bin_; return most_freq_bin_;
} }
} }
...@@ -311,8 +330,8 @@ uint32_t Dense4bitsBinIterator::RawGet(data_size_t idx) { ...@@ -311,8 +330,8 @@ uint32_t Dense4bitsBinIterator::RawGet(data_size_t idx) {
return (bin_data_->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf; return (bin_data_->data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
} }
inline BinIterator* Dense4bitsBin::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const { inline BinIterator* Dense4bitsBin::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const {
return new Dense4bitsBinIterator(this, min_bin, max_bin, default_bin); return new Dense4bitsBinIterator(this, min_bin, max_bin, most_freq_bin);
} }
} // namespace LightGBM } // namespace LightGBM
......
...@@ -26,11 +26,11 @@ template <typename VAL_T> ...@@ -26,11 +26,11 @@ template <typename VAL_T>
class SparseBinIterator: public BinIterator { class SparseBinIterator: public BinIterator {
public: public:
SparseBinIterator(const SparseBin<VAL_T>* bin_data, SparseBinIterator(const SparseBin<VAL_T>* bin_data,
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin)
: bin_data_(bin_data), min_bin_(static_cast<VAL_T>(min_bin)), : bin_data_(bin_data), min_bin_(static_cast<VAL_T>(min_bin)),
max_bin_(static_cast<VAL_T>(max_bin)), max_bin_(static_cast<VAL_T>(max_bin)),
default_bin_(static_cast<VAL_T>(default_bin)) { most_freq_bin_(static_cast<VAL_T>(most_freq_bin)) {
if (default_bin_ == 0) { if (most_freq_bin_ == 0) {
offset_ = 1; offset_ = 1;
} else { } else {
offset_ = 0; offset_ = 0;
...@@ -50,7 +50,7 @@ class SparseBinIterator: public BinIterator { ...@@ -50,7 +50,7 @@ class SparseBinIterator: public BinIterator {
if (ret >= min_bin_ && ret <= max_bin_) { if (ret >= min_bin_ && ret <= max_bin_) {
return ret - min_bin_ + offset_; return ret - min_bin_ + offset_;
} else { } else {
return default_bin_; return most_freq_bin_;
} }
} }
...@@ -62,7 +62,7 @@ class SparseBinIterator: public BinIterator { ...@@ -62,7 +62,7 @@ class SparseBinIterator: public BinIterator {
data_size_t i_delta_; data_size_t i_delta_;
VAL_T min_bin_; VAL_T min_bin_;
VAL_T max_bin_; VAL_T max_bin_;
VAL_T default_bin_; VAL_T most_freq_bin_;
uint8_t offset_; uint8_t offset_;
}; };
...@@ -100,7 +100,7 @@ class SparseBin: public Bin { ...@@ -100,7 +100,7 @@ class SparseBin: public Bin {
} }
} }
BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const override; BinIterator* GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const override;
void ConstructHistogram(const data_size_t*, data_size_t, data_size_t, const score_t*, void ConstructHistogram(const data_size_t*, data_size_t, data_size_t, const score_t*,
const score_t*, HistogramBinEntry*) const override { const score_t*, HistogramBinEntry*) const override {
...@@ -145,32 +145,34 @@ class SparseBin: public Bin { ...@@ -145,32 +145,34 @@ class SparseBin: public Bin {
} }
} }
data_size_t Split( data_size_t Split(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, MissingType missing_type, bool default_left, uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t most_freq_bin, MissingType missing_type, bool default_left,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data, uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override { data_size_t* lte_indices, data_size_t* gt_indices) const override {
// not need to split
if (num_data <= 0) { return 0; } if (num_data <= 0) { return 0; }
VAL_T th = static_cast<VAL_T>(threshold + min_bin); VAL_T th = static_cast<VAL_T>(threshold + min_bin);
const VAL_T minb = static_cast<VAL_T>(min_bin); const VAL_T minb = static_cast<VAL_T>(min_bin);
const VAL_T maxb = static_cast<VAL_T>(max_bin); const VAL_T maxb = static_cast<VAL_T>(max_bin);
VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin); VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin);
if (default_bin == 0) { VAL_T t_most_freq_bin = static_cast<VAL_T>(min_bin + most_freq_bin);
if (most_freq_bin == 0) {
th -= 1; th -= 1;
t_default_bin -= 1; t_default_bin -= 1;
t_most_freq_bin -= 1;
} }
SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
data_size_t lte_count = 0; data_size_t lte_count = 0;
data_size_t gt_count = 0; data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices; data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count; data_size_t* default_count = &gt_count;
data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count;
SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
if (most_freq_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
if (missing_type == MissingType::NaN) { if (missing_type == MissingType::NaN) {
if (default_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count;
if (default_left) { if (default_left) {
missing_default_indices = lte_indices; missing_default_indices = lte_indices;
missing_default_count = &lte_count; missing_default_count = &lte_count;
...@@ -178,10 +180,10 @@ class SparseBin: public Bin { ...@@ -178,10 +180,10 @@ class SparseBin: public Bin {
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
const VAL_T bin = iterator.InnerRawGet(idx); const VAL_T bin = iterator.InnerRawGet(idx);
if (bin < minb || bin > maxb || t_default_bin == bin) { if (bin == maxb) {
default_indices[(*default_count)++] = idx;
} else if (bin == maxb) {
missing_default_indices[(*missing_default_count)++] = idx; missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) { } else if (bin > th) {
gt_indices[gt_count++] = idx; gt_indices[gt_count++] = idx;
} else { } else {
...@@ -189,19 +191,36 @@ class SparseBin: public Bin { ...@@ -189,19 +191,36 @@ class SparseBin: public Bin {
} }
} }
} else { } else {
if ((default_left && missing_type == MissingType::Zero) || (default_bin <= threshold && missing_type != MissingType::Zero)) { if ((default_left && missing_type == MissingType::Zero)
default_indices = lte_indices; || (default_bin <= threshold && missing_type != MissingType::Zero)) {
default_count = &lte_count; missing_default_indices = lte_indices;
missing_default_count = &lte_count;
} }
for (data_size_t i = 0; i < num_data; ++i) { if (default_bin == most_freq_bin) {
const data_size_t idx = data_indices[i]; for (data_size_t i = 0; i < num_data; ++i) {
const VAL_T bin = iterator.InnerRawGet(idx); const data_size_t idx = data_indices[i];
if (bin < minb || bin > maxb || t_default_bin == bin) { const VAL_T bin = iterator.InnerRawGet(idx);
default_indices[(*default_count)++] = idx; if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
} else if (bin > th) { missing_default_indices[(*missing_default_count)++] = idx;
gt_indices[gt_count++] = idx; } else if (bin > th) {
} else { gt_indices[gt_count++] = idx;
lte_indices[lte_count++] = idx; } else {
lte_indices[lte_count++] = idx;
}
}
} else {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = iterator.InnerRawGet(idx);
if (bin == t_default_bin) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
} }
} }
} }
...@@ -209,7 +228,7 @@ class SparseBin: public Bin { ...@@ -209,7 +228,7 @@ class SparseBin: public Bin {
} }
data_size_t SplitCategorical( data_size_t SplitCategorical(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin,
const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data, const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override { data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; } if (num_data <= 0) { return 0; }
...@@ -218,7 +237,7 @@ class SparseBin: public Bin { ...@@ -218,7 +237,7 @@ class SparseBin: public Bin {
SparseBinIterator<VAL_T> iterator(this, data_indices[0]); SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
data_size_t* default_indices = gt_indices; data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count; data_size_t* default_count = &gt_count;
if (Common::FindInBitset(threshold, num_threahold, default_bin)) { if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
default_indices = lte_indices; default_indices = lte_indices;
default_count = &lte_count; default_count = &lte_count;
} }
...@@ -464,8 +483,8 @@ inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) { ...@@ -464,8 +483,8 @@ inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
} }
template <typename VAL_T> template <typename VAL_T>
BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t default_bin) const { BinIterator* SparseBin<VAL_T>::GetIterator(uint32_t min_bin, uint32_t max_bin, uint32_t most_freq_bin) const {
return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, default_bin); return new SparseBinIterator<VAL_T>(this, min_bin, max_bin, most_freq_bin);
} }
} // namespace LightGBM } // namespace LightGBM
......
...@@ -61,7 +61,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() { ...@@ -61,7 +61,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
int cur_min_machine = static_cast<int>(ArrayArgs<int>::ArgMin(num_bins_distributed)); int cur_min_machine = static_cast<int>(ArrayArgs<int>::ArgMin(num_bins_distributed));
feature_distribution[cur_min_machine].push_back(inner_feature_index); feature_distribution[cur_min_machine].push_back(inner_feature_index);
auto num_bin = this->train_data_->FeatureNumBin(inner_feature_index); auto num_bin = this->train_data_->FeatureNumBin(inner_feature_index);
if (this->train_data_->FeatureBinMapper(inner_feature_index)->GetDefaultBin() == 0) { if (this->train_data_->FeatureBinMapper(inner_feature_index)->GetMostFreqBin() == 0) {
num_bin -= 1; num_bin -= 1;
} }
num_bins_distributed[cur_min_machine] += num_bin; num_bins_distributed[cur_min_machine] += num_bin;
...@@ -79,7 +79,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() { ...@@ -79,7 +79,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
block_len_[i] = 0; block_len_[i] = 0;
for (auto fid : feature_distribution[i]) { for (auto fid : feature_distribution[i]) {
auto num_bin = this->train_data_->FeatureNumBin(fid); auto num_bin = this->train_data_->FeatureNumBin(fid);
if (this->train_data_->FeatureBinMapper(fid)->GetDefaultBin() == 0) { if (this->train_data_->FeatureBinMapper(fid)->GetMostFreqBin() == 0) {
num_bin -= 1; num_bin -= 1;
} }
block_len_[i] += num_bin * sizeof(HistogramBinEntry); block_len_[i] += num_bin * sizeof(HistogramBinEntry);
...@@ -98,7 +98,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() { ...@@ -98,7 +98,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
for (auto fid : feature_distribution[i]) { for (auto fid : feature_distribution[i]) {
buffer_write_start_pos_[fid] = bin_size; buffer_write_start_pos_[fid] = bin_size;
auto num_bin = this->train_data_->FeatureNumBin(fid); auto num_bin = this->train_data_->FeatureNumBin(fid);
if (this->train_data_->FeatureBinMapper(fid)->GetDefaultBin() == 0) { if (this->train_data_->FeatureBinMapper(fid)->GetMostFreqBin() == 0) {
num_bin -= 1; num_bin -= 1;
} }
bin_size += num_bin * sizeof(HistogramBinEntry); bin_size += num_bin * sizeof(HistogramBinEntry);
...@@ -110,7 +110,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() { ...@@ -110,7 +110,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
for (auto fid : feature_distribution[rank_]) { for (auto fid : feature_distribution[rank_]) {
buffer_read_start_pos_[fid] = bin_size; buffer_read_start_pos_[fid] = bin_size;
auto num_bin = this->train_data_->FeatureNumBin(fid); auto num_bin = this->train_data_->FeatureNumBin(fid);
if (this->train_data_->FeatureBinMapper(fid)->GetDefaultBin() == 0) { if (this->train_data_->FeatureBinMapper(fid)->GetMostFreqBin() == 0) {
num_bin -= 1; num_bin -= 1;
} }
bin_size += num_bin * sizeof(HistogramBinEntry); bin_size += num_bin * sizeof(HistogramBinEntry);
......
...@@ -710,7 +710,7 @@ class HistogramPool { ...@@ -710,7 +710,7 @@ class HistogramPool {
feature_metas_[i].missing_type = train_data->FeatureBinMapper(i)->missing_type(); feature_metas_[i].missing_type = train_data->FeatureBinMapper(i)->missing_type();
feature_metas_[i].monotone_type = train_data->FeatureMonotone(i); feature_metas_[i].monotone_type = train_data->FeatureMonotone(i);
feature_metas_[i].penalty = train_data->FeaturePenalte(i); feature_metas_[i].penalty = train_data->FeaturePenalte(i);
if (train_data->FeatureBinMapper(i)->GetDefaultBin() == 0) { if (train_data->FeatureBinMapper(i)->GetMostFreqBin() == 0) {
feature_metas_[i].offset = 1; feature_metas_[i].offset = 1;
} else { } else {
feature_metas_[i].offset = 0; feature_metas_[i].offset = 0;
...@@ -740,7 +740,7 @@ class HistogramPool { ...@@ -740,7 +740,7 @@ class HistogramPool {
offset += static_cast<uint64_t>(train_data->SubFeatureBinOffset(j)); offset += static_cast<uint64_t>(train_data->SubFeatureBinOffset(j));
pool_[i][j].Init(data_[i].data() + offset, &feature_metas_[j]); pool_[i][j].Init(data_[i].data() + offset, &feature_metas_[j]);
auto num_bin = train_data->FeatureNumBin(j); auto num_bin = train_data->FeatureNumBin(j);
if (train_data->FeatureBinMapper(j)->GetDefaultBin() == 0) { if (train_data->FeatureBinMapper(j)->GetMostFreqBin() == 0) {
num_bin -= 1; num_bin -= 1;
} }
offset += static_cast<uint64_t>(num_bin); offset += static_cast<uint64_t>(num_bin);
......
...@@ -74,7 +74,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b ...@@ -74,7 +74,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
feature_metas_[i].missing_type = train_data->FeatureBinMapper(i)->missing_type(); feature_metas_[i].missing_type = train_data->FeatureBinMapper(i)->missing_type();
feature_metas_[i].monotone_type = train_data->FeatureMonotone(i); feature_metas_[i].monotone_type = train_data->FeatureMonotone(i);
feature_metas_[i].penalty = train_data->FeaturePenalte(i); feature_metas_[i].penalty = train_data->FeaturePenalte(i);
if (train_data->FeatureBinMapper(i)->GetDefaultBin() == 0) { if (train_data->FeatureBinMapper(i)->GetMostFreqBin() == 0) {
feature_metas_[i].offset = 1; feature_metas_[i].offset = 1;
} else { } else {
feature_metas_[i].offset = 0; feature_metas_[i].offset = 0;
...@@ -88,7 +88,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b ...@@ -88,7 +88,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
smaller_leaf_histogram_array_global_[j].Init(smaller_leaf_histogram_data_.data() + offset, &feature_metas_[j]); smaller_leaf_histogram_array_global_[j].Init(smaller_leaf_histogram_data_.data() + offset, &feature_metas_[j]);
larger_leaf_histogram_array_global_[j].Init(larger_leaf_histogram_data_.data() + offset, &feature_metas_[j]); larger_leaf_histogram_array_global_[j].Init(larger_leaf_histogram_data_.data() + offset, &feature_metas_[j]);
auto num_bin = train_data->FeatureNumBin(j); auto num_bin = train_data->FeatureNumBin(j);
if (train_data->FeatureBinMapper(j)->GetDefaultBin() == 0) { if (train_data->FeatureBinMapper(j)->GetMostFreqBin() == 0) {
num_bin -= 1; num_bin -= 1;
} }
offset += static_cast<uint64_t>(num_bin); offset += static_cast<uint64_t>(num_bin);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment