Unverified Commit f94050a4 authored by Scott Votaw's avatar Scott Votaw Committed by GitHub
Browse files

fix: Adjust LGBM_DatasetCreateFromSampledColumn to handle distributed data (#5344)

* Adjust LGBM_DatasetCreateFromSampledColumn to handle distributed data better

* linting fix

* switch to 1 API with breaking change

* Fix pything native call

* more python test fixes
parent 44fe591a
...@@ -118,7 +118,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -118,7 +118,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
* \param ncol Number of columns * \param ncol Number of columns
* \param num_per_col Size of each sampling column * \param num_per_col Size of each sampling column
* \param num_sample_row Number of sampled rows * \param num_sample_row Number of sampled rows
* \param num_total_row Number of total rows * \param num_local_row Total number of rows local to machine
* \param num_dist_row Number of total distributed rows
* \param parameters Additional parameters * \param parameters Additional parameters
* \param[out] out Created dataset * \param[out] out Created dataset
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
...@@ -128,7 +129,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledColumn(double** sample_data, ...@@ -128,7 +129,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
int32_t ncol, int32_t ncol,
const int* num_per_col, const int* num_per_col,
int32_t num_sample_row, int32_t num_sample_row,
int32_t num_total_row, int32_t num_local_row,
int64_t num_dist_row,
const char* parameters, const char* parameters,
DatasetHandle* out); DatasetHandle* out);
......
...@@ -29,8 +29,12 @@ class DatasetLoader { ...@@ -29,8 +29,12 @@ class DatasetLoader {
LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data); LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data);
LIGHTGBM_EXPORT Dataset* ConstructFromSampleData(double** sample_values, LIGHTGBM_EXPORT Dataset* ConstructFromSampleData(double** sample_values,
int** sample_indices, int num_col, const int* num_per_col, int** sample_indices,
size_t total_sample_size, data_size_t num_data); int num_col,
const int* num_per_col,
size_t total_sample_size,
data_size_t num_local_data,
int64_t num_dist_data);
/*! \brief Disable copy */ /*! \brief Disable copy */
DatasetLoader& operator=(const DatasetLoader&) = delete; DatasetLoader& operator=(const DatasetLoader&) = delete;
......
...@@ -1344,6 +1344,7 @@ class Dataset: ...@@ -1344,6 +1344,7 @@ class Dataset:
num_per_col_ptr, num_per_col_ptr,
ctypes.c_int32(sample_cnt), ctypes.c_int32(sample_cnt),
ctypes.c_int32(total_nrow), ctypes.c_int32(total_nrow),
ctypes.c_int64(total_nrow),
c_str(params_str), c_str(params_str),
ctypes.byref(self.handle), ctypes.byref(self.handle),
)) ))
......
...@@ -974,13 +974,13 @@ int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -974,13 +974,13 @@ int LGBM_DatasetCreateFromFile(const char* filename,
API_END(); API_END();
} }
int LGBM_DatasetCreateFromSampledColumn(double** sample_data, int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
int** sample_indices, int** sample_indices,
int32_t ncol, int32_t ncol,
const int* num_per_col, const int* num_per_col,
int32_t num_sample_row, int32_t num_sample_row,
int32_t num_total_row, int32_t num_local_row,
int64_t num_dist_row,
const char* parameters, const char* parameters,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
...@@ -989,13 +989,16 @@ int LGBM_DatasetCreateFromSampledColumn(double** sample_data, ...@@ -989,13 +989,16 @@ int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
config.Set(param); config.Set(param);
OMP_SET_NUM_THREADS(config.num_threads); OMP_SET_NUM_THREADS(config.num_threads);
DatasetLoader loader(config, nullptr, 1, nullptr); DatasetLoader loader(config, nullptr, 1, nullptr);
*out = loader.ConstructFromSampleData(sample_data, sample_indices, ncol, num_per_col, *out = loader.ConstructFromSampleData(sample_data,
sample_indices,
ncol,
num_per_col,
num_sample_row, num_sample_row,
static_cast<data_size_t>(num_total_row)); static_cast<data_size_t>(num_local_row),
num_dist_row);
API_END(); API_END();
} }
int LGBM_DatasetCreateByReference(const DatasetHandle reference, int LGBM_DatasetCreateByReference(const DatasetHandle reference,
int64_t num_total_row, int64_t num_total_row,
DatasetHandle* out) { DatasetHandle* out) {
...@@ -1141,7 +1144,9 @@ int LGBM_DatasetCreateFromMats(int32_t nmat, ...@@ -1141,7 +1144,9 @@ int LGBM_DatasetCreateFromMats(int32_t nmat,
Vector2Ptr<int>(&sample_idx).data(), Vector2Ptr<int>(&sample_idx).data(),
ncol, ncol,
VectorSize<double>(sample_values).data(), VectorSize<double>(sample_values).data(),
sample_cnt, total_nrow)); sample_cnt,
total_nrow,
total_nrow));
} else { } else {
ret.reset(new Dataset(total_nrow)); ret.reset(new Dataset(total_nrow));
ret->CreateValid( ret->CreateValid(
...@@ -1216,7 +1221,9 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -1216,7 +1221,9 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
Vector2Ptr<int>(&sample_idx).data(), Vector2Ptr<int>(&sample_idx).data(),
static_cast<int>(num_col), static_cast<int>(num_col),
VectorSize<double>(sample_values).data(), VectorSize<double>(sample_values).data(),
sample_cnt, nrow)); sample_cnt,
nrow,
nrow));
} else { } else {
ret.reset(new Dataset(nrow)); ret.reset(new Dataset(nrow));
ret->CreateValid( ret->CreateValid(
...@@ -1283,7 +1290,9 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, ...@@ -1283,7 +1290,9 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
Vector2Ptr<int>(&sample_idx).data(), Vector2Ptr<int>(&sample_idx).data(),
static_cast<int>(num_col), static_cast<int>(num_col),
VectorSize<double>(sample_values).data(), VectorSize<double>(sample_values).data(),
sample_cnt, nrow)); sample_cnt,
nrow,
nrow));
} else { } else {
ret.reset(new Dataset(nrow)); ret.reset(new Dataset(nrow));
ret->CreateValid( ret->CreateValid(
...@@ -1355,7 +1364,9 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -1355,7 +1364,9 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
Vector2Ptr<int>(&sample_idx).data(), Vector2Ptr<int>(&sample_idx).data(),
static_cast<int>(sample_values.size()), static_cast<int>(sample_values.size()),
VectorSize<double>(sample_values).data(), VectorSize<double>(sample_values).data(),
sample_cnt, nrow)); sample_cnt,
nrow,
nrow));
} else { } else {
ret.reset(new Dataset(nrow)); ret.reset(new Dataset(nrow));
ret->CreateValid( ret->CreateValid(
......
...@@ -657,11 +657,14 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -657,11 +657,14 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
return dataset.release(); return dataset.release();
} }
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values, Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
int** sample_indices, int num_col, const int* num_per_col, int** sample_indices,
size_t total_sample_size, data_size_t num_data) { int num_col,
CheckSampleSize(total_sample_size, static_cast<size_t>(num_data)); const int* num_per_col,
size_t total_sample_size,
data_size_t num_local_data,
int64_t num_dist_data) {
CheckSampleSize(total_sample_size, static_cast<size_t>(num_dist_data));
int num_total_features = num_col; int num_total_features = num_col;
if (Network::num_machines() > 1) { if (Network::num_machines() > 1) {
num_total_features = Network::GlobalSyncUpByMax(num_total_features); num_total_features = Network::GlobalSyncUpByMax(num_total_features);
...@@ -685,7 +688,7 @@ Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values, ...@@ -685,7 +688,7 @@ Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_); std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);
const data_size_t filter_cnt = static_cast<data_size_t>( const data_size_t filter_cnt = static_cast<data_size_t>(
static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data); static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_dist_data);
if (Network::num_machines() == 1) { if (Network::num_machines() == 1) {
// if only one machine, find bin locally // if only one machine, find bin locally
OMP_INIT_EX(); OMP_INIT_EX();
...@@ -806,10 +809,10 @@ Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values, ...@@ -806,10 +809,10 @@ Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
} }
} }
CheckCategoricalFeatureNumBin(bin_mappers, config_.max_bin, config_.max_bin_by_feature); CheckCategoricalFeatureNumBin(bin_mappers, config_.max_bin, config_.max_bin_by_feature);
auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data)); auto dataset = std::unique_ptr<Dataset>(new Dataset(num_local_data));
dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_); dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
if (dataset->has_raw()) { if (dataset->has_raw()) {
dataset->ResizeRaw(num_data); dataset->ResizeRaw(num_local_data);
} }
dataset->set_feature_names(feature_names_); dataset->set_feature_names(feature_names_);
return dataset.release(); return dataset.release();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment