Commit 4cf9376d authored by Guolin Ke's avatar Guolin Ke
Browse files

support specific num_threads for data loading in c_api.

parent 68980887
......@@ -22,7 +22,6 @@ public:
void ReThrow() {
if (ex_ptr_ != nullptr) {
std::rethrow_exception(ex_ptr_);
ex_ptr_ = nullptr;
}
}
void CaptureException() {
......
......@@ -323,9 +323,12 @@ int LGBM_DatasetCreateFromFile(const char* filename,
DatasetHandle* out) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
DatasetLoader loader(io_config, nullptr, 1, filename);
OverallConfig config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
DatasetLoader loader(config.io_config, nullptr, 1, filename);
if (reference == nullptr) {
*out = loader.LoadFromFile(filename);
} else {
......@@ -346,9 +349,12 @@ int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
DatasetHandle* out) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
DatasetLoader loader(io_config, nullptr, 1, nullptr);
OverallConfig config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
*out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
num_sample_row,
static_cast<data_size_t>(num_total_row));
......@@ -433,14 +439,17 @@ int LGBM_DatasetCreateFromMat(const void* data,
DatasetHandle* out) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
OverallConfig config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
std::unique_ptr<Dataset> ret;
auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
if (reference == nullptr) {
// sample data first
Random rand(io_config.data_random_seed);
int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Random rand(config.io_config.data_random_seed);
int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(ncol);
......@@ -455,7 +464,7 @@ int LGBM_DatasetCreateFromMat(const void* data,
}
}
}
DatasetLoader loader(io_config, nullptr, 1, nullptr);
DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx).data(),
static_cast<int>(sample_values.size()),
......@@ -494,15 +503,18 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
DatasetHandle* out) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
OverallConfig config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
std::unique_ptr<Dataset> ret;
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int32_t nrow = static_cast<int32_t>(nindptr - 1);
if (reference == nullptr) {
// sample data first
Random rand(io_config.data_random_seed);
int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Random rand(config.io_config.data_random_seed);
int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values;
......@@ -522,7 +534,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
}
}
CHECK(num_col >= static_cast<int>(sample_values.size()));
DatasetLoader loader(io_config, nullptr, 1, nullptr);
DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx).data(),
static_cast<int>(sample_values.size()),
......@@ -561,14 +573,17 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
DatasetHandle* out) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
OverallConfig config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
std::unique_ptr<Dataset> ret;
int32_t nrow = static_cast<int32_t>(num_row);
if (reference == nullptr) {
// sample data first
Random rand(io_config.data_random_seed);
int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Random rand(config.io_config.data_random_seed);
int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
......@@ -588,7 +603,7 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
DatasetLoader loader(io_config, nullptr, 1, nullptr);
DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx).data(),
static_cast<int>(sample_values.size()),
......@@ -633,8 +648,11 @@ int LGBM_DatasetGetSubset(
DatasetHandle* out) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
OverallConfig config;
config.Set(param);
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
auto full_dataset = reinterpret_cast<const Dataset*>(handle);
CHECK(num_used_row_indices > 0);
auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
......
......@@ -227,7 +227,6 @@ void Dataset::Construct(
auto features_in_group = NoGroup(used_features);
if (io_config.enable_bundle) {
std::chrono::duration<double, std::milli> bundling_time_;
features_in_group = FastFeatureBundling(bin_mappers,
sample_non_zero_indices, num_per_col, total_sample_cnt,
used_features, io_config.max_conflict_rate,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment