"tests/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "7076cb8a3ac3a7b32dcf37be5593dddf27bf7f16"
Commit d41c78f9 authored by Guolin Ke's avatar Guolin Ke
Browse files

add load data from mat

parent f3b1daf8
...@@ -73,8 +73,8 @@ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename, ...@@ -73,8 +73,8 @@ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
* \param out created dataset * \param out created dataset
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_CreateDatasetFromCSR(const uint32_t* indptr, DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
const uint32_t* indices, const int32_t* indices,
const void* data, const void* data,
int float_type, int float_type,
uint64_t nindptr, uint64_t nindptr,
...@@ -98,8 +98,8 @@ DllExport int LGBM_CreateDatasetFromCSR(const uint32_t* indptr, ...@@ -98,8 +98,8 @@ DllExport int LGBM_CreateDatasetFromCSR(const uint32_t* indptr,
* \param out created dataset * \param out created dataset
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_CreateDatasetFromCSC(const uint32_t* col_ptr, DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
const uint32_t* indices, const int32_t* indices,
const void* data, const void* data,
int float_type, int float_type,
uint64_t nindptr, uint64_t nindptr,
...@@ -124,10 +124,9 @@ DllExport int LGBM_CreateDatasetFromCSC(const uint32_t* col_ptr, ...@@ -124,10 +124,9 @@ DllExport int LGBM_CreateDatasetFromCSC(const uint32_t* col_ptr,
*/ */
DllExport int LGBM_CreateDatasetFromMat(const void* data, DllExport int LGBM_CreateDatasetFromMat(const void* data,
int float_type, int float_type,
uint64_t nrow, int32_t nrow,
uint64_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
double missing,
const char* parameters, const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out); DatesetHandle* out);
...@@ -299,8 +298,8 @@ DllExport int LGBM_BoosterPredict(BoosterHandle handle, ...@@ -299,8 +298,8 @@ DllExport int LGBM_BoosterPredict(BoosterHandle handle,
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
const uint32_t* indptr, const int32_t* indptr,
const uint32_t* indices, const int32_t* indices,
const void* data, const void* data,
int float_type, int float_type,
uint64_t nindptr, uint64_t nindptr,
...@@ -328,8 +327,8 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -328,8 +327,8 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle, DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
const uint32_t* col_ptr, const int32_t* col_ptr,
const uint32_t* indices, const int32_t* indices,
const void* data, const void* data,
int float_type, int float_type,
uint64_t nindptr, uint64_t nindptr,
...@@ -357,9 +356,8 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -357,9 +356,8 @@ DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
const void* data, const void* data,
int float_type, int float_type,
uint64_t nrow, int32_t nrow,
uint64_t ncol, int32_t ncol,
double missing,
int predict_type, int predict_type,
uint64_t n_used_trees, uint64_t n_used_trees,
const double** out_result); const double** out_result);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <LightGBM/meta.h> #include <LightGBM/meta.h>
#include <LightGBM/config.h> #include <LightGBM/config.h>
#include <LightGBM/feature.h>
#include <vector> #include <vector>
#include <utility> #include <utility>
...@@ -16,8 +17,6 @@ ...@@ -16,8 +17,6 @@
namespace LightGBM { namespace LightGBM {
/*! \brief forward declaration */ /*! \brief forward declaration */
class Feature;
class BinMapper;
class DatasetLoader; class DatasetLoader;
/*! /*!
...@@ -250,6 +249,20 @@ public: ...@@ -250,6 +249,20 @@ public:
/*! \brief Destructor */ /*! \brief Destructor */
~Dataset(); ~Dataset();
inline void PushOneRow(int tid, data_size_t row_idx, const std::vector<double>& feature_values) {
for (size_t i = 0; i < feature_values.size(); ++i) {
int feature_idx = used_feature_map_[i];
if (feature_idx >= 0) {
features_[feature_idx]->PushData(tid, row_idx, feature_values[i]);
}
}
}
inline void SetNumData(data_size_t num_data) {
num_data_ = num_data;
}
void FinishLoad();
void SetField(const char* field_name, const void* field_data, data_size_t num_element, int type); void SetField(const char* field_name, const void* field_data, data_size_t num_element, int type);
......
...@@ -17,13 +17,14 @@ public: ...@@ -17,13 +17,14 @@ public:
Dataset* LoadFromFile(const char* filename, int rank, int num_machines); Dataset* LoadFromFile(const char* filename, int rank, int num_machines);
Dataset* LoadFromFile(const char* filename) { Dataset* LoadFromFile(const char* filename) {
LoadFromFile(filename, 0, 1); return LoadFromFile(filename, 0, 1);
} }
Dataset* LoadFromFileLikeOthers(const char* filename, const Dataset* other); Dataset* LoadFromFileLikeOthers(const char* filename, const Dataset* other);
Dataset* LoadFromBinFile(const char* bin_filename, int rank, int num_machines); Dataset* LoadFromBinFile(const char* bin_filename, int rank, int num_machines);
Dataset* CostructFromSampleData(std::vector<std::vector<double>>& sample_values, data_size_t num_data);
/*! \brief Disable copy */ /*! \brief Disable copy */
DatasetLoader& operator=(const DatasetLoader&) = delete; DatasetLoader& operator=(const DatasetLoader&) = delete;
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <cstdint> #include <cstdint>
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <functional>
namespace LightGBM { namespace LightGBM {
...@@ -380,6 +381,53 @@ inline void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t s ...@@ -380,6 +381,53 @@ inline void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t s
} }
inline std::function<std::vector<double>(const void* data, int num_row, int num_col, int row_idx)>
GetRowFunctionFromMat(int float_type, int is_row_major) {
if (float_type == 0) {
if (is_row_major) {
return [](const void* data, int, int num_col, int row_idx) {
std::vector<double> ret;
const float* dptr = reinterpret_cast<const float*>(data);
dptr += num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(dptr + i)));
}
return ret;
};
} else {
return [](const void* data, int num_row, int num_col, int row_idx) {
std::vector<double> ret;
const float* dptr = reinterpret_cast<const float*>(data);
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(dptr + num_row * i + row_idx)));
}
return ret;
};
}
} else {
if (is_row_major) {
return [](const void* data, int, int num_col, int row_idx) {
std::vector<double> ret;
const double* dptr = reinterpret_cast<const double*>(data);
dptr += num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(dptr + i)));
}
return ret;
};
} else {
return [](const void* data, int num_row, int num_col, int row_idx) {
std::vector<double> ret;
const double* dptr = reinterpret_cast<const double*>(data);
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(dptr + num_row * i + row_idx)));
}
return ret;
};
}
}
}
} // namespace Common } // namespace Common
} // namespace LightGBM } // namespace LightGBM
......
#include <omp.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/c_api.h> #include <LightGBM/c_api.h>
#include <LightGBM/dataset_loader.h>
#include <LightGBM/dataset.h> #include <LightGBM/dataset.h>
#include <LightGBM/boosting.h> #include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h> #include <LightGBM/objective_function.h>
...@@ -103,3 +108,81 @@ private: ...@@ -103,3 +108,81 @@ private:
using namespace LightGBM; using namespace LightGBM;
DllExport const char* LGBM_GetLastError() {
return "Not error msg now, will support soon";
}
DllExport int LGBM_CreateDatasetFromFile(const char* filename,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out) {
OverallConfig config;
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
if (reference == nullptr) {
*out = loader.LoadFromFile(filename);
} else {
*out = loader.LoadFromFileLikeOthers(filename, reinterpret_cast<const Dataset*>(*reference));
}
return 0;
}
DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
DatesetHandle* out) {
OverallConfig config;
DatasetLoader loader(config.io_config, nullptr);
*out = loader.LoadFromBinFile(filename, 0, 1);
return 0;
}
DllExport int LGBM_CreateDatasetFromMat(const void* data,
int float_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out) {
OverallConfig config;
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
Dataset* ret = nullptr;
auto get_row_fun = Common::GetRowFunctionFromMat(float_type, is_row_major);
if (reference == nullptr) {
// sample data first
Random rand(config.io_config.data_random_seed);
const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_data(ncol);
for (size_t i = 0; i < sample_indices.size(); i++) {
auto idx = sample_indices[i];
auto row = get_row_fun(data, nrow, ncol, static_cast<int>(idx));
for (size_t j = 0; j < row.size(); j++) {
sample_data[j].push_back(row[j]);
}
}
ret = loader.CostructFromSampleData(sample_data, nrow);
} else {
ret = new Dataset();
// need to set num_data first
ret->SetNumData(nrow);
reinterpret_cast<const Dataset*>(*reference)->CopyFeatureMetadataTo(ret, config.io_config.is_enable_sparse);
}
#pragma omp parallel for schedule(guided)
for (int i = 0; i < nrow; ++i) {
const int tid = omp_get_thread_num();
auto one_row = get_row_fun(data, nrow, ncol, i);
ret->PushOneRow(tid, i, one_row);
}
ret->FinishLoad();
*out = ret;
return 1;
}
#include <LightGBM/dataset.h> #include <LightGBM/dataset.h>
#include <LightGBM/feature.h> #include <LightGBM/feature.h>
#include <LightGBM/network.h>
#include <omp.h> #include <omp.h>
...@@ -29,6 +28,12 @@ Dataset::~Dataset() { ...@@ -29,6 +28,12 @@ Dataset::~Dataset() {
features_.clear(); features_.clear();
} }
void Dataset::FinishLoad() {
#pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; ++i) {
features_[i]->FinishLoad();
}
}
void Dataset::CopyFeatureMetadataTo(Dataset *dataset, bool is_enable_sparse) const { void Dataset::CopyFeatureMetadataTo(Dataset *dataset, bool is_enable_sparse) const {
dataset->features_.clear(); dataset->features_.clear();
......
...@@ -303,7 +303,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* bin_filename, int rank, int ...@@ -303,7 +303,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* bin_filename, int rank, int
} }
mem_ptr += sizeof(int) * num_used_feature_map; mem_ptr += sizeof(int) * num_used_feature_map;
// get feature names // get feature names
feature_names_.clear(); dataset->feature_names_.clear();
// write feature names // write feature names
for (int i = 0; i < dataset->num_total_features_; ++i) { for (int i = 0; i < dataset->num_total_features_; ++i) {
int str_len = *(reinterpret_cast<const int*>(mem_ptr)); int str_len = *(reinterpret_cast<const int*>(mem_ptr));
...@@ -314,7 +314,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* bin_filename, int rank, int ...@@ -314,7 +314,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* bin_filename, int rank, int
mem_ptr += sizeof(char); mem_ptr += sizeof(char);
str_buf << tmp_char; str_buf << tmp_char;
} }
feature_names_.emplace_back(str_buf.str()); dataset->feature_names_.emplace_back(str_buf.str());
} }
// read size of meta data // read size of meta data
...@@ -406,6 +406,37 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* bin_filename, int rank, int ...@@ -406,6 +406,37 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* bin_filename, int rank, int
return dataset; return dataset;
} }
Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>& sample_values, data_size_t num_data) {
std::vector<BinMapper*> bin_mappers(sample_values.size());
#pragma omp parallel for schedule(guided)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
bin_mappers[i] = new BinMapper();
bin_mappers[i]->FindBin(&sample_values[i], io_config_.max_bin);
}
Dataset* dataset = new Dataset();
dataset->features_.clear();
dataset->num_data_ = num_data;
// -1 means doesn't use this feature
dataset->used_feature_map_ = std::vector<int>(bin_mappers.size(), -1);
dataset->num_total_features_ = static_cast<int>(bin_mappers.size());
for (size_t i = 0; i < bin_mappers.size(); ++i) {
if (!bin_mappers[i]->is_trival()) {
// map real feature index to used feature index
dataset->used_feature_map_[i] = static_cast<int>(dataset->features_.size());
// push new feature
dataset->features_.push_back(new Feature(static_cast<int>(i), bin_mappers[i],
dataset->num_data_, io_config_.is_enable_sparse));
} else {
// if feature is trival(only 1 bin), free spaces
Log::Warning("Ignoring Column_%d , only has one value", i);
delete bin_mappers[i];
}
}
return dataset;
}
// ---- private functions ---- // ---- private functions ----
...@@ -738,11 +769,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat ...@@ -738,11 +769,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
dataset->metadata_.SetInitScore(init_score, dataset->num_data_ * dataset->num_class_); dataset->metadata_.SetInitScore(init_score, dataset->num_data_ * dataset->num_class_);
delete[] init_score; delete[] init_score;
} }
dataset->FinishLoad();
#pragma omp parallel for schedule(guided)
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->features_[i]->FinishLoad();
}
// text data can be free after loaded feature values // text data can be free after loaded feature values
text_data.clear(); text_data.clear();
} }
...@@ -803,11 +830,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* ...@@ -803,11 +830,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
dataset->metadata_.SetInitScore(init_score, dataset->num_data_ * dataset->num_class_); dataset->metadata_.SetInitScore(init_score, dataset->num_data_ * dataset->num_class_);
delete[] init_score; delete[] init_score;
} }
dataset->FinishLoad();
#pragma omp parallel for schedule(guided)
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->features_[i]->FinishLoad();
}
} }
/*! \brief Check can load from binary file */ /*! \brief Check can load from binary file */
......
...@@ -103,50 +103,50 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -103,50 +103,50 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
} }
// check weights // check weights
if (weights_ != nullptr && num_weights_ != num_data_) { if (weights_ != nullptr && num_weights_ != num_data_) {
Log::Fatal("Weights size doesn't match data size");
delete[] weights_; delete[] weights_;
num_weights_ = 0; num_weights_ = 0;
weights_ = nullptr; weights_ = nullptr;
Log::Fatal("Weights size doesn't match data size");
} }
// check query boundries // check query boundries
if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_data_) { if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_data_) {
Log::Fatal("Query size doesn't match data size");
delete[] query_boundaries_; delete[] query_boundaries_;
num_queries_ = 0; num_queries_ = 0;
query_boundaries_ = nullptr; query_boundaries_ = nullptr;
Log::Fatal("Query size doesn't match data size");
} }
// contain initial score file // contain initial score file
if (init_score_ != nullptr && num_init_score_ != num_data_) { if (init_score_ != nullptr && num_init_score_ != num_data_) {
delete[] init_score_; delete[] init_score_;
Log::Fatal("Initial score size doesn't match data size");
init_score_ = nullptr; init_score_ = nullptr;
num_init_score_ = 0; num_init_score_ = 0;
Log::Fatal("Initial score size doesn't match data size");
} }
} else { } else {
data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size()); data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
// check weights // check weights
if (weights_ != nullptr && num_weights_ != num_all_data) { if (weights_ != nullptr && num_weights_ != num_all_data) {
Log::Fatal("Weights size doesn't match data size");
delete[] weights_; delete[] weights_;
num_weights_ = 0; num_weights_ = 0;
weights_ = nullptr; weights_ = nullptr;
Log::Fatal("Weights size doesn't match data size");
} }
// check query boundries // check query boundries
if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_all_data) { if (query_boundaries_ != nullptr && query_boundaries_[num_queries_] != num_all_data) {
Log::Fatal("Query size doesn't match data size");
delete[] query_boundaries_; delete[] query_boundaries_;
num_queries_ = 0; num_queries_ = 0;
query_boundaries_ = nullptr; query_boundaries_ = nullptr;
Log::Fatal("Query size doesn't match data size");
} }
// contain initial score file // contain initial score file
if (init_score_ != nullptr && num_init_score_ != num_all_data) { if (init_score_ != nullptr && num_init_score_ != num_all_data) {
Log::Fatal("Initial score size doesn't match data size");
delete[] init_score_; delete[] init_score_;
num_init_score_ = 0; num_init_score_ = 0;
init_score_ = nullptr; init_score_ = nullptr;
Log::Fatal("Initial score size doesn't match data size");
} }
// get local weights // get local weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment