Commit d1ac71e1 authored by Guolin Ke's avatar Guolin Ke
Browse files

add some interface in dataset

parent b953cd58
...@@ -142,7 +142,7 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle, ...@@ -142,7 +142,7 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
DllExport int LGBM_DatasetSetField(DatesetHandle handle, DllExport int LGBM_DatasetSetField(DatesetHandle handle,
const char* field_name, const char* field_name,
const void* field_data, const void* field_data,
uint64_t field_len, uint64_t num_element,
int type); int type);
/*! /*!
......
...@@ -17,6 +17,7 @@ namespace LightGBM { ...@@ -17,6 +17,7 @@ namespace LightGBM {
/*! \brief forward declaration */ /*! \brief forward declaration */
class Feature; class Feature;
class BinMapper;
/*! /*!
* \brief This class is used to store some meta(non-feature) data for training data, * \brief This class is used to store some meta(non-feature) data for training data,
...@@ -79,6 +80,13 @@ public: ...@@ -79,6 +80,13 @@ public:
void CheckOrPartition(data_size_t num_all_data, void CheckOrPartition(data_size_t num_all_data,
const std::vector<data_size_t>& used_data_indices); const std::vector<data_size_t>& used_data_indices);
void SetLabel(const float* label, data_size_t len);
void SetWeights(const float* weights, data_size_t len);
void SetQueryBoundaries(const data_size_t* QueryBoundaries, data_size_t len);
/*! /*!
* \brief Set initial scores * \brief Set initial scores
* \param init_score Initial scores, this class will manage memory for init_score. * \param init_score Initial scores, this class will manage memory for init_score.
...@@ -188,8 +196,6 @@ private: ...@@ -188,8 +196,6 @@ private:
data_size_t num_weights_; data_size_t num_weights_;
/*! \brief Label data */ /*! \brief Label data */
float* label_; float* label_;
/*! \brief Label data, int type */
int16_t* label_int_;
/*! \brief Weights data */ /*! \brief Weights data */
float* weights_; float* weights_;
/*! \brief Query boundaries */ /*! \brief Query boundaries */
...@@ -272,6 +278,14 @@ public: ...@@ -272,6 +278,14 @@ public:
/*! \brief Destructor */ /*! \brief Destructor */
~Dataset(); ~Dataset();
/*! \brief Init Dataset with specific binmapper */
void InitByBinMapper(std::vector<const BinMapper*> bin_mappers, data_size_t num_data);
/*! \brief push raw data into dataset */
void PushData(const std::vector<std::vector<std::pair<int, float>>>& datas, data_size_t start_idx, bool is_finished);
void SetField(const char* field_name, const void* field_data, data_size_t num_element, int type);
/*! /*!
* \brief Load training data on parallel training * \brief Load training data on parallel training
* \param rank Rank of local machine * \param rank Rank of local machine
...@@ -311,6 +325,8 @@ public: ...@@ -311,6 +325,8 @@ public:
*/ */
void SaveBinaryFile(const char* bin_filename); void SaveBinaryFile(const char* bin_filename);
std::vector<const BinMapper*> GetBinMappers() const;
/*! /*!
* \brief Get a feature pointer for specific index * \brief Get a feature pointer for specific index
* \param i Index for feature * \param i Index for feature
......
...@@ -285,6 +285,83 @@ void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partiti ...@@ -285,6 +285,83 @@ void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partiti
} }
} }
void Dataset::InitByBinMapper(std::vector<const BinMapper*> bin_mappers, data_size_t num_data) {
num_data_ = num_data;
global_num_data_ = num_data_;
// initialize label
metadata_.Init(num_data_, -1, -1);
// free old memory
for (auto& feature : features_) {
delete feature;
}
features_.clear();
used_feature_map_ = std::vector<int>(bin_mappers.size(), -1);
for (size_t i = 0; i < bin_mappers.size(); ++i) {
if (bin_mappers[i] != nullptr) {
features_.push_back(new Feature(static_cast<int>(i), new BinMapper(bin_mappers[i]), num_data_, is_enable_sparse_));
used_feature_map_[i] = static_cast<int>(features_.size());
}
}
num_features_ = static_cast<int>(features_.size());
}
std::vector<const BinMapper*> Dataset::GetBinMappers() const {
std::vector<const BinMapper*> ret(num_total_features_, nullptr);
for (const auto feature : features_) {
ret[feature->feature_index()] = feature->bin_mapper();
}
return ret;
}
void Dataset::PushData(const std::vector<std::vector<std::pair<int, float>>>& datas, data_size_t start_idx, bool is_finished) {
// if doesn't need to prediction with initial model
#pragma omp parallel for schedule(guided)
for (data_size_t i = 0; i < static_cast<int>(datas.size()); ++i) {
const int tid = omp_get_thread_num();
for (auto& inner_data : datas[i]) {
int feature_idx = used_feature_map_[inner_data.first];
if (feature_idx >= 0) {
// if is used feature
features_[feature_idx]->PushData(tid, start_idx + i, inner_data.second);
}
}
}
if (is_finished) {
#pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; ++i) {
features_[i]->FinishLoad();
}
}
}
void Dataset::SetField(const char* field_name, const void* field_data, data_size_t num_element, int type) {
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) {
if (type != 0) {
Log::Fatal("type of label should be float");
}
metadata_.SetLabel(static_cast<const float*>(field_data), num_element);
}else if (name == std::string("weight") || name == std::string("weights")) {
if (type != 0) {
Log::Fatal("type of weights should be float");
}
metadata_.SetWeights(static_cast<const float*>(field_data), num_element);
} else if (name == std::string("init_score")) {
if (type != 0) {
Log::Fatal("type of init_score should be float");
}
metadata_.SetInitScore(static_cast<const float*>(field_data), num_element);
} else if (name == std::string("query") || name == std::string("group")) {
if (type != 1) {
Log::Fatal("type of init_score should be int");
}
metadata_.SetQueryBoundaries(static_cast<const data_size_t*>(field_data), num_element);
} else {
Log::Fatal("unknow field name: %s", field_name);
}
}
void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<std::string>& sample_data) { void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<std::string>& sample_data) {
// sample_values[i][j], means the value of j-th sample on i-th feature // sample_values[i][j], means the value of j-th sample on i-th feature
std::vector<std::vector<float>> sample_values; std::vector<std::vector<float>> sample_values;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace LightGBM { namespace LightGBM {
Metadata::Metadata() Metadata::Metadata()
:label_(nullptr), label_int_(nullptr), weights_(nullptr), :label_(nullptr), weights_(nullptr),
query_boundaries_(nullptr), query_boundaries_(nullptr),
query_weights_(nullptr), init_score_(nullptr), queries_(nullptr){ query_weights_(nullptr), init_score_(nullptr), queries_(nullptr){
...@@ -225,6 +225,48 @@ void Metadata::SetInitScore(const float* init_score, data_size_t len) { ...@@ -225,6 +225,48 @@ void Metadata::SetInitScore(const float* init_score, data_size_t len) {
} }
} }
void Metadata::SetLabel(const float* label, data_size_t len) {
if (num_data_ != len) {
Log::Fatal("len of label is not same with #data");
}
if (label_ != nullptr) { delete[] label_; }
label_ = new float[num_data_];
for (data_size_t i = 0; i < num_data_; ++i) {
label_[i] = label[i];
}
}
void Metadata::SetWeights(const float* weights, data_size_t len) {
if (num_data_ != len) {
Log::Fatal("len of weights is not same with #data");
}
if (weights_ != nullptr) { delete[] weights_; }
num_weights_ = num_data_;
weights_ = new float[num_weights_];
for (data_size_t i = 0; i < num_weights_; ++i) {
weights_[i] = weights[i];
}
LoadQueryWeights();
}
void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size_t len) {
data_size_t sum = 0;
for (data_size_t i = 0; i < len; ++i) {
sum += query_boundaries[i];
}
if (num_data_ != sum) {
Log::Fatal("sum of query counts is not same with #data");
}
if (query_boundaries_ != nullptr) { delete[] query_boundaries_; }
num_queries_ = len;
query_boundaries_ = new data_size_t[num_queries_];
for (data_size_t i = 0; i < num_queries_; ++i) {
query_boundaries_[i] = query_boundaries[i];
}
LoadQueryWeights();
}
void Metadata::LoadWeights() { void Metadata::LoadWeights() {
num_weights_ = 0; num_weights_ = 0;
std::string weight_filename(data_filename_); std::string weight_filename(data_filename_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment