Commit bcb9613e authored by Guolin Ke's avatar Guolin Ke
Browse files

fix subset bug.

parent 9e2ad71a
...@@ -49,6 +49,30 @@ public: ...@@ -49,6 +49,30 @@ public:
bin_data_.reset(Bin::CreateBin(num_data, num_total_bin_, bin_data_.reset(Bin::CreateBin(num_data, num_total_bin_,
sparse_rate, is_enable_sparse, sparse_threshold, &is_sparse_)); sparse_rate, is_enable_sparse, sparse_threshold, &is_sparse_));
} }
FeatureGroup(int num_feature,
std::vector<std::unique_ptr<BinMapper>>& bin_mappers,
data_size_t num_data, bool is_sparse) : num_feature_(num_feature) {
CHECK(static_cast<int>(bin_mappers.size()) == num_feature);
// use bin at zero to store default_bin
num_total_bin_ = 1;
bin_offsets_.emplace_back(num_total_bin_);
for (int i = 0; i < num_feature_; ++i) {
bin_mappers_.emplace_back(bin_mappers[i].release());
auto num_bin = bin_mappers_[i]->num_bin();
if (bin_mappers_[i]->GetDefaultBin() == 0) {
num_bin -= 1;
}
num_total_bin_ += num_bin;
bin_offsets_.emplace_back(num_total_bin_);
}
is_sparse_ = is_sparse;
if (is_sparse_) {
bin_data_.reset(Bin::CreateSparseBin(num_data, num_total_bin_));
} else {
bin_data_.reset(Bin::CreateDenseBin(num_data, num_total_bin_));
}
}
/*! /*!
* \brief Constructor from memory * \brief Constructor from memory
* \param memory Pointer of memory * \param memory Pointer of memory
......
...@@ -636,6 +636,7 @@ int LGBM_DatasetGetSubset( ...@@ -636,6 +636,7 @@ int LGBM_DatasetGetSubset(
IOConfig io_config; IOConfig io_config;
io_config.Set(param); io_config.Set(param);
auto full_dataset = reinterpret_cast<const Dataset*>(handle); auto full_dataset = reinterpret_cast<const Dataset*>(handle);
CHECK(num_used_row_indices > 0);
auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices)); auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
ret->CopyFeatureMapperFrom(full_dataset); ret->CopyFeatureMapperFrom(full_dataset);
ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true); ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
......
...@@ -131,13 +131,6 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) { ...@@ -131,13 +131,6 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) {
num_features_ = dataset->num_features_; num_features_ = dataset->num_features_;
num_groups_ = dataset->num_groups_; num_groups_ = dataset->num_groups_;
sparse_threshold_ = dataset->sparse_threshold_; sparse_threshold_ = dataset->sparse_threshold_;
bool is_enable_sparse = false;
for (int i = 0; i < num_groups_; ++i) {
if (dataset->feature_groups_[i]->is_sparse_) {
is_enable_sparse = true;
break;
}
}
// copy feature bin mapper data // copy feature bin mapper data
for (int i = 0; i < num_groups_; ++i) { for (int i = 0; i < num_groups_; ++i) {
std::vector<std::unique_ptr<BinMapper>> bin_mappers; std::vector<std::unique_ptr<BinMapper>> bin_mappers;
...@@ -148,8 +141,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) { ...@@ -148,8 +141,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) {
dataset->feature_groups_[i]->num_feature_, dataset->feature_groups_[i]->num_feature_,
bin_mappers, bin_mappers,
num_data_, num_data_,
dataset->sparse_threshold_, dataset->feature_groups_[i]->is_sparse_));
is_enable_sparse));
} }
feature_groups_.shrink_to_fit(); feature_groups_.shrink_to_fit();
used_feature_map_ = dataset->used_feature_map_; used_feature_map_ = dataset->used_feature_map_;
...@@ -233,7 +225,7 @@ void Dataset::ReSize(data_size_t num_data) { ...@@ -233,7 +225,7 @@ void Dataset::ReSize(data_size_t num_data) {
void Dataset::CopySubset(const Dataset* fullset, const data_size_t* used_indices, data_size_t num_used_indices, bool need_meta_data) { void Dataset::CopySubset(const Dataset* fullset, const data_size_t* used_indices, data_size_t num_used_indices, bool need_meta_data) {
CHECK(num_used_indices == num_data_); CHECK(num_used_indices == num_data_);
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int group = 0; group < num_groups_; ++group) { for (int group = 0; group < num_groups_; ++group) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
feature_groups_[group]->CopySubset(fullset->feature_groups_[group].get(), used_indices, num_used_indices); feature_groups_[group]->CopySubset(fullset->feature_groups_[group].get(), used_indices, num_used_indices);
......
...@@ -258,7 +258,7 @@ public: ...@@ -258,7 +258,7 @@ public:
} }
void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override { void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
auto other_bin = reinterpret_cast<const DenseBin<VAL_T>*>(full_bin); auto other_bin = dynamic_cast<const DenseBin<VAL_T>*>(full_bin);
for (int i = 0; i < num_used_indices; ++i) { for (int i = 0; i < num_used_indices; ++i) {
data_[i] = other_bin->data_[used_indices[i]]; data_[i] = other_bin->data_[used_indices[i]];
} }
......
...@@ -316,7 +316,7 @@ public: ...@@ -316,7 +316,7 @@ public:
} }
void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override { void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
auto other_bin = reinterpret_cast<const Dense4bitsBin*>(full_bin); auto other_bin = dynamic_cast<const Dense4bitsBin*>(full_bin);
const data_size_t rest = num_used_indices & 1; const data_size_t rest = num_used_indices & 1;
for (int i = 0; i < num_used_indices - rest; i += 2) { for (int i = 0; i < num_used_indices - rest; i += 2) {
data_size_t idx = used_indices[i]; data_size_t idx = used_indices[i];
......
...@@ -123,14 +123,14 @@ public: ...@@ -123,14 +123,14 @@ public:
} }
inline bool NextNonzero(data_size_t* i_delta, inline bool NextNonzero(data_size_t* i_delta,
data_size_t* cur_pos) const { data_size_t* cur_pos) const {
++(*i_delta); ++(*i_delta);
data_size_t shift = 0; data_size_t shift = 0;
data_size_t delta = deltas_[*i_delta]; data_size_t delta = deltas_[*i_delta];
while (*i_delta < num_vals_ && vals_[*i_delta] == 0) { while (*i_delta < num_vals_ && vals_[*i_delta] == 0) {
++(*i_delta); ++(*i_delta);
shift += 8; shift += 8;
delta |= static_cast<data_size_t>(deltas_[*i_delta]) << shift; delta |= static_cast<data_size_t>(deltas_[*i_delta]) << shift;
} }
*cur_pos += delta; *cur_pos += delta;
if (*i_delta < num_vals_) { if (*i_delta < num_vals_) {
...@@ -252,7 +252,6 @@ public: ...@@ -252,7 +252,6 @@ public:
} }
void GetFastIndex() { void GetFastIndex() {
fast_index_.clear(); fast_index_.clear();
// get shift cnt // get shift cnt
data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex; data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex;
...@@ -333,10 +332,12 @@ public: ...@@ -333,10 +332,12 @@ public:
} }
void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override { void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
auto other_bin = reinterpret_cast<const SparseBin<VAL_T>*>(full_bin); auto other_bin = dynamic_cast<const SparseBin<VAL_T>*>(full_bin);
SparseBinIterator<VAL_T> iterator(other_bin, used_indices[0]); data_size_t start = 0;
deltas_.clear(); if (num_used_indices > 0) {
vals_.clear(); start = used_indices[0];
}
SparseBinIterator<VAL_T> iterator(other_bin, start);
// transform to delta array // transform to delta array
data_size_t last_idx = 0; data_size_t last_idx = 0;
for (data_size_t i = 0; i < num_used_indices; ++i) { for (data_size_t i = 0; i < num_used_indices; ++i) {
...@@ -394,9 +395,15 @@ inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) { ...@@ -394,9 +395,15 @@ inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) {
template <typename VAL_T> template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) { inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
const auto fast_pair = bin_data_->fast_index_[start_idx >> bin_data_->fast_index_shift_]; auto idx = start_idx >> bin_data_->fast_index_shift_;
i_delta_ = fast_pair.first; if (static_cast<size_t>(idx) < bin_data_->fast_index_.size()) {
cur_pos_ = fast_pair.second; const auto fast_pair = bin_data_->fast_index_[start_idx >> bin_data_->fast_index_shift_];
i_delta_ = fast_pair.first;
cur_pos_ = fast_pair.second;
} else {
i_delta_ = -1;
cur_pos_ = 0;
}
} }
template <typename VAL_T> template <typename VAL_T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment