Commit 494fef34 authored by Guolin Ke's avatar Guolin Ke
Browse files

finish the expose interface of dataset

parent d41c78f9
...@@ -79,32 +79,6 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr, ...@@ -79,32 +79,6 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
int float_type, int float_type,
uint64_t nindptr, uint64_t nindptr,
uint64_t nelem, uint64_t nelem,
uint64_t num_col,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out);
/*!
* \brief create a dataset from CSC format
* \param col_ptr pointer to col headers
* \param indices findex
* \param data fvalue
* \param float_type 0 for float_32 1 for float_64
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data
* \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
const int32_t* indices,
const void* data,
int float_type,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_row,
const char* parameters, const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out); DatesetHandle* out);
......
...@@ -258,6 +258,15 @@ public: ...@@ -258,6 +258,15 @@ public:
} }
} }
inline void PushOneRow(int tid, data_size_t row_idx, const std::vector<std::pair<int, double>>& feature_values) {
for (auto& inner_data : feature_values) {
int feature_idx = used_feature_map_[inner_data.first];
if (feature_idx >= 0) {
features_[feature_idx]->PushData(tid, row_idx, inner_data.second);
}
}
}
inline void SetNumData(data_size_t num_data) { inline void SetNumData(data_size_t num_data) {
num_data_ = num_data; num_data_ = num_data;
} }
...@@ -266,6 +275,8 @@ public: ...@@ -266,6 +275,8 @@ public:
void SetField(const char* field_name, const void* field_data, data_size_t num_element, int type); void SetField(const char* field_name, const void* field_data, data_size_t num_element, int type);
void GetField(const char* field_name, uint64_t* out_len, const void** out_ptr, int* out_type);
/*! /*!
* \brief Save current dataset into binary file, will save to "filename.bin" * \brief Save current dataset into binary file, will save to "filename.bin"
*/ */
......
...@@ -381,13 +381,14 @@ inline void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t s ...@@ -381,13 +381,14 @@ inline void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t s
} }
inline std::function<std::vector<double>(const void* data, int num_row, int num_col, int row_idx)> inline std::function<std::vector<double>(int row_idx)>
GetRowFunctionFromMat(int float_type, int is_row_major) { GetRowFunctionFromMat(const void* data, int num_row, int num_col, int float_type, int is_row_major) {
if (float_type == 0) { if (float_type == 0) {
const float* dptr = reinterpret_cast<const float*>(data);
if (is_row_major) { if (is_row_major) {
return [](const void* data, int, int num_col, int row_idx) { return [&dptr, &num_col, &num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret; std::vector<double> ret;
const float* dptr = reinterpret_cast<const float*>(data);
dptr += num_col * row_idx; dptr += num_col * row_idx;
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(dptr + i))); ret.push_back(static_cast<double>(*(dptr + i)));
...@@ -395,9 +396,9 @@ GetRowFunctionFromMat(int float_type, int is_row_major) { ...@@ -395,9 +396,9 @@ GetRowFunctionFromMat(int float_type, int is_row_major) {
return ret; return ret;
}; };
} else { } else {
return [](const void* data, int num_row, int num_col, int row_idx) { return [&dptr, &num_col, &num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret; std::vector<double> ret;
const float* dptr = reinterpret_cast<const float*>(data);
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(dptr + num_row * i + row_idx))); ret.push_back(static_cast<double>(*(dptr + num_row * i + row_idx)));
} }
...@@ -405,10 +406,11 @@ GetRowFunctionFromMat(int float_type, int is_row_major) { ...@@ -405,10 +406,11 @@ GetRowFunctionFromMat(int float_type, int is_row_major) {
}; };
} }
} else { } else {
const double* dptr = reinterpret_cast<const double*>(data);
if (is_row_major) { if (is_row_major) {
return [](const void* data, int, int num_col, int row_idx) { return [&dptr, &num_col, &num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret; std::vector<double> ret;
const double* dptr = reinterpret_cast<const double*>(data);
dptr += num_col * row_idx; dptr += num_col * row_idx;
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(dptr + i))); ret.push_back(static_cast<double>(*(dptr + i)));
...@@ -416,9 +418,9 @@ GetRowFunctionFromMat(int float_type, int is_row_major) { ...@@ -416,9 +418,9 @@ GetRowFunctionFromMat(int float_type, int is_row_major) {
return ret; return ret;
}; };
} else { } else {
return [](const void* data, int num_row, int num_col, int row_idx) { return [&dptr, &num_col, &num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret; std::vector<double> ret;
const double* dptr = reinterpret_cast<const double*>(data);
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(dptr + num_row * i + row_idx))); ret.push_back(static_cast<double>(*(dptr + num_row * i + row_idx)));
} }
...@@ -428,6 +430,39 @@ GetRowFunctionFromMat(int float_type, int is_row_major) { ...@@ -428,6 +430,39 @@ GetRowFunctionFromMat(int float_type, int is_row_major) {
} }
} }
inline std::function<std::vector<std::pair<int, double>>(int idx)>
GetRowFunctionFromCSR(const int32_t* indptr, const int32_t* indices, const void* data, int float_type, uint64_t nindptr, uint64_t nelem) {
if (float_type == 0) {
const float* dptr = reinterpret_cast<const float*>(data);
return [&indptr, &indices, &dptr, &nindptr, &nelem](int idx) {
CHECK(idx + 1 < nindptr);
std::vector<std::pair<int, double>> ret;
int32_t start = indptr[idx];
int32_t end = indptr[idx + 1];
CHECK(start >= 0 && end < nelem);
for (int32_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], dptr[i]);
}
return ret;
};
} else {
const double* dptr = reinterpret_cast<const double*>(data);
return [&indptr, &indices, &dptr, &nindptr, &nelem](int idx) {
CHECK(idx + 1 < nindptr);
std::vector<std::pair<int, double>> ret;
int32_t start = indptr[idx];
int32_t end = indptr[idx + 1];
CHECK(start >= 0 && end < nelem);
for (int32_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], dptr[i]);
}
return ret;
};
}
}
} // namespace Common } // namespace Common
} // namespace LightGBM } // namespace LightGBM
......
...@@ -154,21 +154,21 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, ...@@ -154,21 +154,21 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
config.LoadFromString(parameters); config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr); DatasetLoader loader(config.io_config, nullptr);
Dataset* ret = nullptr; Dataset* ret = nullptr;
auto get_row_fun = Common::GetRowFunctionFromMat(float_type, is_row_major); auto get_row_fun = Common::GetRowFunctionFromMat(data, nrow, ncol, float_type, is_row_major);
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(config.io_config.data_random_seed); Random rand(config.io_config.data_random_seed);
const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt); const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_data(ncol); std::vector<std::vector<double>> sample_values(ncol);
for (size_t i = 0; i < sample_indices.size(); i++) { for (size_t i = 0; i < sample_indices.size(); i++) {
auto idx = sample_indices[i]; auto idx = sample_indices[i];
auto row = get_row_fun(data, nrow, ncol, static_cast<int>(idx)); auto row = get_row_fun(static_cast<int>(idx));
for (size_t j = 0; j < row.size(); j++) { for (size_t j = 0; j < row.size(); j++) {
sample_data[j].push_back(row[j]); sample_values[j].push_back(row[j]);
} }
} }
ret = loader.CostructFromSampleData(sample_data, nrow); ret = loader.CostructFromSampleData(sample_values, nrow);
} else { } else {
ret = new Dataset(); ret = new Dataset();
// need to set num_data first // need to set num_data first
...@@ -179,10 +179,121 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, ...@@ -179,10 +179,121 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < nrow; ++i) { for (int i = 0; i < nrow; ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
auto one_row = get_row_fun(data, nrow, ncol, i); auto one_row = get_row_fun(i);
ret->PushOneRow(tid, i, one_row); ret->PushOneRow(tid, i, one_row);
} }
ret->FinishLoad(); ret->FinishLoad();
*out = ret; *out = ret;
return 1; return 0;
}
DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
const int32_t* indices,
const void* data,
int float_type,
uint64_t nindptr,
uint64_t nelem,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out) {
OverallConfig config;
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
Dataset* ret = nullptr;
auto get_row_fun = Common::GetRowFunctionFromCSR(indptr, indices, data, float_type, nindptr, nelem);
int32_t nrow = static_cast<int32_t>(nindptr - 1);
if (reference == nullptr) {
// sample data first
Random rand(config.io_config.data_random_seed);
const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_values;
for (size_t i = 0; i < sample_indices.size(); ++i) {
auto idx = sample_indices[i];
auto row = get_row_fun(static_cast<int>(idx));
// push 0 first, then edit the value according existing feature values
for (auto& feature_values : sample_values) {
feature_values.push_back(0.0);
}
for (std::pair<int, double>& inner_data : row) {
if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
// if need expand feature set
size_t need_size = inner_data.first - sample_values.size() + 1;
for (size_t j = 0; j < need_size; ++j) {
// push i+1 0
sample_values.emplace_back(i + 1, 0.0f);
}
}
// edit the feature value
sample_values[inner_data.first][i] = inner_data.second;
}
}
ret = loader.CostructFromSampleData(sample_values, nrow);
} else {
ret = new Dataset();
// need to set num_data first
ret->SetNumData(nrow);
reinterpret_cast<const Dataset*>(*reference)->CopyFeatureMetadataTo(ret, config.io_config.is_enable_sparse);
}
#pragma omp parallel for schedule(guided)
for (int i = 0; i < nindptr - 1; ++i) {
const int tid = omp_get_thread_num();
auto one_row = get_row_fun(i);
ret->PushOneRow(tid, i, one_row);
}
ret->FinishLoad();
*out = ret;
return 0;
}
DllExport int LGBM_DatasetFree(DatesetHandle* handle) {
auto dataset = reinterpret_cast<Dataset*>(*handle);
delete dataset;
return 0;
}
DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
const char* filename) {
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->SaveBinaryFile(filename);
return 0;
}
DllExport int LGBM_DatasetSetField(DatesetHandle handle,
const char* field_name,
const void* field_data,
uint64_t num_element,
int type) {
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->SetField(field_name, field_data, num_element, type);
return 0;
}
DllExport int LGBM_DatasetGetField(DatesetHandle handle,
const char* field_name,
uint64_t* out_len,
const void** out_ptr,
int* out_type) {
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->GetField(field_name, out_len, out_ptr, out_type);
return 0;
}
DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
uint64_t* out) {
auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_data();
return 0;
}
DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
uint64_t* out) {
auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_total_features();
return 0;
} }
...@@ -63,27 +63,52 @@ void Dataset::SetField(const char* field_name, const void* field_data, data_size ...@@ -63,27 +63,52 @@ void Dataset::SetField(const char* field_name, const void* field_data, data_size
if (type != 0) { if (type != 0) {
Log::Fatal("type of label should be float"); Log::Fatal("type of label should be float");
} }
metadata_.SetLabel(static_cast<const float*>(field_data), num_element); metadata_.SetLabel(reinterpret_cast<const float*>(field_data), num_element);
} else if (name == std::string("weight") || name == std::string("weights")) { } else if (name == std::string("weight") || name == std::string("weights")) {
if (type != 0) { if (type != 0) {
Log::Fatal("type of weights should be float"); Log::Fatal("type of weights should be float");
} }
metadata_.SetWeights(static_cast<const float*>(field_data), num_element); metadata_.SetWeights(reinterpret_cast<const float*>(field_data), num_element);
} else if (name == std::string("init_score")) { } else if (name == std::string("init_score")) {
if (type != 0) { if (type != 0) {
Log::Fatal("type of init_score should be float"); Log::Fatal("type of init_score should be float");
} }
metadata_.SetInitScore(static_cast<const float*>(field_data), num_element); metadata_.SetInitScore(reinterpret_cast<const float*>(field_data), num_element);
} else if (name == std::string("query") || name == std::string("group")) { } else if (name == std::string("query") || name == std::string("group")) {
if (type != 1) { if (type != 1) {
Log::Fatal("type of init_score should be int"); Log::Fatal("type of init_score should be int");
} }
metadata_.SetQueryBoundaries(static_cast<const data_size_t*>(field_data), num_element); metadata_.SetQueryBoundaries(reinterpret_cast<const data_size_t*>(field_data), num_element);
} else { } else {
Log::Fatal("unknow field name: %s", field_name); Log::Fatal("unknow field name: %s", field_name);
} }
} }
void Dataset::GetField(const char* field_name, uint64_t* out_len, const void** out_ptr, int* out_type) {
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) {
*out_ptr = metadata_.label();
*out_len = num_data_;
*out_type = 0;
} else if (name == std::string("weight") || name == std::string("weights")) {
*out_ptr = metadata_.weights();
*out_len = num_data_;
*out_type = 0;
} else if (name == std::string("init_score")) {
*out_ptr = metadata_.init_score();
*out_len = num_data_;
*out_type = 0;
} else if (name == std::string("query") || name == std::string("group")) {
*out_ptr = metadata_.query_boundaries();
*out_len = num_data_;
*out_type = 1;
} else {
Log::Fatal("unknow field name: %s", field_name);
}
}
void Dataset::SaveBinaryFile(const char* bin_filename) { void Dataset::SaveBinaryFile(const char* bin_filename) {
if (!is_loading_from_binfile_) { if (!is_loading_from_binfile_) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment