Unverified Commit f94050a4 authored by Scott Votaw's avatar Scott Votaw Committed by GitHub
Browse files

fix: Adjust LGBM_DatasetCreateFromSampledColumn to handle distributed data (#5344)

* Adjust LGBM_DatasetCreateFromSampledColumn to handle distributed data better

* linting fix

* switch to 1 API with breaking change

* Fix pything native call

* more python test fixes
parent 44fe591a
......@@ -118,7 +118,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
* \param ncol Number of columns
* \param num_per_col Size of each sampling column
* \param num_sample_row Number of sampled rows
* \param num_total_row Number of total rows
* \param num_local_row Total number of rows local to machine
* \param num_dist_row Number of total distributed rows
* \param parameters Additional parameters
* \param[out] out Created dataset
* \return 0 when succeed, -1 when failure happens
......@@ -128,7 +129,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
int32_t ncol,
const int* num_per_col,
int32_t num_sample_row,
int32_t num_total_row,
int32_t num_local_row,
int64_t num_dist_row,
const char* parameters,
DatasetHandle* out);
......
......@@ -29,8 +29,12 @@ class DatasetLoader {
LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data);
LIGHTGBM_EXPORT Dataset* ConstructFromSampleData(double** sample_values,
int** sample_indices, int num_col, const int* num_per_col,
size_t total_sample_size, data_size_t num_data);
int** sample_indices,
int num_col,
const int* num_per_col,
size_t total_sample_size,
data_size_t num_local_data,
int64_t num_dist_data);
/*! \brief Disable copy */
DatasetLoader& operator=(const DatasetLoader&) = delete;
......
......@@ -1344,6 +1344,7 @@ class Dataset:
num_per_col_ptr,
ctypes.c_int32(sample_cnt),
ctypes.c_int32(total_nrow),
ctypes.c_int64(total_nrow),
c_str(params_str),
ctypes.byref(self.handle),
))
......
......@@ -974,13 +974,13 @@ int LGBM_DatasetCreateFromFile(const char* filename,
API_END();
}
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
int** sample_indices,
int32_t ncol,
const int* num_per_col,
int32_t num_sample_row,
int32_t num_total_row,
int32_t num_local_row,
int64_t num_dist_row,
const char* parameters,
DatasetHandle* out) {
API_BEGIN();
......@@ -989,13 +989,16 @@ int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
config.Set(param);
OMP_SET_NUM_THREADS(config.num_threads);
DatasetLoader loader(config, nullptr, 1, nullptr);
*out = loader.ConstructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
*out = loader.ConstructFromSampleData(sample_data,
sample_indices,
ncol,
num_per_col,
num_sample_row,
static_cast<data_size_t>(num_total_row));
static_cast<data_size_t>(num_local_row),
num_dist_row);
API_END();
}
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
int64_t num_total_row,
DatasetHandle* out) {
......@@ -1141,7 +1144,9 @@ int LGBM_DatasetCreateFromMats(int32_t nmat,
Vector2Ptr<int>(&sample_idx).data(),
ncol,
VectorSize<double>(sample_values).data(),
sample_cnt, total_nrow));
sample_cnt,
total_nrow,
total_nrow));
} else {
ret.reset(new Dataset(total_nrow));
ret->CreateValid(
......@@ -1216,7 +1221,9 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
Vector2Ptr<int>(&sample_idx).data(),
static_cast<int>(num_col),
VectorSize<double>(sample_values).data(),
sample_cnt, nrow));
sample_cnt,
nrow,
nrow));
} else {
ret.reset(new Dataset(nrow));
ret->CreateValid(
......@@ -1283,7 +1290,9 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
Vector2Ptr<int>(&sample_idx).data(),
static_cast<int>(num_col),
VectorSize<double>(sample_values).data(),
sample_cnt, nrow));
sample_cnt,
nrow,
nrow));
} else {
ret.reset(new Dataset(nrow));
ret->CreateValid(
......@@ -1355,7 +1364,9 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
Vector2Ptr<int>(&sample_idx).data(),
static_cast<int>(sample_values.size()),
VectorSize<double>(sample_values).data(),
sample_cnt, nrow));
sample_cnt,
nrow,
nrow));
} else {
ret.reset(new Dataset(nrow));
ret->CreateValid(
......
......@@ -657,11 +657,14 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
return dataset.release();
}
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
int** sample_indices, int num_col, const int* num_per_col,
size_t total_sample_size, data_size_t num_data) {
CheckSampleSize(total_sample_size, static_cast<size_t>(num_data));
int** sample_indices,
int num_col,
const int* num_per_col,
size_t total_sample_size,
data_size_t num_local_data,
int64_t num_dist_data) {
CheckSampleSize(total_sample_size, static_cast<size_t>(num_dist_data));
int num_total_features = num_col;
if (Network::num_machines() > 1) {
num_total_features = Network::GlobalSyncUpByMax(num_total_features);
......@@ -685,7 +688,7 @@ Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);
const data_size_t filter_cnt = static_cast<data_size_t>(
static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_dist_data);
if (Network::num_machines() == 1) {
// if only one machine, find bin locally
OMP_INIT_EX();
......@@ -806,10 +809,10 @@ Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
}
}
CheckCategoricalFeatureNumBin(bin_mappers, config_.max_bin, config_.max_bin_by_feature);
auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
auto dataset = std::unique_ptr<Dataset>(new Dataset(num_local_data));
dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
if (dataset->has_raw()) {
dataset->ResizeRaw(num_data);
dataset->ResizeRaw(num_local_data);
}
dataset->set_feature_names(feature_names_);
return dataset.release();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment