"googlemock/include/gmock/vscode:/vscode.git/clone" did not exist on "e935e6c387cdf541f73b2cbbbe02e651c12887a9"
Commit 5442ed78 authored by Guolin Ke's avatar Guolin Ke Committed by xuehui
Browse files

Refactor for RAII (#86)

* RAII for utils, application and c_api(partical)

* raii for class in include folder

* raii for application and boosting

* raii for dataset and dataset loader

* raii for dense bin and parser

* RAII refactor for almost all classes

* RAII for c_api

* clean code

* refine repeated code

* Decouple the "sigmoid" between objective and boosting.

* change std::vector<bool> back to std::vector<char> due to concurrence problem

* slight reduce some memory cost
parent 3586673a
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <LightGBM/config.h> #include <LightGBM/config.h>
#include <vector> #include <vector>
#include <memory>
namespace LightGBM { namespace LightGBM {
...@@ -60,20 +61,18 @@ private: ...@@ -60,20 +61,18 @@ private:
/*! \brief All configs */ /*! \brief All configs */
OverallConfig config_; OverallConfig config_;
/*! \brief Dataset loader */
DatasetLoader* dataset_loader_;
/*! \brief Training data */ /*! \brief Training data */
Dataset* train_data_; std::unique_ptr<Dataset> train_data_;
/*! \brief Validation data */ /*! \brief Validation data */
std::vector<Dataset*> valid_datas_; std::vector<std::unique_ptr<Dataset>> valid_datas_;
/*! \brief Metric for training data */ /*! \brief Metric for training data */
std::vector<Metric*> train_metric_; std::vector<std::unique_ptr<Metric>> train_metric_;
/*! \brief Metrics for validation data */ /*! \brief Metrics for validation data */
std::vector<std::vector<Metric*>> valid_metrics_; std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
/*! \brief Boosting object */ /*! \brief Boosting object */
Boosting* boosting_; std::unique_ptr<Boosting> boosting_;
/*! \brief Training objective function */ /*! \brief Training objective function */
ObjectiveFunction* objective_fun_; std::unique_ptr<ObjectiveFunction> objective_fun_;
}; };
......
...@@ -111,7 +111,7 @@ private: ...@@ -111,7 +111,7 @@ private:
/*! \brief Number of bins */ /*! \brief Number of bins */
int num_bin_; int num_bin_;
/*! \brief Store upper bound for each bin */ /*! \brief Store upper bound for each bin */
double* bin_upper_bound_; std::vector<double> bin_upper_bound_;
/*! \brief True if this feature is trival */ /*! \brief True if this feature is trival */
bool is_trival_; bool is_trival_;
/*! \brief Sparse rate of this bins( num_bin0/num_data ) */ /*! \brief Sparse rate of this bins( num_bin0/num_data ) */
...@@ -133,11 +133,11 @@ public: ...@@ -133,11 +133,11 @@ public:
/*! /*!
* \brief Initialization logic. * \brief Initialization logic.
* \param used_indices If used_indices==nullptr means using all data, otherwise, used_indices[i] != 0 means i-th data is used * \param used_indices If used_indices.size() == 0 means using all data, otherwise, used_indices[i] == true means i-th data is used
(this logic was build for bagging logic) (this logic was build for bagging logic)
* \param num_leaves Number of leaves on this iteration * \param num_leaves Number of leaves on this iteration
*/ */
virtual void Init(const char* used_indices, data_size_t num_leaves) = 0; virtual void Init(const char* used_idices, data_size_t num_leaves) = 0;
/*! /*!
* \brief Construct histogram by using this bin * \brief Construct histogram by using this bin
...@@ -155,7 +155,7 @@ public: ...@@ -155,7 +155,7 @@ public:
* \brief Split current bin, and perform re-order by leaf * \brief Split current bin, and perform re-order by leaf
* \param leaf Using which leaf's to split * \param leaf Using which leaf's to split
* \param right_leaf The new leaf index after perform this split * \param right_leaf The new leaf index after perform this split
* \param left_indices left_indices[i] != 0 means the i-th data will be on left leaf after split * \param left_indices left_indices[i] == true means the i-th data will be on left leaf after split
*/ */
virtual void Split(int leaf, int right_leaf, const char* left_indices) = 0; virtual void Split(int leaf, int right_leaf, const char* left_indices) = 0;
}; };
...@@ -231,7 +231,7 @@ public: ...@@ -231,7 +231,7 @@ public:
* \param out Output Result * \param out Output Result
*/ */
virtual void ConstructHistogram( virtual void ConstructHistogram(
data_size_t* data_indices, data_size_t num_data, const data_size_t* data_indices, data_size_t num_data,
const score_t* ordered_gradients, const score_t* ordered_hessians, const score_t* ordered_gradients, const score_t* ordered_hessians,
HistogramBinEntry* out) const = 0; HistogramBinEntry* out) const = 0;
......
...@@ -52,6 +52,7 @@ public: ...@@ -52,6 +52,7 @@ public:
*/ */
virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0; virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0;
virtual bool EvalAndCheckEarlyStopping() = 0;
/*! /*!
* \brief Get evaluation result at data_idx data * \brief Get evaluation result at data_idx data
* \param data_idx 0: training data, 1: 1st validation data * \param data_idx 0: training data, 1: 1st validation data
...@@ -98,6 +99,9 @@ public: ...@@ -98,6 +99,9 @@ public:
/*! /*!
* \brief save model to file * \brief save model to file
* \param num_used_model number of model that want to save, -1 means save all
* \param is_finish is training finished or not
* \param filename filename that want to save to
*/ */
virtual void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) = 0; virtual void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) = 0;
...@@ -141,6 +145,12 @@ public: ...@@ -141,6 +145,12 @@ public:
*/ */
virtual const char* Name() const = 0; virtual const char* Name() const = 0;
Boosting() = default;
/*! \brief Disable copy */
Boosting& operator=(const Boosting&) = delete;
/*! \brief Disable copy */
Boosting(const Boosting&) = delete;
/*! /*!
* \brief Create boosting object * \brief Create boosting object
* \param type Type of boosting * \param type Type of boosting
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <algorithm> #include <algorithm>
#include <memory>
namespace LightGBM { namespace LightGBM {
...@@ -182,6 +183,7 @@ enum TreeLearnerType { ...@@ -182,6 +183,7 @@ enum TreeLearnerType {
struct BoostingConfig: public ConfigBase { struct BoostingConfig: public ConfigBase {
public: public:
virtual ~BoostingConfig() {} virtual ~BoostingConfig() {}
double sigmoid = 1.0f;
int output_freq = 1; int output_freq = 1;
bool is_provide_training_metric = false; bool is_provide_training_metric = false;
int num_iterations = 10; int num_iterations = 10;
...@@ -193,19 +195,12 @@ public: ...@@ -193,19 +195,12 @@ public:
int num_class = 1; int num_class = 1;
double drop_rate = 0.01; double drop_rate = 0.01;
int drop_seed = 4; int drop_seed = 4;
void Set(const std::unordered_map<std::string, std::string>& params) override;
};
/*! \brief Config for GBDT */
struct GBDTConfig: public BoostingConfig {
public:
TreeLearnerType tree_learner_type = TreeLearnerType::kSerialTreeLearner; TreeLearnerType tree_learner_type = TreeLearnerType::kSerialTreeLearner;
TreeConfig tree_config; TreeConfig tree_config;
void Set(const std::unordered_map<std::string, std::string>& params) override; void Set(const std::unordered_map<std::string, std::string>& params) override;
private: private:
void GetTreeLearnerType(const std::unordered_map<std::string, void GetTreeLearnerType(const std::unordered_map<std::string,
std::string>& params); std::string>& params);
}; };
/*! \brief Config for Network */ /*! \brief Config for Network */
...@@ -229,16 +224,12 @@ public: ...@@ -229,16 +224,12 @@ public:
bool is_parallel_find_bin = false; bool is_parallel_find_bin = false;
IOConfig io_config; IOConfig io_config;
BoostingType boosting_type = BoostingType::kGBDT; BoostingType boosting_type = BoostingType::kGBDT;
BoostingConfig* boosting_config = nullptr; BoostingConfig boosting_config;
std::string objective_type = "regression"; std::string objective_type = "regression";
ObjectiveConfig objective_config; ObjectiveConfig objective_config;
std::vector<std::string> metric_types; std::vector<std::string> metric_types;
MetricConfig metric_config; MetricConfig metric_config;
~OverallConfig() {
if (boosting_config != nullptr) {
delete boosting_config;
}
}
void Set(const std::unordered_map<std::string, std::string>& params) override; void Set(const std::unordered_map<std::string, std::string>& params) override;
void LoadFromString(const char* str); void LoadFromString(const char* str);
private: private:
......
...@@ -105,7 +105,7 @@ public: ...@@ -105,7 +105,7 @@ public:
* \brief Get pointer of label * \brief Get pointer of label
* \return Pointer of label * \return Pointer of label
*/ */
inline const float* label() const { return label_; } inline const float* label() const { return label_.data(); }
/*! /*!
* \brief Set label for one record * \brief Set label for one record
...@@ -142,7 +142,7 @@ public: ...@@ -142,7 +142,7 @@ public:
* \return Pointer of weights * \return Pointer of weights
*/ */
inline const float* weights() inline const float* weights()
const { return weights_; } const { return weights_.data(); }
/*! /*!
* \brief Get data boundaries on queries, if not exists, will return nullptr * \brief Get data boundaries on queries, if not exists, will return nullptr
...@@ -152,7 +152,7 @@ public: ...@@ -152,7 +152,7 @@ public:
* \return Pointer of data boundaries on queries * \return Pointer of data boundaries on queries
*/ */
inline const data_size_t* query_boundaries() inline const data_size_t* query_boundaries()
const { return query_boundaries_; } const { return query_boundaries_.data(); }
/*! /*!
* \brief Get Number of queries * \brief Get Number of queries
...@@ -164,13 +164,18 @@ public: ...@@ -164,13 +164,18 @@ public:
* \brief Get weights for queries, if not exists, will return nullptr * \brief Get weights for queries, if not exists, will return nullptr
* \return Pointer of weights for queries * \return Pointer of weights for queries
*/ */
inline const float* query_weights() const { return query_weights_; } inline const float* query_weights() const { return query_weights_.data(); }
/*! /*!
* \brief Get initial scores, if not exists, will return nullptr * \brief Get initial scores, if not exists, will return nullptr
* \return Pointer of initial scores * \return Pointer of initial scores
*/ */
inline const float* init_score() const { return init_score_; } inline const float* init_score() const { return init_score_.data(); }
/*! \brief Disable copy */
Metadata& operator=(const Metadata&) = delete;
/*! \brief Disable copy */
Metadata(const Metadata&) = delete;
private: private:
/*! \brief Load initial scores from file */ /*! \brief Load initial scores from file */
...@@ -190,21 +195,21 @@ private: ...@@ -190,21 +195,21 @@ private:
/*! \brief Number of weights, used to check correct weight file */ /*! \brief Number of weights, used to check correct weight file */
data_size_t num_weights_; data_size_t num_weights_;
/*! \brief Label data */ /*! \brief Label data */
float* label_; std::vector<float> label_;
/*! \brief Weights data */ /*! \brief Weights data */
float* weights_; std::vector<float> weights_;
/*! \brief Query boundaries */ /*! \brief Query boundaries */
data_size_t* query_boundaries_; std::vector<data_size_t> query_boundaries_;
/*! \brief Query weights */ /*! \brief Query weights */
float* query_weights_; std::vector<float> query_weights_;
/*! \brief Number of querys */ /*! \brief Number of querys */
data_size_t num_queries_; data_size_t num_queries_;
/*! \brief Number of Initial score, used to check correct weight file */ /*! \brief Number of Initial score, used to check correct weight file */
data_size_t num_init_score_; data_size_t num_init_score_;
/*! \brief Initial score */ /*! \brief Initial score */
float* init_score_; std::vector<float> init_score_;
/*! \brief Queries data */ /*! \brief Queries data */
data_size_t* queries_; std::vector<data_size_t> queries_;
}; };
...@@ -292,8 +297,6 @@ public: ...@@ -292,8 +297,6 @@ public:
*/ */
void SaveBinaryFile(const char* bin_filename); void SaveBinaryFile(const char* bin_filename);
std::vector<const BinMapper*> GetBinMappers() const;
void CopyFeatureMapperFrom(const Dataset* dataset, bool is_enable_sparse); void CopyFeatureMapperFrom(const Dataset* dataset, bool is_enable_sparse);
/*! /*!
...@@ -301,7 +304,7 @@ public: ...@@ -301,7 +304,7 @@ public:
* \param i Index for feature * \param i Index for feature
* \return Pointer of feature * \return Pointer of feature
*/ */
inline const Feature* FeatureAt(int i) const { return features_[i]; } inline const Feature* FeatureAt(int i) const { return features_[i].get(); }
/*! /*!
* \brief Get meta data pointer * \brief Get meta data pointer
...@@ -332,7 +335,7 @@ public: ...@@ -332,7 +335,7 @@ public:
private: private:
const char* data_filename_; const char* data_filename_;
/*! \brief Store used features */ /*! \brief Store used features */
std::vector<Feature*> features_; std::vector<std::unique_ptr<Feature>> features_;
/*! \brief Mapper from real feature index to used index*/ /*! \brief Mapper from real feature index to used index*/
std::vector<int> used_feature_map_; std::vector<int> used_feature_map_;
/*! \brief Number of used features*/ /*! \brief Number of used features*/
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <LightGBM/bin.h> #include <LightGBM/bin.h>
#include <cstdio> #include <cstdio>
#include <memory>
#include <vector> #include <vector>
namespace LightGBM { namespace LightGBM {
...@@ -26,8 +26,8 @@ public: ...@@ -26,8 +26,8 @@ public:
data_size_t num_data, bool is_enable_sparse) data_size_t num_data, bool is_enable_sparse)
:bin_mapper_(bin_mapper) { :bin_mapper_(bin_mapper) {
feature_index_ = feature_idx; feature_index_ = feature_idx;
bin_data_ = Bin::CreateBin(num_data, bin_mapper_->num_bin(), bin_data_.reset(Bin::CreateBin(num_data, bin_mapper_->num_bin(),
bin_mapper_->sparse_rate(), is_enable_sparse, &is_sparse_, bin_mapper_->ValueToBin(0)); bin_mapper_->sparse_rate(), is_enable_sparse, &is_sparse_, bin_mapper_->ValueToBin(0)));
} }
/*! /*!
* \brief Constructor from memory * \brief Constructor from memory
...@@ -45,24 +45,22 @@ public: ...@@ -45,24 +45,22 @@ public:
is_sparse_ = *(reinterpret_cast<const bool*>(memory_ptr)); is_sparse_ = *(reinterpret_cast<const bool*>(memory_ptr));
memory_ptr += sizeof(is_sparse_); memory_ptr += sizeof(is_sparse_);
// get bin mapper // get bin mapper
bin_mapper_ = new BinMapper(memory_ptr); bin_mapper_.reset(new BinMapper(memory_ptr));
memory_ptr += bin_mapper_->SizesInByte(); memory_ptr += bin_mapper_->SizesInByte();
data_size_t num_data = num_all_data; data_size_t num_data = num_all_data;
if (local_used_indices.size() > 0) { if (local_used_indices.size() > 0) {
num_data = static_cast<data_size_t>(local_used_indices.size()); num_data = static_cast<data_size_t>(local_used_indices.size());
} }
if (is_sparse_) { if (is_sparse_) {
bin_data_ = Bin::CreateSparseBin(num_data, bin_mapper_->num_bin(), bin_mapper_->ValueToBin(0)); bin_data_.reset(Bin::CreateSparseBin(num_data, bin_mapper_->num_bin(), bin_mapper_->ValueToBin(0)));
} else { } else {
bin_data_ = Bin::CreateDenseBin(num_data, bin_mapper_->num_bin(), bin_mapper_->ValueToBin(0)); bin_data_.reset(Bin::CreateDenseBin(num_data, bin_mapper_->num_bin(), bin_mapper_->ValueToBin(0)));
} }
// get bin data // get bin data
bin_data_->LoadFromMemory(memory_ptr, local_used_indices); bin_data_->LoadFromMemory(memory_ptr, local_used_indices);
} }
/*! \brief Destructor */ /*! \brief Destructor */
~Feature() { ~Feature() {
delete bin_mapper_;
delete bin_data_;
} }
/*! /*!
...@@ -79,11 +77,11 @@ public: ...@@ -79,11 +77,11 @@ public:
/*! \brief Index of this feature */ /*! \brief Index of this feature */
inline int feature_index() const { return feature_index_; } inline int feature_index() const { return feature_index_; }
/*! \brief Bin mapper that this feature used */ /*! \brief Bin mapper that this feature used */
inline const BinMapper* bin_mapper() const { return bin_mapper_; } inline const BinMapper* bin_mapper() const { return bin_mapper_.get(); }
/*! \brief Number of bin of this feature */ /*! \brief Number of bin of this feature */
inline int num_bin() const { return bin_mapper_->num_bin(); } inline int num_bin() const { return bin_mapper_->num_bin(); }
/*! \brief Get bin data of this feature */ /*! \brief Get bin data of this feature */
inline const Bin* bin_data() const { return bin_data_; } inline const Bin* bin_data() const { return bin_data_.get(); }
/*! /*!
* \brief From bin to feature value * \brief From bin to feature value
* \param bin * \param bin
...@@ -118,9 +116,9 @@ private: ...@@ -118,9 +116,9 @@ private:
/*! \brief Index of this feature */ /*! \brief Index of this feature */
int feature_index_; int feature_index_;
/*! \brief Bin mapper that this feature used */ /*! \brief Bin mapper that this feature used */
BinMapper* bin_mapper_; std::unique_ptr<BinMapper> bin_mapper_;
/*! \brief Bin data of this feature */ /*! \brief Bin data of this feature */
Bin* bin_data_; std::unique_ptr<Bin> bin_data_;
/*! \brief True if this feature is sparse */ /*! \brief True if this feature is sparse */
bool is_sparse_; bool is_sparse_;
}; };
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <limits> #include <limits>
#include <vector> #include <vector>
#include <functional> #include <functional>
#include <memory>
namespace LightGBM { namespace LightGBM {
...@@ -18,11 +19,6 @@ const score_t kMinScore = -std::numeric_limits<score_t>::infinity(); ...@@ -18,11 +19,6 @@ const score_t kMinScore = -std::numeric_limits<score_t>::infinity();
const score_t kEpsilon = 1e-15f; const score_t kEpsilon = 1e-15f;
template<typename T>
std::vector<const T*> ConstPtrInVectorWarpper(std::vector<T*> input) {
return std::vector<const T*>(input.begin(), input.end());
}
using ReduceFunction = std::function<void(const char*, char*, int)>; using ReduceFunction = std::function<void(const char*, char*, int)>;
using PredictFunction = using PredictFunction =
......
...@@ -27,7 +27,7 @@ public: ...@@ -27,7 +27,7 @@ public:
virtual void Init(const char* test_name, virtual void Init(const char* test_name,
const Metadata& metadata, data_size_t num_data) = 0; const Metadata& metadata, data_size_t num_data) = 0;
virtual std::vector<std::string> GetName() const = 0; virtual const std::vector<std::string>& GetName() const = 0;
virtual score_t factor_to_bigger_better() const = 0; virtual score_t factor_to_bigger_better() const = 0;
/*! /*!
...@@ -36,6 +36,12 @@ public: ...@@ -36,6 +36,12 @@ public:
*/ */
virtual std::vector<double> Eval(const score_t* score) const = 0; virtual std::vector<double> Eval(const score_t* score) const = 0;
Metric() = default;
/*! \brief Disable copy */
Metric& operator=(const Metric&) = delete;
/*! \brief Disable copy */
Metric(const Metric&) = delete;
/*! /*!
* \brief Create object of metrics * \brief Create object of metrics
* \param type Specific type of metric * \param type Specific type of metric
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <functional> #include <functional>
#include <vector> #include <vector>
#include <memory>
namespace LightGBM { namespace LightGBM {
...@@ -139,8 +140,8 @@ public: ...@@ -139,8 +140,8 @@ public:
* \param block_len The block size for different machines * \param block_len The block size for different machines
* \param output Output result * \param output Output result
*/ */
static void Allgather(char* input, int all_size, int* block_start, static void Allgather(char* input, int all_size, const int* block_start,
int* block_len, char* output); const int* block_len, char* output);
/*! /*!
* \brief Perform reduce scatter by using recursive halving algorithm. * \brief Perform reduce scatter by using recursive halving algorithm.
...@@ -153,7 +154,7 @@ public: ...@@ -153,7 +154,7 @@ public:
* \param reducer Reduce function * \param reducer Reduce function
*/ */
static void ReduceScatter(char* input, int input_size, static void ReduceScatter(char* input, int input_size,
int* block_start, int* block_len, char* output, const int* block_start, const int* block_len, char* output,
const ReduceFunction& reducer); const ReduceFunction& reducer);
private: private:
...@@ -162,17 +163,17 @@ private: ...@@ -162,17 +163,17 @@ private:
/*! \brief Rank of local machine */ /*! \brief Rank of local machine */
static int rank_; static int rank_;
/*! \brief The network interface, provide send/recv functions */ /*! \brief The network interface, provide send/recv functions */
static Linkers *linkers_; static std::unique_ptr<Linkers> linkers_;
/*! \brief Bruck map for all gather algorithm*/ /*! \brief Bruck map for all gather algorithm*/
static BruckMap bruck_map_; static BruckMap bruck_map_;
/*! \brief Recursive halving map for reduce scatter */ /*! \brief Recursive halving map for reduce scatter */
static RecursiveHalvingMap recursive_halving_map_; static RecursiveHalvingMap recursive_halving_map_;
/*! \brief Buffer to store block start index */ /*! \brief Buffer to store block start index */
static int* block_start_; static std::vector<int> block_start_;
/*! \brief Buffer to store block size */ /*! \brief Buffer to store block size */
static int* block_len_; static std::vector<int> block_len_;
/*! \brief Buffer */ /*! \brief Buffer */
static char* buffer_; static std::vector<char> buffer_;
/*! \brief Size of buffer_ */ /*! \brief Size of buffer_ */
static int buffer_size_; static int buffer_size_;
}; };
......
...@@ -31,12 +31,13 @@ public: ...@@ -31,12 +31,13 @@ public:
virtual void GetGradients(const score_t* score, virtual void GetGradients(const score_t* score,
score_t* gradients, score_t* hessians) const = 0; score_t* gradients, score_t* hessians) const = 0;
/*! virtual const char* GetName() const = 0;
* \brief Get sigmoid param for this objective if has.
* This function is used for prediction task, if has sigmoid param, the prediction value will be transform by sigmoid function. ObjectiveFunction() = default;
* \return Sigmoid param, if <=0.0 means don't use sigmoid transform on this objective. /*! \brief Disable copy */
*/ ObjectiveFunction& operator=(const ObjectiveFunction&) = delete;
virtual score_t GetSigmoid() const = 0; /*! \brief Disable copy */
ObjectiveFunction(const ObjectiveFunction&) = delete;
/*! /*!
* \brief Create object of objective function * \brief Create object of objective function
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
namespace LightGBM { namespace LightGBM {
...@@ -111,7 +112,7 @@ private: ...@@ -111,7 +112,7 @@ private:
* \param data_idx Index of record * \param data_idx Index of record
* \return Leaf index * \return Leaf index
*/ */
inline int GetLeaf(const std::vector<BinIterator*>& iterators, inline int GetLeaf(const std::vector<std::unique_ptr<BinIterator>>& iterators,
data_size_t data_idx) const; data_size_t data_idx) const;
/*! /*!
...@@ -127,28 +128,28 @@ private: ...@@ -127,28 +128,28 @@ private:
int num_leaves_; int num_leaves_;
// following values used for non-leaf node // following values used for non-leaf node
/*! \brief A non-leaf node's left child */ /*! \brief A non-leaf node's left child */
int* left_child_; std::vector<int> left_child_;
/*! \brief A non-leaf node's right child */ /*! \brief A non-leaf node's right child */
int* right_child_; std::vector<int> right_child_;
/*! \brief A non-leaf node's split feature */ /*! \brief A non-leaf node's split feature */
int* split_feature_; std::vector<int> split_feature_;
/*! \brief A non-leaf node's split feature, the original index */ /*! \brief A non-leaf node's split feature, the original index */
int* split_feature_real_; std::vector<int> split_feature_real_;
/*! \brief A non-leaf node's split threshold in bin */ /*! \brief A non-leaf node's split threshold in bin */
unsigned int* threshold_in_bin_; std::vector<unsigned int> threshold_in_bin_;
/*! \brief A non-leaf node's split threshold in feature value */ /*! \brief A non-leaf node's split threshold in feature value */
double* threshold_; std::vector<double> threshold_;
/*! \brief A non-leaf node's split gain */ /*! \brief A non-leaf node's split gain */
double* split_gain_; std::vector<double> split_gain_;
// used for leaf node // used for leaf node
/*! \brief The parent of leaf */ /*! \brief The parent of leaf */
int* leaf_parent_; std::vector<int> leaf_parent_;
/*! \brief Output of leaves */ /*! \brief Output of leaves */
double* leaf_value_; std::vector<double> leaf_value_;
/*! \brief Output of internal nodes(save internal output for per inference feature importance calc) */ /*! \brief Output of internal nodes(save internal output for per inference feature importance calc) */
double* internal_value_; std::vector<double> internal_value_;
/*! \brief Depth for leaves */ /*! \brief Depth for leaves */
int* leaf_depth_; std::vector<int> leaf_depth_;
}; };
...@@ -162,12 +163,11 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const { ...@@ -162,12 +163,11 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const {
return leaf; return leaf;
} }
inline int Tree::GetLeaf(const std::vector<BinIterator*>& iterators, inline int Tree::GetLeaf(const std::vector<std::unique_ptr<BinIterator>>& iterators,
data_size_t data_idx) const { data_size_t data_idx) const {
int node = 0; int node = 0;
while (node >= 0) { while (node >= 0) {
if (iterators[split_feature_[node]]->Get(data_idx) <= if (iterators[split_feature_[node]]->Get(data_idx) <= threshold_in_bin_[node]) {
threshold_in_bin_[node]) {
node = left_child_[node]; node = left_child_[node];
} else { } else {
node = right_child_[node]; node = right_child_[node];
......
...@@ -49,6 +49,12 @@ public: ...@@ -49,6 +49,12 @@ public:
*/ */
virtual void AddPredictionToScore(score_t *out_score) const = 0; virtual void AddPredictionToScore(score_t *out_score) const = 0;
TreeLearner() = default;
/*! \brief Disable copy */
TreeLearner& operator=(const TreeLearner&) = delete;
/*! \brief Disable copy */
TreeLearner(const TreeLearner&) = delete;
/*! /*!
* \brief Create object of tree learner * \brief Create object of tree learner
* \param type Type of tree learner * \param type Type of tree learner
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <functional> #include <functional>
#include <memory>
namespace LightGBM { namespace LightGBM {
...@@ -359,29 +360,15 @@ inline void Softmax(std::vector<double>* p_rec) { ...@@ -359,29 +360,15 @@ inline void Softmax(std::vector<double>* p_rec) {
} }
} }
template<typename T1, typename T2> template<typename T>
inline void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t start, bool is_reverse = false) { std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<T>>& input) {
std::vector<std::pair<T1, T2>> arr; std::vector<const T*> ret;
for (size_t i = start; i < keys.size(); ++i) { for (size_t i = 0; i < input.size(); ++i) {
arr.emplace_back(keys[i], values[i]); ret.push_back(input.at(i).get());
}
if (!is_reverse) {
std::sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
return a.first < b.first;
});
} else {
std::sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
return a.first > b.first;
});
}
for (size_t i = start; i < arr.size(); ++i) {
keys[i] = arr[i].first;
values[i] = arr[i].second;
} }
return ret;
} }
} // namespace Common } // namespace Common
} // namespace LightGBM } // namespace LightGBM
......
#ifndef LIGHTGBM_UTILS_LRU_POOL_H_
#define LIGHTGBM_UTILS_LRU_POOL_H_
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/log.h>
#include <cstring>
#include <functional>
namespace LightGBM {
/*!
* \brief A LRU cached object pool, used for store historical histograms
*/
template<typename T>
class LRUPool {
public:
/*!
* \brief Constructor
*/
LRUPool() {
}
/*!
* \brief Destructor
*/
~LRUPool() {
FreeAll();
}
/*!
* \brief Reset pool size
* \param cache_size Max cache size
* \param total_size Total size will be used
*/
void ResetSize(int cache_size, int total_size) {
// free old memory
FreeAll();
cache_size_ = cache_size;
// at least need 2 bucket to store smaller leaf and larger leaf
CHECK(cache_size_ >= 2);
total_size_ = total_size;
if (cache_size_ > total_size_) {
cache_size_ = total_size_;
}
is_enough_ = (cache_size_ == total_size_);
pool_ = new T[cache_size_];
if (!is_enough_) {
mapper_ = new int[total_size_];
inverse_mapper_ = new int[cache_size_];
last_used_time_ = new int[cache_size_];
ResetMap();
}
}
/*!
* \brief Reset mapper
*/
void ResetMap() {
if (!is_enough_) {
cur_time_ = 0;
memset(mapper_, -1, sizeof(int)*total_size_);
memset(inverse_mapper_, -1, sizeof(int)*cache_size_);
memset(last_used_time_, 0, sizeof(int)*cache_size_);
}
}
/*!
* \brief Fill the pool
* \param obj_create_fun that used to generate object
*/
void Fill(std::function<T()> obj_create_fun) {
for (int i = 0; i < cache_size_; ++i) {
pool_[i] = obj_create_fun();
}
}
/*!
* \brief Get data for the specific index
* \param idx which index want to get
* \param out output data will store into this
* \return True if this index is in the pool, False if this index is not in the pool
*/
bool Get(int idx, T* out) {
if (is_enough_) {
*out = pool_[idx];
return true;
}
else if (mapper_[idx] >= 0) {
int slot = mapper_[idx];
*out = pool_[slot];
last_used_time_[slot] = ++cur_time_;
return true;
} else {
// choose the least used slot
int slot = static_cast<int>(ArrayArgs<int>::ArgMin(last_used_time_, cache_size_));
*out = pool_[slot];
last_used_time_[slot] = ++cur_time_;
// reset previous mapper
if (inverse_mapper_[slot] >= 0) mapper_[inverse_mapper_[slot]] = -1;
// update current mapper
mapper_[idx] = slot;
inverse_mapper_[slot] = idx;
return false;
}
}
/*!
* \brief Move data from one index to another index
* \param src_idx
* \param dst_idx
*/
void Move(int src_idx, int dst_idx) {
if (is_enough_) {
std::swap(pool_[src_idx], pool_[dst_idx]);
return;
}
if (mapper_[src_idx] < 0) {
return;
}
// get slot of src idx
int slot = mapper_[src_idx];
// reset src_idx
mapper_[src_idx] = -1;
// move to dst idx
mapper_[dst_idx] = slot;
last_used_time_[slot] = ++cur_time_;
inverse_mapper_[slot] = dst_idx;
}
private:
void FreeAll(){
if (pool_ != nullptr) {
delete[] pool_;
}
if (mapper_ != nullptr) {
delete[] mapper_;
}
if (inverse_mapper_ != nullptr) {
delete[] inverse_mapper_;
}
if (last_used_time_ != nullptr) {
delete[] last_used_time_;
}
}
T* pool_ = nullptr;
int cache_size_;
int total_size_;
bool is_enough_ = false;
int* mapper_ = nullptr;
int* inverse_mapper_ = nullptr;
int* last_used_time_ = nullptr;
int cur_time_ = 0;
};
}
#endif // LIGHTGBM_UTILS_LRU_POOL_H_
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <thread> #include <thread>
#include <memory>
namespace LightGBM{ namespace LightGBM{
...@@ -35,34 +36,32 @@ public: ...@@ -35,34 +36,32 @@ public:
size_t cnt = 0; size_t cnt = 0;
const size_t buffer_size = 16 * 1024 * 1024 ; const size_t buffer_size = 16 * 1024 * 1024 ;
// buffer used for the process_fun // buffer used for the process_fun
char* buffer_process = new char[buffer_size]; auto buffer_process = std::vector<char>(buffer_size);
// buffer used for the file reading // buffer used for the file reading
char* buffer_read = new char[buffer_size]; auto buffer_read = std::vector<char>(buffer_size);
size_t read_cnt = 0; size_t read_cnt = 0;
if (skip_bytes > 0) { if (skip_bytes > 0) {
// skip first k bytes // skip first k bytes
read_cnt = fread(buffer_process, 1, skip_bytes, file); read_cnt = fread(buffer_process.data(), 1, skip_bytes, file);
} }
// read first block // read first block
read_cnt = fread(buffer_process, 1, buffer_size, file); read_cnt = fread(buffer_process.data(), 1, buffer_size, file);
size_t last_read_cnt = 0; size_t last_read_cnt = 0;
while (read_cnt > 0) { while (read_cnt > 0) {
// strat read thread // strat read thread
std::thread read_worker = std::thread( std::thread read_worker = std::thread(
[file, buffer_read, buffer_size, &last_read_cnt] { [file, &buffer_read, buffer_size, &last_read_cnt] {
last_read_cnt = fread(buffer_read, 1, buffer_size, file); last_read_cnt = fread(buffer_read.data(), 1, buffer_size, file);
} }
); );
// start process // start process
cnt += process_fun(buffer_process, read_cnt); cnt += process_fun(buffer_process.data(), read_cnt);
// wait for read thread // wait for read thread
read_worker.join(); read_worker.join();
// exchange the buffer // exchange the buffer
std::swap(buffer_process, buffer_read); std::swap(buffer_process, buffer_read);
read_cnt = last_read_cnt; read_cnt = last_read_cnt;
} }
delete[] buffer_process;
delete[] buffer_read;
// close file // close file
fclose(file); fclose(file);
return cnt; return cnt;
......
...@@ -26,8 +26,7 @@ ...@@ -26,8 +26,7 @@
namespace LightGBM { namespace LightGBM {
Application::Application(int argc, char** argv) Application::Application(int argc, char** argv) {
:dataset_loader_(nullptr), train_data_(nullptr), boosting_(nullptr), objective_fun_(nullptr) {
LoadParameters(argc, argv); LoadParameters(argc, argv);
// set number of threads for openmp // set number of threads for openmp
if (config_.num_threads > 0) { if (config_.num_threads > 0) {
...@@ -39,23 +38,6 @@ Application::Application(int argc, char** argv) ...@@ -39,23 +38,6 @@ Application::Application(int argc, char** argv)
} }
Application::~Application() { Application::~Application() {
if (dataset_loader_ != nullptr) { delete dataset_loader_; }
if (train_data_ != nullptr) { delete train_data_; }
for (auto& data : valid_datas_) {
if (data != nullptr) { delete data; }
}
valid_datas_.clear();
for (auto& metric : train_metric_) {
if (metric != nullptr) { delete metric; }
}
for (auto& metric : valid_metrics_) {
for (auto& sub_metric : metric) {
if (sub_metric != nullptr) { delete sub_metric; }
}
}
valid_metrics_.clear();
if (boosting_ != nullptr) { delete boosting_; }
if (objective_fun_ != nullptr) { delete objective_fun_; }
if (config_.is_parallel) { if (config_.is_parallel) {
Network::Dispose(); Network::Dispose();
} }
...@@ -125,11 +107,10 @@ void Application::LoadData() { ...@@ -125,11 +107,10 @@ void Application::LoadData() {
auto start_time = std::chrono::high_resolution_clock::now(); auto start_time = std::chrono::high_resolution_clock::now();
// prediction is needed if using input initial model(continued train) // prediction is needed if using input initial model(continued train)
PredictFunction predict_fun = nullptr; PredictFunction predict_fun = nullptr;
Predictor* predictor = nullptr;
// need to continue training // need to continue training
if (boosting_->NumberOfSubModels() > 0) { if (boosting_->NumberOfSubModels() > 0) {
predictor = new Predictor(boosting_, true, false); Predictor predictor(boosting_.get(), true, false);
predict_fun = predictor->GetPredictFunction(); predict_fun = predictor.GetPredictFunction();
} }
// sync up random seed for data partition // sync up random seed for data partition
...@@ -138,37 +119,41 @@ void Application::LoadData() { ...@@ -138,37 +119,41 @@ void Application::LoadData() {
GlobalSyncUpByMin<int>(config_.io_config.data_random_seed); GlobalSyncUpByMin<int>(config_.io_config.data_random_seed);
} }
dataset_loader_ = new DatasetLoader(config_.io_config, predict_fun); DatasetLoader dataset_loader(config_.io_config, predict_fun);
dataset_loader_->SetHeader(config_.io_config.data_filename.c_str()); dataset_loader.SetHeader(config_.io_config.data_filename.c_str());
// load Training data // load Training data
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
// load data for parallel training // load data for parallel training
train_data_ = dataset_loader_->LoadFromFile(config_.io_config.data_filename.c_str(), train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(),
Network::rank(), Network::num_machines()); Network::rank(), Network::num_machines()));
} else { } else {
// load data for single machine // load data for single machine
train_data_ = dataset_loader_->LoadFromFile(config_.io_config.data_filename.c_str(), 0, 1); train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), 0, 1));
} }
// need save binary file // need save binary file
if (config_.io_config.is_save_binary_file) { if (config_.io_config.is_save_binary_file) {
train_data_->SaveBinaryFile(nullptr); train_data_->SaveBinaryFile(nullptr);
} }
// create training metric // create training metric
if (config_.boosting_config->is_provide_training_metric) { if (config_.boosting_config.is_provide_training_metric) {
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric_types) {
Metric* metric = auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
Metric::CreateMetric(metric_type, config_.metric_config);
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init("training", train_data_->metadata(), metric->Init("training", train_data_->metadata(),
train_data_->num_data()); train_data_->num_data());
train_metric_.push_back(metric); train_metric_.push_back(std::move(metric));
} }
} }
train_metric_.shrink_to_fit();
// Add validation data, if it exists // Add validation data, if it exists
for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) { for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) {
// add // add
valid_datas_.push_back(dataset_loader_->LoadFromFileAlignWithOtherDataset(config_.io_config.valid_data_filenames[i].c_str(), auto new_dataset = std::unique_ptr<Dataset>(
train_data_)); dataset_loader.LoadFromFileAlignWithOtherDataset(
config_.io_config.valid_data_filenames[i].c_str(),
train_data_.get())
);
valid_datas_.push_back(std::move(new_dataset));
// need save binary file // need save binary file
if (config_.io_config.is_save_binary_file) { if (config_.io_config.is_save_binary_file) {
valid_datas_.back()->SaveBinaryFile(nullptr); valid_datas_.back()->SaveBinaryFile(nullptr);
...@@ -177,17 +162,17 @@ void Application::LoadData() { ...@@ -177,17 +162,17 @@ void Application::LoadData() {
// add metric for validation data // add metric for validation data
valid_metrics_.emplace_back(); valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric_types) {
Metric* metric = Metric::CreateMetric(metric_type, config_.metric_config); auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init(config_.io_config.valid_data_filenames[i].c_str(), metric->Init(config_.io_config.valid_data_filenames[i].c_str(),
valid_datas_.back()->metadata(), valid_datas_.back()->metadata(),
valid_datas_.back()->num_data()); valid_datas_.back()->num_data());
valid_metrics_.back().push_back(metric); valid_metrics_.back().push_back(std::move(metric));
} }
valid_metrics_.back().shrink_to_fit();
} }
if (predictor != nullptr) { valid_datas_.shrink_to_fit();
delete predictor; valid_metrics_.shrink_to_fit();
}
auto end_time = std::chrono::high_resolution_clock::now(); auto end_time = std::chrono::high_resolution_clock::now();
// output used time on each iteration // output used time on each iteration
Log::Info("Finished loading data in %f seconds", Log::Info("Finished loading data in %f seconds",
...@@ -201,40 +186,38 @@ void Application::InitTrain() { ...@@ -201,40 +186,38 @@ void Application::InitTrain() {
Log::Info("Finished initializing network"); Log::Info("Finished initializing network");
// sync global random seed for feature patition // sync global random seed for feature patition
if (config_.boosting_type == BoostingType::kGBDT || config_.boosting_type == BoostingType::kDART) { if (config_.boosting_type == BoostingType::kGBDT || config_.boosting_type == BoostingType::kDART) {
GBDTConfig* gbdt_config = config_.boosting_config.tree_config.feature_fraction_seed =
dynamic_cast<GBDTConfig*>(config_.boosting_config); GlobalSyncUpByMin<int>(config_.boosting_config.tree_config.feature_fraction_seed);
gbdt_config->tree_config.feature_fraction_seed = config_.boosting_config.tree_config.feature_fraction =
GlobalSyncUpByMin<int>(gbdt_config->tree_config.feature_fraction_seed); GlobalSyncUpByMin<double>(config_.boosting_config.tree_config.feature_fraction);
gbdt_config->tree_config.feature_fraction =
GlobalSyncUpByMin<double>(gbdt_config->tree_config.feature_fraction);
} }
} }
// create boosting // create boosting
boosting_ = boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type, Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.input_model.c_str()); config_.io_config.input_model.c_str()));
// create objective function // create objective function
objective_fun_ = objective_fun_.reset(
ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config); config_.objective_config));
// load training data // load training data
LoadData(); LoadData();
// initialize the objective function // initialize the objective function
objective_fun_->Init(train_data_->metadata(), train_data_->num_data()); objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
// initialize the boosting // initialize the boosting
boosting_->Init(config_.boosting_config, train_data_, objective_fun_, boosting_->Init(&config_.boosting_config, train_data_.get(), objective_fun_.get(),
ConstPtrInVectorWarpper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
// add validation data into boosting // add validation data into boosting
for (size_t i = 0; i < valid_datas_.size(); ++i) { for (size_t i = 0; i < valid_datas_.size(); ++i) {
boosting_->AddDataset(valid_datas_[i], boosting_->AddDataset(valid_datas_[i].get(),
ConstPtrInVectorWarpper<Metric>(valid_metrics_[i])); Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
} }
Log::Info("Finished initializing training"); Log::Info("Finished initializing training");
} }
void Application::Train() { void Application::Train() {
Log::Info("Started training..."); Log::Info("Started training...");
int total_iter = config_.boosting_config->num_iterations; int total_iter = config_.boosting_config.num_iterations;
bool is_finished = false; bool is_finished = false;
bool need_eval = true; bool need_eval = true;
auto start_time = std::chrono::high_resolution_clock::now(); auto start_time = std::chrono::high_resolution_clock::now();
...@@ -256,7 +239,7 @@ void Application::Train() { ...@@ -256,7 +239,7 @@ void Application::Train() {
void Application::Predict() { void Application::Predict() {
boosting_->SetNumUsedModel(config_.io_config.num_model_predict); boosting_->SetNumUsedModel(config_.io_config.num_model_predict);
// create predictor // create predictor
Predictor predictor(boosting_, config_.io_config.is_predict_raw_score, Predictor predictor(boosting_.get(), config_.io_config.is_predict_raw_score,
config_.io_config.is_predict_leaf_index); config_.io_config.is_predict_leaf_index);
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header); config_.io_config.output_result.c_str(), config_.io_config.has_header);
...@@ -264,8 +247,8 @@ void Application::Predict() { ...@@ -264,8 +247,8 @@ void Application::Predict() {
} }
void Application::InitPredict() { void Application::InitPredict() {
boosting_ = boosting_.reset(
Boosting::CreateBoosting(config_.io_config.input_model.c_str()); Boosting::CreateBoosting(config_.io_config.input_model.c_str()));
Log::Info("Finished initializing prediction"); Log::Info("Finished initializing prediction");
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <string> #include <string>
#include <memory>
namespace LightGBM { namespace LightGBM {
...@@ -36,16 +37,15 @@ public: ...@@ -36,16 +37,15 @@ public:
{ {
num_threads_ = omp_get_num_threads(); num_threads_ = omp_get_num_threads();
} }
features_ = new double*[num_threads_];
for (int i = 0; i < num_threads_; ++i) { for (int i = 0; i < num_threads_; ++i) {
features_[i] = new double[num_features_]; features_.push_back(std::vector<double>(num_features_));
} }
features_.shrink_to_fit();
if (is_predict_leaf_index) { if (is_predict_leaf_index) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result for leaf index // get result for leaf index
auto result = boosting_->PredictLeafIndex(features_[tid]); auto result = boosting_->PredictLeafIndex(features_[tid].data());
return std::vector<double>(result.begin(), result.end()); return std::vector<double>(result.begin(), result.end());
}; };
} else { } else {
...@@ -53,12 +53,12 @@ public: ...@@ -53,12 +53,12 @@ public:
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result without sigmoid transformation // get result without sigmoid transformation
return boosting_->PredictRaw(features_[tid]); return boosting_->PredictRaw(features_[tid].data());
}; };
} else { } else {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
return boosting_->Predict(features_[tid]); return boosting_->Predict(features_[tid].data());
}; };
} }
} }
...@@ -67,12 +67,6 @@ public: ...@@ -67,12 +67,6 @@ public:
* \brief Destructor * \brief Destructor
*/ */
~Predictor() { ~Predictor() {
if (features_ != nullptr) {
for (int i = 0; i < num_threads_; ++i) {
delete[] features_[i];
}
delete[] features_;
}
} }
inline const PredictFunction& GetPredictFunction() { inline const PredictFunction& GetPredictFunction() {
...@@ -97,7 +91,7 @@ public: ...@@ -97,7 +91,7 @@ public:
if (result_file == NULL) { if (result_file == NULL) {
Log::Fatal("Prediction results file %s doesn't exist", data_filename); Log::Fatal("Prediction results file %s doesn't exist", data_filename);
} }
Parser* parser = Parser::CreateParser(data_filename, has_header, num_features_, boosting_->LabelIdx()); auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, num_features_, boosting_->LabelIdx()));
if (parser == nullptr) { if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename); Log::Fatal("Could not recognize the data format of data file %s", data_filename);
...@@ -133,14 +127,13 @@ public: ...@@ -133,14 +127,13 @@ public:
predict_data_reader.ReadAllAndProcessParallel(process_fun); predict_data_reader.ReadAllAndProcessParallel(process_fun);
fclose(result_file); fclose(result_file);
delete parser;
} }
private: private:
int PutFeatureValuesToBuffer(const std::vector<std::pair<int, double>>& features) { int PutFeatureValuesToBuffer(const std::vector<std::pair<int, double>>& features) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
// init feature value // init feature value
std::memset(features_[tid], 0, sizeof(double)*num_features_); std::memset(features_[tid].data(), 0, sizeof(double)*num_features_);
// put feature value // put feature value
for (const auto& p : features) { for (const auto& p : features) {
if (p.first < num_features_) { if (p.first < num_features_) {
...@@ -152,7 +145,7 @@ private: ...@@ -152,7 +145,7 @@ private:
/*! \brief Boosting model */ /*! \brief Boosting model */
const Boosting* boosting_; const Boosting* boosting_;
/*! \brief Buffer for feature values */ /*! \brief Buffer for feature values */
double** features_; std::vector<std::vector<double>> features_;
/*! \brief Number of features */ /*! \brief Number of features */
int num_features_; int num_features_;
/*! \brief Number of threads */ /*! \brief Number of threads */
......
#include <LightGBM/boosting.h> #include <LightGBM/boosting.h>
#include "gbdt.h" #include "gbdt.h"
#include "dart.h" #include "dart.hpp"
namespace LightGBM { namespace LightGBM {
...@@ -37,32 +37,32 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) { ...@@ -37,32 +37,32 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
return nullptr; return nullptr;
} }
} else { } else {
Boosting* ret = nullptr; std::unique_ptr<Boosting> ret;
auto type_in_file = GetBoostingTypeFromModelFile(filename); auto type_in_file = GetBoostingTypeFromModelFile(filename);
if (type_in_file == type) { if (type_in_file == type) {
if (type == BoostingType::kGBDT) { if (type == BoostingType::kGBDT) {
ret = new GBDT(); ret.reset(new GBDT());
} else if (type == BoostingType::kDART) { } else if (type == BoostingType::kDART) {
ret = new DART(); ret.reset(new DART());
} }
LoadFileToBoosting(ret, filename); LoadFileToBoosting(ret.get(), filename);
} else { } else {
Log::Fatal("Boosting type in parameter is not the same as the type in the model file"); Log::Fatal("Boosting type in parameter is not the same as the type in the model file");
} }
return ret; return ret.release();
} }
} }
Boosting* Boosting::CreateBoosting(const char* filename) { Boosting* Boosting::CreateBoosting(const char* filename) {
auto type = GetBoostingTypeFromModelFile(filename); auto type = GetBoostingTypeFromModelFile(filename);
Boosting* ret = nullptr; std::unique_ptr<Boosting> ret;
if (type == BoostingType::kGBDT) { if (type == BoostingType::kGBDT) {
ret = new GBDT(); ret.reset(new GBDT());
} else if (type == BoostingType::kDART) { } else if (type == BoostingType::kDART) {
ret = new DART(); ret.reset(new DART());
} }
LoadFileToBoosting(ret, filename); LoadFileToBoosting(ret.get(), filename);
return ret; return ret.release();
} }
} // namespace LightGBM } // namespace LightGBM
#include "gbdt.h"
#include "dart.h"
#include <LightGBM/utils/common.h>
#include <LightGBM/feature.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <ctime>
#include <sstream>
#include <chrono>
#include <string>
#include <vector>
#include <utility>
namespace LightGBM {
DART::DART(){
}
DART::~DART(){
}
void DART::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) {
GBDT::Init(config, train_data, object_function, training_metrics);
gbdt_config_ = dynamic_cast<const GBDTConfig*>(config);
drop_rate_ = gbdt_config_->drop_rate;
shrinkage_rate_ = 1.0;
random_for_drop_ = Random(gbdt_config_->drop_seed);
}
bool DART::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
// boosting first
if (gradient == nullptr || hessian == nullptr) {
Boosting();
gradient = gradients_;
hessian = hessians_;
}
for (int curr_class = 0; curr_class < num_class_; ++curr_class){
// bagging logic
Bagging(iter_, curr_class);
// train a new tree
Tree * new_tree = tree_learner_[curr_class]->Train(gradient + curr_class * num_data_, hessian+ curr_class * num_data_);
// if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) {
Log::Info("Can't training anymore, there isn't any leaf meets split requirements.");
return true;
}
// shrink new tree
new_tree->Shrinkage(shrinkage_rate_);
// update score
UpdateScore(new_tree, curr_class);
UpdateScoreOutOfBag(new_tree, curr_class);
// add model
models_.push_back(new_tree);
}
// normalize
Normalize();
bool is_met_early_stopping = false;
// print message for metric
if (is_eval) {
is_met_early_stopping = OutputMetric(iter_ + 1);
}
++iter_;
if (is_met_early_stopping) {
Log::Info("Early stopping at iteration %d, the best iteration round is %d",
iter_, iter_ - early_stopping_round_);
// pop last early_stopping_round_ models
for (int i = 0; i < early_stopping_round_ * num_class_; ++i) {
delete models_.back();
models_.pop_back();
}
}
return is_met_early_stopping;
}
/*! \brief Get training scores result */
const score_t* DART::GetTrainingScore(data_size_t* out_len) {
DroppingTrees();
*out_len = train_score_updater_->num_data() * num_class_;
return train_score_updater_->score();
}
void DART::SaveModelToFile(int num_used_model, bool is_finish, const char* filename) {
// only save model once when is_finish = true
if (is_finish && saved_model_size_ < 0) {
GBDT::SaveModelToFile(num_used_model, is_finish, filename);
}
}
void DART::DroppingTrees() {
drop_index_.clear();
// select dropping tree indexes based on drop_rate
// if drop rate is too small, skip this step, drop one tree randomly
if (drop_rate_ > kEpsilon) {
for (size_t i = 0; i < static_cast<size_t>(iter_); ++i){
if (random_for_drop_.NextDouble() < drop_rate_) {
drop_index_.push_back(i);
}
}
}
// binomial-plus-one, at least one tree will be dropped
if (drop_index_.empty()){
drop_index_ = random_for_drop_.Sample(iter_, 1);
}
// drop trees
for (auto i: drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
models_[curr_tree]->Shrinkage(-1.0);
train_score_updater_->AddScore(models_[curr_tree], curr_class);
}
}
shrinkage_rate_ = 1.0 / (1.0 + drop_index_.size());
}
void DART::Normalize() {
double k = static_cast<double>(drop_index_.size());
for (auto i: drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
// update validation score
models_[curr_tree]->Shrinkage(shrinkage_rate_);
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(models_[curr_tree], curr_class);
}
// update training score
models_[curr_tree]->Shrinkage(-k);
train_score_updater_->AddScore(models_[curr_tree], curr_class);
}
}
}
} // namespace LightGBM
...@@ -19,11 +19,11 @@ public: ...@@ -19,11 +19,11 @@ public:
/*! /*!
* \brief Constructor * \brief Constructor
*/ */
DART(); DART(): GBDT() { }
/*! /*!
* \brief Destructor * \brief Destructor
*/ */
~DART(); ~DART() { }
/*! /*!
* \brief Initialization logic * \brief Initialization logic
* \param config Config for boosting * \param config Config for boosting
...@@ -32,24 +32,48 @@ public: ...@@ -32,24 +32,48 @@ public:
* \param training_metrics Training metrics * \param training_metrics Training metrics
* \param output_model_filename Filename of output model * \param output_model_filename Filename of output model
*/ */
void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* object_function, void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) const std::vector<const Metric*>& training_metrics) override {
override; GBDT::Init(config, train_data, object_function, training_metrics);
drop_rate_ = gbdt_config_->drop_rate;
shrinkage_rate_ = 1.0;
random_for_drop_ = Random(gbdt_config_->drop_seed);
}
/*! /*!
* \brief one training iteration * \brief one training iteration
*/ */
bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override; bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override {
GBDT::TrainOneIter(gradient, hessian, false);
// normalize
Normalize();
if (is_eval) {
return EvalAndCheckEarlyStopping();
} else {
return false;
}
}
/*! /*!
* \brief Get current training score * \brief Get current training score
* \param out_len lenght of returned score * \param out_len length of returned score
* \return training score * \return training score
*/ */
const score_t* GetTrainingScore(data_size_t* out_len) override; const score_t* GetTrainingScore(data_size_t* out_len) override {
DroppingTrees();
*out_len = train_score_updater_->num_data() * num_class_;
return train_score_updater_->score();
}
/*! /*!
* \brief Serialize models by string * \brief save model to file
* \return String output of tranined model * \param num_used_model number of model that want to save, -1 means save all
* \param is_finish is training finished or not
* \param filename filename that want to save to
*/ */
void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) override; void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) override {
// only save model once when is_finish = true
if (is_finish && saved_model_size_ < 0) {
GBDT::SaveModelToFile(num_used_model, is_finish, filename);
}
}
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
*/ */
...@@ -59,17 +83,54 @@ private: ...@@ -59,17 +83,54 @@ private:
/*! /*!
* \brief drop trees based on drop_rate * \brief drop trees based on drop_rate
*/ */
void DroppingTrees(); void DroppingTrees() {
drop_index_.clear();
// select dropping tree indexes based on drop_rate
// if drop rate is too small, skip this step, drop one tree randomly
if (drop_rate_ > kEpsilon) {
for (size_t i = 0; i < static_cast<size_t>(iter_); ++i) {
if (random_for_drop_.NextDouble() < drop_rate_) {
drop_index_.push_back(i);
}
}
}
// binomial-plus-one, at least one tree will be dropped
if (drop_index_.empty()) {
drop_index_ = random_for_drop_.Sample(iter_, 1);
}
// drop trees
for (auto i : drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
models_[curr_tree]->Shrinkage(-1.0);
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
}
}
shrinkage_rate_ = 1.0 / (1.0 + drop_index_.size());
}
/*! /*!
* \brief normalize dropped trees * \brief normalize dropped trees
*/ */
void Normalize(); void Normalize() {
double k = static_cast<double>(drop_index_.size());
for (auto i : drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
// update validation score
models_[curr_tree]->Shrinkage(shrinkage_rate_);
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(models_[curr_tree].get(), curr_class);
}
// update training score
models_[curr_tree]->Shrinkage(-k);
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
}
}
}
/*! \brief The indexes of dropping trees */ /*! \brief The indexes of dropping trees */
std::vector<size_t> drop_index_; std::vector<size_t> drop_index_;
/*! \brief Dropping rate */ /*! \brief Dropping rate */
double drop_rate_; double drop_rate_;
/*! \brief Shrinkage rate for one iteration */
double shrinkage_rate_;
/*! \brief Random generator, used to select dropping trees */ /*! \brief Random generator, used to select dropping trees */
Random random_for_drop_; Random random_for_drop_;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment