Commit 494fef34 authored by Guolin Ke's avatar Guolin Ke
Browse files

finish the expose interface of dataset

parent d41c78f9
......@@ -79,32 +79,6 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
int float_type,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_col,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out);
/*!
* \brief create a dataset from CSC format
* \param col_ptr pointer to col headers
* \param indices findex
* \param data fvalue
* \param float_type 0 for float_32 1 for float_64
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data
* \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
const int32_t* indices,
const void* data,
int float_type,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_row,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out);
......
......@@ -258,6 +258,15 @@ public:
}
}
inline void PushOneRow(int tid, data_size_t row_idx, const std::vector<std::pair<int, double>>& feature_values) {
for (auto& inner_data : feature_values) {
int feature_idx = used_feature_map_[inner_data.first];
if (feature_idx >= 0) {
features_[feature_idx]->PushData(tid, row_idx, inner_data.second);
}
}
}
inline void SetNumData(data_size_t num_data) {
num_data_ = num_data;
}
......@@ -266,6 +275,8 @@ public:
void SetField(const char* field_name, const void* field_data, data_size_t num_element, int type);
void GetField(const char* field_name, uint64_t* out_len, const void** out_ptr, int* out_type);
/*!
* \brief Save current dataset into binary file, will save to "filename.bin"
*/
......
......@@ -381,13 +381,14 @@ inline void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t s
}
inline std::function<std::vector<double>(const void* data, int num_row, int num_col, int row_idx)>
GetRowFunctionFromMat(int float_type, int is_row_major) {
inline std::function<std::vector<double>(int row_idx)>
GetRowFunctionFromMat(const void* data, int num_row, int num_col, int float_type, int is_row_major) {
if (float_type == 0) {
const float* dptr = reinterpret_cast<const float*>(data);
if (is_row_major) {
return [](const void* data, int, int num_col, int row_idx) {
return [&dptr, &num_col, &num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret;
const float* dptr = reinterpret_cast<const float*>(data);
dptr += num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(dptr + i)));
......@@ -395,9 +396,9 @@ GetRowFunctionFromMat(int float_type, int is_row_major) {
return ret;
};
} else {
return [](const void* data, int num_row, int num_col, int row_idx) {
return [&dptr, &num_col, &num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret;
const float* dptr = reinterpret_cast<const float*>(data);
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(dptr + num_row * i + row_idx)));
}
......@@ -405,10 +406,11 @@ GetRowFunctionFromMat(int float_type, int is_row_major) {
};
}
} else {
const double* dptr = reinterpret_cast<const double*>(data);
if (is_row_major) {
return [](const void* data, int, int num_col, int row_idx) {
return [&dptr, &num_col, &num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret;
const double* dptr = reinterpret_cast<const double*>(data);
dptr += num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(dptr + i)));
......@@ -416,9 +418,9 @@ GetRowFunctionFromMat(int float_type, int is_row_major) {
return ret;
};
} else {
return [](const void* data, int num_row, int num_col, int row_idx) {
return [&dptr, &num_col, &num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret;
const double* dptr = reinterpret_cast<const double*>(data);
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(dptr + num_row * i + row_idx)));
}
......@@ -428,6 +430,39 @@ GetRowFunctionFromMat(int float_type, int is_row_major) {
}
}
inline std::function<std::vector<std::pair<int, double>>(int idx)>
GetRowFunctionFromCSR(const int32_t* indptr, const int32_t* indices, const void* data, int float_type, uint64_t nindptr, uint64_t nelem) {
if (float_type == 0) {
const float* dptr = reinterpret_cast<const float*>(data);
return [&indptr, &indices, &dptr, &nindptr, &nelem](int idx) {
CHECK(idx + 1 < nindptr);
std::vector<std::pair<int, double>> ret;
int32_t start = indptr[idx];
int32_t end = indptr[idx + 1];
CHECK(start >= 0 && end < nelem);
for (int32_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], dptr[i]);
}
return ret;
};
} else {
const double* dptr = reinterpret_cast<const double*>(data);
return [&indptr, &indices, &dptr, &nindptr, &nelem](int idx) {
CHECK(idx + 1 < nindptr);
std::vector<std::pair<int, double>> ret;
int32_t start = indptr[idx];
int32_t end = indptr[idx + 1];
CHECK(start >= 0 && end < nelem);
for (int32_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], dptr[i]);
}
return ret;
};
}
}
} // namespace Common
} // namespace LightGBM
......
......@@ -154,21 +154,21 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
Dataset* ret = nullptr;
auto get_row_fun = Common::GetRowFunctionFromMat(float_type, is_row_major);
auto get_row_fun = Common::GetRowFunctionFromMat(data, nrow, ncol, float_type, is_row_major);
if (reference == nullptr) {
// sample data first
Random rand(config.io_config.data_random_seed);
const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_data(ncol);
std::vector<std::vector<double>> sample_values(ncol);
for (size_t i = 0; i < sample_indices.size(); i++) {
auto idx = sample_indices[i];
auto row = get_row_fun(data, nrow, ncol, static_cast<int>(idx));
auto row = get_row_fun(static_cast<int>(idx));
for (size_t j = 0; j < row.size(); j++) {
sample_data[j].push_back(row[j]);
sample_values[j].push_back(row[j]);
}
}
ret = loader.CostructFromSampleData(sample_data, nrow);
ret = loader.CostructFromSampleData(sample_values, nrow);
} else {
ret = new Dataset();
// need to set num_data first
......@@ -179,10 +179,121 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
#pragma omp parallel for schedule(guided)
for (int i = 0; i < nrow; ++i) {
const int tid = omp_get_thread_num();
auto one_row = get_row_fun(data, nrow, ncol, i);
auto one_row = get_row_fun(i);
ret->PushOneRow(tid, i, one_row);
}
ret->FinishLoad();
*out = ret;
return 1;
return 0;
}
DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
const int32_t* indices,
const void* data,
int float_type,
uint64_t nindptr,
uint64_t nelem,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out) {
OverallConfig config;
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
Dataset* ret = nullptr;
auto get_row_fun = Common::GetRowFunctionFromCSR(indptr, indices, data, float_type, nindptr, nelem);
int32_t nrow = static_cast<int32_t>(nindptr - 1);
if (reference == nullptr) {
// sample data first
Random rand(config.io_config.data_random_seed);
const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_values;
for (size_t i = 0; i < sample_indices.size(); ++i) {
auto idx = sample_indices[i];
auto row = get_row_fun(static_cast<int>(idx));
// push 0 first, then edit the value according existing feature values
for (auto& feature_values : sample_values) {
feature_values.push_back(0.0);
}
for (std::pair<int, double>& inner_data : row) {
if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
// if need expand feature set
size_t need_size = inner_data.first - sample_values.size() + 1;
for (size_t j = 0; j < need_size; ++j) {
// push i+1 0
sample_values.emplace_back(i + 1, 0.0f);
}
}
// edit the feature value
sample_values[inner_data.first][i] = inner_data.second;
}
}
ret = loader.CostructFromSampleData(sample_values, nrow);
} else {
ret = new Dataset();
// need to set num_data first
ret->SetNumData(nrow);
reinterpret_cast<const Dataset*>(*reference)->CopyFeatureMetadataTo(ret, config.io_config.is_enable_sparse);
}
#pragma omp parallel for schedule(guided)
for (int i = 0; i < nindptr - 1; ++i) {
const int tid = omp_get_thread_num();
auto one_row = get_row_fun(i);
ret->PushOneRow(tid, i, one_row);
}
ret->FinishLoad();
*out = ret;
return 0;
}
DllExport int LGBM_DatasetFree(DatesetHandle* handle) {
auto dataset = reinterpret_cast<Dataset*>(*handle);
delete dataset;
return 0;
}
DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
const char* filename) {
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->SaveBinaryFile(filename);
return 0;
}
DllExport int LGBM_DatasetSetField(DatesetHandle handle,
const char* field_name,
const void* field_data,
uint64_t num_element,
int type) {
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->SetField(field_name, field_data, num_element, type);
return 0;
}
DllExport int LGBM_DatasetGetField(DatesetHandle handle,
const char* field_name,
uint64_t* out_len,
const void** out_ptr,
int* out_type) {
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->GetField(field_name, out_len, out_ptr, out_type);
return 0;
}
DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
uint64_t* out) {
auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_data();
return 0;
}
DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
uint64_t* out) {
auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_total_features();
return 0;
}
......@@ -63,27 +63,52 @@ void Dataset::SetField(const char* field_name, const void* field_data, data_size
if (type != 0) {
Log::Fatal("type of label should be float");
}
metadata_.SetLabel(static_cast<const float*>(field_data), num_element);
metadata_.SetLabel(reinterpret_cast<const float*>(field_data), num_element);
} else if (name == std::string("weight") || name == std::string("weights")) {
if (type != 0) {
Log::Fatal("type of weights should be float");
}
metadata_.SetWeights(static_cast<const float*>(field_data), num_element);
metadata_.SetWeights(reinterpret_cast<const float*>(field_data), num_element);
} else if (name == std::string("init_score")) {
if (type != 0) {
Log::Fatal("type of init_score should be float");
}
metadata_.SetInitScore(static_cast<const float*>(field_data), num_element);
metadata_.SetInitScore(reinterpret_cast<const float*>(field_data), num_element);
} else if (name == std::string("query") || name == std::string("group")) {
if (type != 1) {
Log::Fatal("type of init_score should be int");
}
metadata_.SetQueryBoundaries(static_cast<const data_size_t*>(field_data), num_element);
metadata_.SetQueryBoundaries(reinterpret_cast<const data_size_t*>(field_data), num_element);
} else {
Log::Fatal("unknow field name: %s", field_name);
}
}
void Dataset::GetField(const char* field_name, uint64_t* out_len, const void** out_ptr, int* out_type) {
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) {
*out_ptr = metadata_.label();
*out_len = num_data_;
*out_type = 0;
} else if (name == std::string("weight") || name == std::string("weights")) {
*out_ptr = metadata_.weights();
*out_len = num_data_;
*out_type = 0;
} else if (name == std::string("init_score")) {
*out_ptr = metadata_.init_score();
*out_len = num_data_;
*out_type = 0;
} else if (name == std::string("query") || name == std::string("group")) {
*out_ptr = metadata_.query_boundaries();
*out_len = num_data_;
*out_type = 1;
} else {
Log::Fatal("unknow field name: %s", field_name);
}
}
void Dataset::SaveBinaryFile(const char* bin_filename) {
if (!is_loading_from_binfile_) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment