Commit 3b13f95a authored by Guolin Ke's avatar Guolin Ke
Browse files

fix #1153

parent 12f55f75
...@@ -555,23 +555,19 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -555,23 +555,19 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt); int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size()); sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values; std::vector<std::vector<double>> sample_values(num_col);
std::vector<std::vector<int>> sample_idx; std::vector<std::vector<int>> sample_idx(num_col);
for (size_t i = 0; i < sample_indices.size(); ++i) { for (size_t i = 0; i < sample_indices.size(); ++i) {
auto idx = sample_indices[i]; auto idx = sample_indices[i];
auto row = get_row_fun(static_cast<int>(idx)); auto row = get_row_fun(static_cast<int>(idx));
for (std::pair<int, double>& inner_data : row) { for (std::pair<int, double>& inner_data : row) {
if (static_cast<size_t>(inner_data.first) >= sample_values.size()) { CHECK(static_cast<size_t>(inner_data.first) < num_col);
sample_values.resize(inner_data.first + 1);
sample_idx.resize(inner_data.first + 1);
}
if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) { if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
sample_values[inner_data.first].emplace_back(inner_data.second); sample_values[inner_data.first].emplace_back(inner_data.second);
sample_idx[inner_data.first].emplace_back(static_cast<int>(i)); sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
} }
} }
} }
CHECK(num_col >= static_cast<int>(sample_values.size()));
DatasetLoader loader(config.io_config, nullptr, 1, nullptr); DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(), ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx).data(), Common::Vector2Ptr<int>(sample_idx).data(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment