Commit bcb9613e authored by Guolin Ke's avatar Guolin Ke
Browse files

fix subset bug.

parent 9e2ad71a
......@@ -49,6 +49,30 @@ public:
bin_data_.reset(Bin::CreateBin(num_data, num_total_bin_,
sparse_rate, is_enable_sparse, sparse_threshold, &is_sparse_));
}
FeatureGroup(int num_feature,
std::vector<std::unique_ptr<BinMapper>>& bin_mappers,
data_size_t num_data, bool is_sparse) : num_feature_(num_feature) {
CHECK(static_cast<int>(bin_mappers.size()) == num_feature);
// use bin at zero to store default_bin
num_total_bin_ = 1;
bin_offsets_.emplace_back(num_total_bin_);
for (int i = 0; i < num_feature_; ++i) {
bin_mappers_.emplace_back(bin_mappers[i].release());
auto num_bin = bin_mappers_[i]->num_bin();
if (bin_mappers_[i]->GetDefaultBin() == 0) {
num_bin -= 1;
}
num_total_bin_ += num_bin;
bin_offsets_.emplace_back(num_total_bin_);
}
is_sparse_ = is_sparse;
if (is_sparse_) {
bin_data_.reset(Bin::CreateSparseBin(num_data, num_total_bin_));
} else {
bin_data_.reset(Bin::CreateDenseBin(num_data, num_total_bin_));
}
}
/*!
* \brief Constructor from memory
* \param memory Pointer of memory
......
......@@ -636,6 +636,7 @@ int LGBM_DatasetGetSubset(
IOConfig io_config;
io_config.Set(param);
auto full_dataset = reinterpret_cast<const Dataset*>(handle);
CHECK(num_used_row_indices > 0);
auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
ret->CopyFeatureMapperFrom(full_dataset);
ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
......
......@@ -131,13 +131,6 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) {
num_features_ = dataset->num_features_;
num_groups_ = dataset->num_groups_;
sparse_threshold_ = dataset->sparse_threshold_;
bool is_enable_sparse = false;
for (int i = 0; i < num_groups_; ++i) {
if (dataset->feature_groups_[i]->is_sparse_) {
is_enable_sparse = true;
break;
}
}
// copy feature bin mapper data
for (int i = 0; i < num_groups_; ++i) {
std::vector<std::unique_ptr<BinMapper>> bin_mappers;
......@@ -148,8 +141,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) {
dataset->feature_groups_[i]->num_feature_,
bin_mappers,
num_data_,
dataset->sparse_threshold_,
is_enable_sparse));
dataset->feature_groups_[i]->is_sparse_));
}
feature_groups_.shrink_to_fit();
used_feature_map_ = dataset->used_feature_map_;
......@@ -233,7 +225,7 @@ void Dataset::ReSize(data_size_t num_data) {
void Dataset::CopySubset(const Dataset* fullset, const data_size_t* used_indices, data_size_t num_used_indices, bool need_meta_data) {
CHECK(num_used_indices == num_data_);
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static)
for (int group = 0; group < num_groups_; ++group) {
OMP_LOOP_EX_BEGIN();
feature_groups_[group]->CopySubset(fullset->feature_groups_[group].get(), used_indices, num_used_indices);
......
......@@ -258,7 +258,7 @@ public:
}
void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
auto other_bin = reinterpret_cast<const DenseBin<VAL_T>*>(full_bin);
auto other_bin = dynamic_cast<const DenseBin<VAL_T>*>(full_bin);
for (int i = 0; i < num_used_indices; ++i) {
data_[i] = other_bin->data_[used_indices[i]];
}
......
......@@ -316,7 +316,7 @@ public:
}
void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
auto other_bin = reinterpret_cast<const Dense4bitsBin*>(full_bin);
auto other_bin = dynamic_cast<const Dense4bitsBin*>(full_bin);
const data_size_t rest = num_used_indices & 1;
for (int i = 0; i < num_used_indices - rest; i += 2) {
data_size_t idx = used_indices[i];
......
......@@ -252,7 +252,6 @@ public:
}
void GetFastIndex() {
fast_index_.clear();
// get shift cnt
data_size_t mod_size = (num_data_ + kNumFastIndex - 1) / kNumFastIndex;
......@@ -333,10 +332,12 @@ public:
}
void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
auto other_bin = reinterpret_cast<const SparseBin<VAL_T>*>(full_bin);
SparseBinIterator<VAL_T> iterator(other_bin, used_indices[0]);
deltas_.clear();
vals_.clear();
auto other_bin = dynamic_cast<const SparseBin<VAL_T>*>(full_bin);
data_size_t start = 0;
if (num_used_indices > 0) {
start = used_indices[0];
}
SparseBinIterator<VAL_T> iterator(other_bin, start);
// transform to delta array
data_size_t last_idx = 0;
for (data_size_t i = 0; i < num_used_indices; ++i) {
......@@ -394,9 +395,15 @@ inline VAL_T SparseBinIterator<VAL_T>::InnerRawGet(data_size_t idx) {
template <typename VAL_T>
inline void SparseBinIterator<VAL_T>::Reset(data_size_t start_idx) {
auto idx = start_idx >> bin_data_->fast_index_shift_;
if (static_cast<size_t>(idx) < bin_data_->fast_index_.size()) {
const auto fast_pair = bin_data_->fast_index_[start_idx >> bin_data_->fast_index_shift_];
i_delta_ = fast_pair.first;
cur_pos_ = fast_pair.second;
} else {
i_delta_ = -1;
cur_pos_ = 0;
}
}
template <typename VAL_T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment