Commit 5d022898 authored by Guolin Ke's avatar Guolin Ke Committed by xuehui
Browse files

refine some interface for better expose api. (#52)

parent 66804b93
...@@ -54,9 +54,6 @@ private: ...@@ -54,9 +54,6 @@ private:
/*! \brief Initializations before prediction */ /*! \brief Initializations before prediction */
void InitPredict(); void InitPredict();
/*! \brief Load model from local disk */
void LoadModel();
/*! \brief Main predicting logic */ /*! \brief Main predicting logic */
void Predict(); void Predict();
......
...@@ -56,7 +56,7 @@ public: ...@@ -56,7 +56,7 @@ public:
/*! \brief True if bin is trival (contains only one bin) */ /*! \brief True if bin is trival (contains only one bin) */
inline bool is_trival() const { return is_trival_; } inline bool is_trival() const { return is_trival_; }
/*! \brief Sparsity of this bin ( num_zero_bins / num_data ) */ /*! \brief Sparsity of this bin ( num_zero_bins / num_data ) */
inline double sparse_rate() const { return sparse_rate_; } inline float sparse_rate() const { return sparse_rate_; }
/*! /*!
* \brief Save binary data to file * \brief Save binary data to file
* \param file File want to write * \param file File want to write
...@@ -67,7 +67,7 @@ public: ...@@ -67,7 +67,7 @@ public:
* \param bin * \param bin
* \return Feature value of this bin * \return Feature value of this bin
*/ */
inline double BinToValue(unsigned int bin) const { inline float BinToValue(unsigned int bin) const {
return bin_upper_bound_[bin]; return bin_upper_bound_[bin];
} }
/*! /*!
...@@ -79,14 +79,14 @@ public: ...@@ -79,14 +79,14 @@ public:
* \param value * \param value
* \return bin for this feature value * \return bin for this feature value
*/ */
inline unsigned int ValueToBin(double value) const; inline unsigned int ValueToBin(float value) const;
/*! /*!
* \brief Construct feature value to bin mapper according feature values * \brief Construct feature value to bin mapper according feature values
* \param values (Sampled) values of this feature * \param values (Sampled) values of this feature
* \param max_bin The maximal number of bin * \param max_bin The maximal number of bin
*/ */
void FindBin(std::vector<double>* values, int max_bin); void FindBin(std::vector<float>* values, int max_bin);
/*! /*!
* \brief Use specific number of bin to calculate the size of this class * \brief Use specific number of bin to calculate the size of this class
...@@ -111,11 +111,11 @@ private: ...@@ -111,11 +111,11 @@ private:
/*! \brief Number of bins */ /*! \brief Number of bins */
int num_bin_; int num_bin_;
/*! \brief Store upper bound for each bin */ /*! \brief Store upper bound for each bin */
double* bin_upper_bound_; float* bin_upper_bound_;
/*! \brief True if this feature is trival */ /*! \brief True if this feature is trival */
bool is_trival_; bool is_trival_;
/*! \brief Sparse rate of this bins( num_bin0/num_data ) */ /*! \brief Sparse rate of this bins( num_bin0/num_data ) */
double sparse_rate_; float sparse_rate_;
}; };
/*! /*!
...@@ -271,7 +271,7 @@ public: ...@@ -271,7 +271,7 @@ public:
* \return The bin data object * \return The bin data object
*/ */
static Bin* CreateBin(data_size_t num_data, int num_bin, static Bin* CreateBin(data_size_t num_data, int num_bin,
double sparse_rate, bool is_enable_sparse, bool* is_sparse, int default_bin); float sparse_rate, bool is_enable_sparse, bool* is_sparse, int default_bin);
/*! /*!
* \brief Create object for bin data of one feature, used for dense feature * \brief Create object for bin data of one feature, used for dense feature
...@@ -293,7 +293,7 @@ public: ...@@ -293,7 +293,7 @@ public:
int num_bin, int default_bin); int num_bin, int default_bin);
}; };
inline unsigned int BinMapper::ValueToBin(double value) const { inline unsigned int BinMapper::ValueToBin(float value) const {
// binary search to find bin // binary search to find bin
int l = 0; int l = 0;
int r = num_bin_ - 1; int r = num_bin_ - 1;
......
...@@ -28,12 +28,12 @@ public: ...@@ -28,12 +28,12 @@ public:
* \param train_data Training data * \param train_data Training data
* \param object_function Training objective function * \param object_function Training objective function
* \param training_metrics Training metric * \param training_metrics Training metric
* \param output_model_filename Filename of output model
*/ */
virtual void Init(const Dataset* train_data, virtual void Init(
const BoostingConfig* config,
const Dataset* train_data,
const ObjectiveFunction* object_function, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics, const std::vector<const Metric*>& training_metrics) = 0;
const char* output_model_filename) = 0;
/*! /*!
* \brief Add a validation data * \brief Add a validation data
...@@ -44,40 +44,52 @@ public: ...@@ -44,40 +44,52 @@ public:
const std::vector<const Metric*>& valid_metrics) = 0; const std::vector<const Metric*>& valid_metrics) = 0;
/*! \brief Training logic */ /*! \brief Training logic */
virtual void Train() = 0; virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0;
/*! \brief Get eval result */
virtual std::vector<std::string> EvalCurrent(bool is_eval_train) const = 0 ;
/*! \brief Get prediction result */
virtual const std::vector<const score_t*> PredictCurrent(bool is_predict_train) const = 0;
/*! /*!
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
virtual double PredictRaw(const double * feature_values) const = 0; virtual float PredictRaw(const float* feature_values,
int num_used_model) const = 0;
/*! /*!
* \brief Prediction for one record, sigmoid transformation will be used if needed * \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
virtual double Predict(const double * feature_values) const = 0; virtual float Predict(const float* feature_values,
int num_used_model) const = 0;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Predicted leaf index for this record * \return Predicted leaf index for this record
*/ */
virtual std::vector<int> PredictLeafIndex(const double * feature_values) const = 0; virtual std::vector<int> PredictLeafIndex(
const float* feature_values,
int num_used_model) const = 0;
/*! /*!
* \brief Serialize models by string * \brief save model to file
* \return String output of tranined model
*/ */
virtual std::string ModelsToString() const = 0; virtual void SaveModelToFile(bool is_finish, const char* filename) = 0;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
* \param model_str The string of model * \param model_str The string of model
*/ */
virtual void ModelsFromString(const std::string& model_str, int num_used_model) = 0; virtual void ModelsFromString(const std::string& model_str) = 0;
/*! /*!
* \brief Get max feature index of this model * \brief Get max feature index of this model
...@@ -97,13 +109,27 @@ public: ...@@ -97,13 +109,27 @@ public:
*/ */
virtual int NumberOfSubModels() const = 0; virtual int NumberOfSubModels() const = 0;
/*!
* \brief Get Type name of this boosting object
*/
virtual const char* Name() const = 0;
/*! /*!
* \brief Create boosting object * \brief Create boosting object
* \param type Type of boosting * \param type Type of boosting
* \param config config for boosting
* \param filename name of model file, if existing will continue to train from this model
* \return The boosting object * \return The boosting object
*/ */
static Boosting* CreateBoosting(BoostingType type, static Boosting* CreateBoosting(BoostingType type, const char* filename);
const BoostingConfig* config);
/*!
* \brief Create boosting object from model file
* \param filename name of model file
* \return The boosting object
*/
static Boosting* CreateBoosting(const char* filename);
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -49,15 +49,15 @@ public: ...@@ -49,15 +49,15 @@ public:
const std::string& name, int* out); const std::string& name, int* out);
/*! /*!
* \brief Get double value by specific name of key * \brief Get float value by specific name of key
* \param params Store the key and value for params * \param params Store the key and value for params
* \param name Name of key * \param name Name of key
* \param out Value will assign to out if key exists * \param out Value will assign to out if key exists
* \return True if key exists * \return True if key exists
*/ */
inline bool GetDouble( inline bool GetFloat(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, double* out); const std::string& name, float* out);
/*! /*!
* \brief Get bool value by specific name of key * \brief Get bool value by specific name of key
...@@ -73,7 +73,7 @@ public: ...@@ -73,7 +73,7 @@ public:
/*! \brief Types of boosting */ /*! \brief Types of boosting */
enum BoostingType { enum BoostingType {
kGBDT kGBDT, kUnknow
}; };
...@@ -121,9 +121,9 @@ public: ...@@ -121,9 +121,9 @@ public:
struct ObjectiveConfig: public ConfigBase { struct ObjectiveConfig: public ConfigBase {
public: public:
virtual ~ObjectiveConfig() {} virtual ~ObjectiveConfig() {}
double sigmoid = 1; float sigmoid = 1.0f;
// for lambdarank // for lambdarank
std::vector<double> label_gain; std::vector<float> label_gain;
// for lambdarank // for lambdarank
int max_position = 20; int max_position = 20;
// for binary // for binary
...@@ -135,11 +135,8 @@ public: ...@@ -135,11 +135,8 @@ public:
struct MetricConfig: public ConfigBase { struct MetricConfig: public ConfigBase {
public: public:
virtual ~MetricConfig() {} virtual ~MetricConfig() {}
int early_stopping_round = 0; float sigmoid = 1.0f;
int output_freq = 1; std::vector<float> label_gain;
double sigmoid = 1;
bool is_provide_training_metric = false;
std::vector<double> label_gain;
std::vector<int> eval_at; std::vector<int> eval_at;
void Set(const std::unordered_map<std::string, std::string>& params) override; void Set(const std::unordered_map<std::string, std::string>& params) override;
}; };
...@@ -149,13 +146,13 @@ public: ...@@ -149,13 +146,13 @@ public:
struct TreeConfig: public ConfigBase { struct TreeConfig: public ConfigBase {
public: public:
int min_data_in_leaf = 100; int min_data_in_leaf = 100;
double min_sum_hessian_in_leaf = 10.0f; float min_sum_hessian_in_leaf = 10.0f;
// should > 1, only one leaf means not need to learning // should > 1, only one leaf means not need to learning
int num_leaves = 127; int num_leaves = 127;
int feature_fraction_seed = 2; int feature_fraction_seed = 2;
double feature_fraction = 1.0; float feature_fraction = 1.0f;
// max cache size(unit:MB) for historical histogram. < 0 means not limit // max cache size(unit:MB) for historical histogram. < 0 means not limit
double histogram_pool_size = -1; float histogram_pool_size = -1.0f;
// max depth of tree model. // max depth of tree model.
// Still grow tree by leaf-wise, but limit the max depth to avoid over-fitting // Still grow tree by leaf-wise, but limit the max depth to avoid over-fitting
// And the max leaves will be min(num_leaves, pow(2, max_depth - 1)) // And the max leaves will be min(num_leaves, pow(2, max_depth - 1))
...@@ -174,9 +171,11 @@ enum TreeLearnerType { ...@@ -174,9 +171,11 @@ enum TreeLearnerType {
struct BoostingConfig: public ConfigBase { struct BoostingConfig: public ConfigBase {
public: public:
virtual ~BoostingConfig() {} virtual ~BoostingConfig() {}
int output_freq = 1;
bool is_provide_training_metric = false;
int num_iterations = 10; int num_iterations = 10;
double learning_rate = 0.1; float learning_rate = 0.1f;
double bagging_fraction = 1.0; float bagging_fraction = 1.0f;
int bagging_seed = 3; int bagging_seed = 3;
int bagging_freq = 0; int bagging_freq = 0;
int early_stopping_round = 0; int early_stopping_round = 0;
...@@ -226,7 +225,7 @@ public: ...@@ -226,7 +225,7 @@ public:
delete boosting_config; delete boosting_config;
} }
void Set(const std::unordered_map<std::string, std::string>& params) override; void Set(const std::unordered_map<std::string, std::string>& params) override;
void LoadFromString(const char* str);
private: private:
void GetBoostingType(const std::unordered_map<std::string, std::string>& params); void GetBoostingType(const std::unordered_map<std::string, std::string>& params);
...@@ -263,9 +262,9 @@ inline bool ConfigBase::GetInt( ...@@ -263,9 +262,9 @@ inline bool ConfigBase::GetInt(
return false; return false;
} }
inline bool ConfigBase::GetDouble( inline bool ConfigBase::GetFloat(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, double* out) { const std::string& name, float* out) {
if (params.count(name) > 0) { if (params.count(name) > 0) {
if (!Common::AtofAndCheck(params.at(name).c_str(), out)) { if (!Common::AtofAndCheck(params.at(name).c_str(), out)) {
Log::Fatal("Parameter %s should be float type, passed is [%s]", Log::Fatal("Parameter %s should be float type, passed is [%s]",
......
...@@ -108,9 +108,9 @@ public: ...@@ -108,9 +108,9 @@ public:
* \param idx Index of this record * \param idx Index of this record
* \param value Label value of this record * \param value Label value of this record
*/ */
inline void SetLabelAt(data_size_t idx, double value) inline void SetLabelAt(data_size_t idx, float value)
{ {
label_[idx] = static_cast<float>(value); label_[idx] = value;
} }
/*! /*!
...@@ -118,9 +118,9 @@ public: ...@@ -118,9 +118,9 @@ public:
* \param idx Index of this record * \param idx Index of this record
* \param value Weight value of this record * \param value Weight value of this record
*/ */
inline void SetWeightAt(data_size_t idx, double value) inline void SetWeightAt(data_size_t idx, float value)
{ {
weights_[idx] = static_cast<float>(value); weights_[idx] = value;
} }
/*! /*!
...@@ -128,7 +128,7 @@ public: ...@@ -128,7 +128,7 @@ public:
* \param idx Index of this record * \param idx Index of this record
* \param value Query Id value of this record * \param value Query Id value of this record
*/ */
inline void SetQueryAt(data_size_t idx, double value) inline void SetQueryAt(data_size_t idx, float value)
{ {
queries_[idx] = static_cast<data_size_t>(value); queries_[idx] = static_cast<data_size_t>(value);
} }
...@@ -221,7 +221,7 @@ public: ...@@ -221,7 +221,7 @@ public:
* \param out_label Label will store to this if exists * \param out_label Label will store to this if exists
*/ */
virtual void ParseOneLine(const char* str, virtual void ParseOneLine(const char* str,
std::vector<std::pair<int, double>>* out_features, double* out_label) const = 0; std::vector<std::pair<int, float>>* out_features, float* out_label) const = 0;
/*! /*!
* \brief Create a object of parser, will auto choose the format depend on file * \brief Create a object of parser, will auto choose the format depend on file
...@@ -234,7 +234,7 @@ public: ...@@ -234,7 +234,7 @@ public:
}; };
using PredictFunction = using PredictFunction =
std::function<double(const std::vector<std::pair<int, double>>&)>; std::function<float(const std::vector<std::pair<int, float>>&)>;
/*! \brief The main class of data set, /*! \brief The main class of data set,
* which are used to traning or validation * which are used to traning or validation
......
...@@ -71,7 +71,7 @@ public: ...@@ -71,7 +71,7 @@ public:
* \param idx Index of record * \param idx Index of record
* \param value feature value of record * \param value feature value of record
*/ */
inline void PushData(int tid, data_size_t line_idx, double value) { inline void PushData(int tid, data_size_t line_idx, float value) {
unsigned int bin = bin_mapper_->ValueToBin(value); unsigned int bin = bin_mapper_->ValueToBin(value);
bin_data_->Push(tid, line_idx, bin); bin_data_->Push(tid, line_idx, bin);
} }
...@@ -89,7 +89,7 @@ public: ...@@ -89,7 +89,7 @@ public:
* \param bin * \param bin
* \return Feature value of this bin * \return Feature value of this bin
*/ */
inline double BinToValue(unsigned int bin) inline float BinToValue(unsigned int bin)
const { return bin_mapper_->BinToValue(bin); } const { return bin_mapper_->BinToValue(bin); }
/*! /*!
......
...@@ -12,7 +12,7 @@ namespace LightGBM { ...@@ -12,7 +12,7 @@ namespace LightGBM {
/*! \brief Type of data size, it is better to use signed type*/ /*! \brief Type of data size, it is better to use signed type*/
typedef int32_t data_size_t; typedef int32_t data_size_t;
/*! \brief Type of score, and gradients */ /*! \brief Type of score, and gradients */
typedef double score_t; typedef float score_t;
const score_t kMinScore = -std::numeric_limits<score_t>::infinity(); const score_t kMinScore = -std::numeric_limits<score_t>::infinity();
......
...@@ -11,7 +11,7 @@ namespace LightGBM { ...@@ -11,7 +11,7 @@ namespace LightGBM {
/*! /*!
* \brief The interface of metric. * \brief The interface of metric.
* Metric is used to calculate and output metric result on training / validation data. * Metric is used to calculate metric result
*/ */
class Metric { class Metric {
public: public:
...@@ -27,12 +27,14 @@ public: ...@@ -27,12 +27,14 @@ public:
virtual void Init(const char* test_name, virtual void Init(const char* test_name,
const Metadata& metadata, data_size_t num_data) = 0; const Metadata& metadata, data_size_t num_data) = 0;
virtual const char* GetName() const = 0;
virtual bool is_bigger_better() const = 0;
/*! /*!
* \brief Calcaluting and printing metric result * \brief Calcaluting and printing metric result
* \param iter Current iteration
* \param score Current prediction score * \param score Current prediction score
*/ */
virtual score_t PrintAndGetLoss(int iter, const score_t* score) const = 0; virtual std::vector<float> Eval(const score_t* score) const = 0;
/*! /*!
* \brief Create object of metrics * \brief Create object of metrics
...@@ -41,8 +43,6 @@ public: ...@@ -41,8 +43,6 @@ public:
*/ */
static Metric* CreateMetric(const std::string& type, const MetricConfig& config); static Metric* CreateMetric(const std::string& type, const MetricConfig& config);
bool the_bigger_the_better = false;
int early_stopping_round_ = 0;
}; };
/*! /*!
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
* \brief Initial logic * \brief Initial logic
* \param label_gain Gain for labels, default is 2^i - 1 * \param label_gain Gain for labels, default is 2^i - 1
*/ */
static void Init(std::vector<double> label_gain); static void Init(std::vector<float> label_gain);
/*! /*!
* \brief Calculate the DCG score at position k * \brief Calculate the DCG score at position k
...@@ -64,7 +64,7 @@ public: ...@@ -64,7 +64,7 @@ public:
* \param num_data Number of data * \param num_data Number of data
* \return The DCG score * \return The DCG score
*/ */
static double CalDCGAtK(data_size_t k, const float* label, static float CalDCGAtK(data_size_t k, const float* label,
const score_t* score, data_size_t num_data); const score_t* score, data_size_t num_data);
/*! /*!
...@@ -77,7 +77,7 @@ public: ...@@ -77,7 +77,7 @@ public:
*/ */
static void CalDCG(const std::vector<data_size_t>& ks, static void CalDCG(const std::vector<data_size_t>& ks,
const float* label, const score_t* score, const float* label, const score_t* score,
data_size_t num_data, std::vector<double>* out); data_size_t num_data, std::vector<float>* out);
/*! /*!
* \brief Calculate the Max DCG score at position k * \brief Calculate the Max DCG score at position k
...@@ -86,7 +86,7 @@ public: ...@@ -86,7 +86,7 @@ public:
* \param num_data Number of data * \param num_data Number of data
* \return The max DCG score * \return The max DCG score
*/ */
static double CalMaxDCGAtK(data_size_t k, static float CalMaxDCGAtK(data_size_t k,
const float* label, data_size_t num_data); const float* label, data_size_t num_data);
/*! /*!
...@@ -97,22 +97,22 @@ public: ...@@ -97,22 +97,22 @@ public:
* \param out Output result * \param out Output result
*/ */
static void CalMaxDCG(const std::vector<data_size_t>& ks, static void CalMaxDCG(const std::vector<data_size_t>& ks,
const float* label, data_size_t num_data, std::vector<double>* out); const float* label, data_size_t num_data, std::vector<float>* out);
/*! /*!
* \brief Get discount score of position k * \brief Get discount score of position k
* \param k The position * \param k The position
* \return The discount of this position * \return The discount of this position
*/ */
inline static double GetDiscount(data_size_t k) { return discount_[k]; } inline static float GetDiscount(data_size_t k) { return discount_[k]; }
private: private:
/*! \brief True if inited, avoid init multi times */ /*! \brief True if inited, avoid init multi times */
static bool is_inited_; static bool is_inited_;
/*! \brief store gains for different label */ /*! \brief store gains for different label */
static std::vector<double> label_gain_; static std::vector<float> label_gain_;
/*! \brief store discount score for different position */ /*! \brief store discount score for different position */
static std::vector<double> discount_; static std::vector<float> discount_;
/*! \brief max position for eval */ /*! \brief max position for eval */
static const data_size_t kMaxPosition; static const data_size_t kMaxPosition;
}; };
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
* This function is used for prediction task, if has sigmoid param, the prediction value will be transform by sigmoid function. * This function is used for prediction task, if has sigmoid param, the prediction value will be transform by sigmoid function.
* \return Sigmoid param, if <=0.0 means don't use sigmoid transform on this objective. * \return Sigmoid param, if <=0.0 means don't use sigmoid transform on this objective.
*/ */
virtual double GetSigmoid() const = 0; virtual float GetSigmoid() const = 0;
/*! /*!
* \brief Create object of objective function * \brief Create object of objective function
......
...@@ -36,15 +36,15 @@ public: ...@@ -36,15 +36,15 @@ public:
* \param feature Index of feature; the converted index after removing useless features * \param feature Index of feature; the converted index after removing useless features
* \param threshold Threshold(bin) of split * \param threshold Threshold(bin) of split
* \param real_feature Index of feature, the original index on data * \param real_feature Index of feature, the original index on data
* \param threshold_double Threshold on feature value * \param threshold_float Threshold on feature value
* \param left_value Model Left child output * \param left_value Model Left child output
* \param right_value Model Right child output * \param right_value Model Right child output
* \param gain Split gain * \param gain Split gain
* \return The index of new leaf. * \return The index of new leaf.
*/ */
int Split(int leaf, int feature, unsigned int threshold, int real_feature, int Split(int leaf, int feature, unsigned int threshold, int real_feature,
double threshold_double, score_t left_value, float threshold_float, score_t left_value,
score_t right_value, double gain); score_t right_value, float gain);
/*! \brief Get the output of one leave */ /*! \brief Get the output of one leave */
inline score_t LeafOutput(int leaf) const { return leaf_value_[leaf]; } inline score_t LeafOutput(int leaf) const { return leaf_value_[leaf]; }
...@@ -74,8 +74,8 @@ public: ...@@ -74,8 +74,8 @@ public:
* \param feature_values Feature value of this record * \param feature_values Feature value of this record
* \return Prediction result * \return Prediction result
*/ */
inline score_t Predict(const double* feature_values) const; inline score_t Predict(const float* feature_values) const;
inline int PredictLeafIndex(const double* feature_values) const; inline int PredictLeafIndex(const float* feature_values) const;
/*! \brief Get Number of leaves*/ /*! \brief Get Number of leaves*/
inline int num_leaves() const { return num_leaves_; } inline int num_leaves() const { return num_leaves_; }
...@@ -91,7 +91,7 @@ public: ...@@ -91,7 +91,7 @@ public:
* shrinkage rate (a.k.a learning rate) is used to tune the traning process * shrinkage rate (a.k.a learning rate) is used to tune the traning process
* \param rate The factor of shrinkage * \param rate The factor of shrinkage
*/ */
inline void Shrinkage(double rate) { inline void Shrinkage(float rate) {
for (int i = 0; i < num_leaves_; ++i) { for (int i = 0; i < num_leaves_; ++i) {
leaf_value_[i] = static_cast<score_t>(leaf_value_[i] * rate); leaf_value_[i] = static_cast<score_t>(leaf_value_[i] * rate);
} }
...@@ -119,7 +119,7 @@ private: ...@@ -119,7 +119,7 @@ private:
* \param feature_values Feature value of this record * \param feature_values Feature value of this record
* \return Leaf index * \return Leaf index
*/ */
inline int GetLeaf(const double* feature_values) const; inline int GetLeaf(const float* feature_values) const;
/*! \brief Number of max leaves*/ /*! \brief Number of max leaves*/
int max_leaves_; int max_leaves_;
...@@ -137,9 +137,9 @@ private: ...@@ -137,9 +137,9 @@ private:
/*! \brief A non-leaf node's split threshold in bin */ /*! \brief A non-leaf node's split threshold in bin */
unsigned int* threshold_in_bin_; unsigned int* threshold_in_bin_;
/*! \brief A non-leaf node's split threshold in feature value */ /*! \brief A non-leaf node's split threshold in feature value */
double* threshold_; float* threshold_;
/*! \brief A non-leaf node's split gain */ /*! \brief A non-leaf node's split gain */
double* split_gain_; float* split_gain_;
// used for leaf node // used for leaf node
/*! \brief The parent of leaf */ /*! \brief The parent of leaf */
int* leaf_parent_; int* leaf_parent_;
...@@ -150,12 +150,12 @@ private: ...@@ -150,12 +150,12 @@ private:
}; };
inline score_t Tree::Predict(const double* feature_values) const { inline score_t Tree::Predict(const float* feature_values) const {
int leaf = GetLeaf(feature_values); int leaf = GetLeaf(feature_values);
return LeafOutput(leaf); return LeafOutput(leaf);
} }
inline int Tree::PredictLeafIndex(const double* feature_values) const { inline int Tree::PredictLeafIndex(const float* feature_values) const {
int leaf = GetLeaf(feature_values); int leaf = GetLeaf(feature_values);
return leaf; return leaf;
} }
...@@ -174,7 +174,7 @@ inline int Tree::GetLeaf(const std::vector<BinIterator*>& iterators, ...@@ -174,7 +174,7 @@ inline int Tree::GetLeaf(const std::vector<BinIterator*>& iterators,
return ~node; return ~node;
} }
inline int Tree::GetLeaf(const double* feature_values) const { inline int Tree::GetLeaf(const float* feature_values) const {
int node = 0; int node = 0;
while (node >= 0) { while (node >= 0) {
if (feature_values[split_feature_real_[node]] <= threshold_[node]) { if (feature_values[split_feature_real_[node]] <= threshold_[node]) {
......
...@@ -103,9 +103,9 @@ inline static const char* Atoi(const char* p, int* out) { ...@@ -103,9 +103,9 @@ inline static const char* Atoi(const char* p, int* out) {
} }
//ref to http://www.leapsecond.com/tools/fast_atof.c //ref to http://www.leapsecond.com/tools/fast_atof.c
inline static const char* Atof(const char* p, double* out) { inline static const char* Atof(const char* p, float* out) {
int frac; int frac;
double sign, value, scale; float sign, value, scale;
*out = 0; *out = 0;
// Skip leading white space, if any. // Skip leading white space, if any.
while (*p == ' ') { while (*p == ' ') {
...@@ -113,9 +113,9 @@ inline static const char* Atof(const char* p, double* out) { ...@@ -113,9 +113,9 @@ inline static const char* Atof(const char* p, double* out) {
} }
// Get sign, if any. // Get sign, if any.
sign = 1.0; sign = 1.0f;
if (*p == '-') { if (*p == '-') {
sign = -1.0; sign = -1.0f;
++p; ++p;
} }
else if (*p == '+') { else if (*p == '+') {
...@@ -125,24 +125,24 @@ inline static const char* Atof(const char* p, double* out) { ...@@ -125,24 +125,24 @@ inline static const char* Atof(const char* p, double* out) {
// is a number // is a number
if ((*p >= '0' && *p <= '9') || *p == '.' || *p == 'e' || *p == 'E') { if ((*p >= '0' && *p <= '9') || *p == '.' || *p == 'e' || *p == 'E') {
// Get digits before decimal point or exponent, if any. // Get digits before decimal point or exponent, if any.
for (value = 0.0; *p >= '0' && *p <= '9'; ++p) { for (value = 0.0f; *p >= '0' && *p <= '9'; ++p) {
value = value * 10.0 + (*p - '0'); value = value * 10.0f + (*p - '0');
} }
// Get digits after decimal point, if any. // Get digits after decimal point, if any.
if (*p == '.') { if (*p == '.') {
double pow10 = 10.0; float pow10 = 10.0f;
++p; ++p;
while (*p >= '0' && *p <= '9') { while (*p >= '0' && *p <= '9') {
value += (*p - '0') / pow10; value += (*p - '0') / pow10;
pow10 *= 10.0; pow10 *= 10.0f;
++p; ++p;
} }
} }
// Handle exponent, if any. // Handle exponent, if any.
frac = 0; frac = 0;
scale = 1.0; scale = 1.0f;
if ((*p == 'e') || (*p == 'E')) { if ((*p == 'e') || (*p == 'E')) {
unsigned int expon; unsigned int expon;
// Get sign of exponent, if any. // Get sign of exponent, if any.
...@@ -157,11 +157,9 @@ inline static const char* Atof(const char* p, double* out) { ...@@ -157,11 +157,9 @@ inline static const char* Atof(const char* p, double* out) {
for (expon = 0; *p >= '0' && *p <= '9'; ++p) { for (expon = 0; *p >= '0' && *p <= '9'; ++p) {
expon = expon * 10 + (*p - '0'); expon = expon * 10 + (*p - '0');
} }
if (expon > 308) expon = 308; if (expon > 38) expon = 38;
// Calculate scaling factor.
while (expon >= 50) { scale *= 1E50; expon -= 50; }
while (expon >= 8) { scale *= 1E8; expon -= 8; } while (expon >= 8) { scale *= 1E8; expon -= 8; }
while (expon > 0) { scale *= 10.0; expon -= 1; } while (expon > 0) { scale *= 10.0f; expon -= 1; }
} }
// Return signed and scaled floating point result. // Return signed and scaled floating point result.
*out = sign * (frac ? (value / scale) : (value * scale)); *out = sign * (frac ? (value / scale) : (value * scale));
...@@ -177,9 +175,9 @@ inline static const char* Atof(const char* p, double* out) { ...@@ -177,9 +175,9 @@ inline static const char* Atof(const char* p, double* out) {
std::string tmp_str(p, cnt); std::string tmp_str(p, cnt);
std::transform(tmp_str.begin(), tmp_str.end(), tmp_str.begin(), ::tolower); std::transform(tmp_str.begin(), tmp_str.end(), tmp_str.begin(), ::tolower);
if (tmp_str == std::string("na") || tmp_str == std::string("nan")) { if (tmp_str == std::string("na") || tmp_str == std::string("nan")) {
*out = 0; *out = 0.0f;
} else if( tmp_str == std::string("inf") || tmp_str == std::string("infinity")) { } else if( tmp_str == std::string("inf") || tmp_str == std::string("infinity")) {
*out = sign * 1e308; *out = sign * static_cast<float>(1e38);
} }
else { else {
Log::Fatal("Unknow token %s in data file", tmp_str.c_str()); Log::Fatal("Unknow token %s in data file", tmp_str.c_str());
...@@ -203,7 +201,7 @@ inline bool AtoiAndCheck(const char* p, int* out) { ...@@ -203,7 +201,7 @@ inline bool AtoiAndCheck(const char* p, int* out) {
return true; return true;
} }
inline bool AtofAndCheck(const char* p, double* out) { inline bool AtofAndCheck(const char* p, float* out) {
const char* after = Atof(p, out); const char* after = Atof(p, out);
if (*after != '\0') { if (*after != '\0') {
return false; return false;
...@@ -230,56 +228,57 @@ inline static std::string ArrayToString(const T* arr, int n, char delimiter) { ...@@ -230,56 +228,57 @@ inline static std::string ArrayToString(const T* arr, int n, char delimiter) {
if (n <= 0) { if (n <= 0) {
return std::string(""); return std::string("");
} }
std::stringstream ss; std::stringstream str_buf;
ss << arr[0]; str_buf << arr[0];
for (int i = 1; i < n; ++i) { for (int i = 1; i < n; ++i) {
ss << delimiter; str_buf << delimiter;
ss << arr[i]; str_buf << arr[i];
} }
return ss.str(); return str_buf.str();
} }
inline static void StringToIntArray(const std::string& str, char delimiter, size_t n, int* out) { template<typename T>
std::vector<std::string> strs = Split(str.c_str(), delimiter); inline static std::string ArrayToString(std::vector<T> arr, char delimiter) {
if (strs.size() != n) { if (arr.size() <= 0) {
Log::Fatal("StringToIntArray error, size doesn't matched."); return std::string("");
} }
for (size_t i = 0; i < strs.size(); ++i) { std::stringstream str_buf;
strs[i] = Trim(strs[i]); str_buf << arr[0];
Atoi(strs[i].c_str(), &out[i]); for (size_t i = 1; i < arr.size(); ++i) {
str_buf << delimiter;
str_buf << arr[i];
} }
return str_buf.str();
} }
inline static void StringToDoubleArray(const std::string& str, char delimiter, size_t n, double* out) { inline static void StringToIntArray(const std::string& str, char delimiter, size_t n, int* out) {
std::vector<std::string> strs = Split(str.c_str(), delimiter); std::vector<std::string> strs = Split(str.c_str(), delimiter);
if (strs.size() != n) { if (strs.size() != n) {
Log::Fatal("StringToDoubleArray error, size doesn't matched."); Log::Fatal("StringToIntArray error, size doesn't matched.");
} }
for (size_t i = 0; i < strs.size(); ++i) { for (size_t i = 0; i < strs.size(); ++i) {
strs[i] = Trim(strs[i]); strs[i] = Trim(strs[i]);
Atof(strs[i].c_str(), &out[i]); Atoi(strs[i].c_str(), &out[i]);
} }
} }
inline static void StringToDoubleArray(const std::string& str, char delimiter, size_t n, float* out) { inline static void StringToFloatArray(const std::string& str, char delimiter, size_t n, float* out) {
std::vector<std::string> strs = Split(str.c_str(), delimiter); std::vector<std::string> strs = Split(str.c_str(), delimiter);
if (strs.size() != n) { if (strs.size() != n) {
Log::Fatal("StringToDoubleArray error, size doesn't matched."); Log::Fatal("StringToFloatArray error, size doesn't matched.");
} }
double tmp;
for (size_t i = 0; i < strs.size(); ++i) { for (size_t i = 0; i < strs.size(); ++i) {
strs[i] = Trim(strs[i]); strs[i] = Trim(strs[i]);
Atof(strs[i].c_str(), &tmp); Atof(strs[i].c_str(), &out[i]);
out[i] = static_cast<float>(tmp);
} }
} }
inline static std::vector<double> StringToDoubleArray(const std::string& str, char delimiter) { inline static std::vector<float> StringToFloatArray(const std::string& str, char delimiter) {
std::vector<std::string> strs = Split(str.c_str(), delimiter); std::vector<std::string> strs = Split(str.c_str(), delimiter);
std::vector<double> ret; std::vector<float> ret;
for (size_t i = 0; i < strs.size(); ++i) { for (size_t i = 0; i < strs.size(); ++i) {
strs[i] = Trim(strs[i]); strs[i] = Trim(strs[i]);
double val = 0.0; float val = 0.0f;
Atof(strs[i].c_str(), &val); Atof(strs[i].c_str(), &val);
ret.push_back(val); ret.push_back(val);
} }
......
...@@ -56,7 +56,7 @@ public: ...@@ -56,7 +56,7 @@ public:
} }
fclose(file); fclose(file);
first_line_ = str_buf.str(); first_line_ = str_buf.str();
Log::Info("skip header:\"%s\" in file %s", first_line_.c_str(), filename_); Log::Debug("skip header:\"%s\" in file %s", first_line_.c_str(), filename_);
} }
} }
/*! /*!
......
...@@ -121,17 +121,14 @@ void Application::LoadData() { ...@@ -121,17 +121,14 @@ void Application::LoadData() {
// predition is needed if using input initial model(continued train) // predition is needed if using input initial model(continued train)
PredictFunction predict_fun = nullptr; PredictFunction predict_fun = nullptr;
Predictor* predictor = nullptr; Predictor* predictor = nullptr;
// load init model // need to continue train
if (config_.io_config.input_model.size() > 0) {
LoadModel();
if (boosting_->NumberOfSubModels() > 0) { if (boosting_->NumberOfSubModels() > 0) {
predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index); predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index, -1);
predict_fun = predict_fun =
[&predictor](const std::vector<std::pair<int, double>>& features) { [&predictor](const std::vector<std::pair<int, float>>& features) {
return predictor->PredictRawOneLine(features); return predictor->PredictRawOneLine(features);
}; };
} }
}
// sync up random seed for data partition // sync up random seed for data partition
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
config_.io_config.data_random_seed = config_.io_config.data_random_seed =
...@@ -156,7 +153,7 @@ void Application::LoadData() { ...@@ -156,7 +153,7 @@ void Application::LoadData() {
train_data_->SaveBinaryFile(); train_data_->SaveBinaryFile();
} }
// create training metric // create training metric
if (config_.metric_config.is_provide_training_metric) { if (config_.boosting_config->is_provide_training_metric) {
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric_types) {
Metric* metric = Metric* metric =
Metric::CreateMetric(metric_type, config_.metric_config); Metric::CreateMetric(metric_type, config_.metric_config);
...@@ -213,12 +210,13 @@ void Application::InitTrain() { ...@@ -213,12 +210,13 @@ void Application::InitTrain() {
gbdt_config->tree_config.feature_fraction_seed = gbdt_config->tree_config.feature_fraction_seed =
GlobalSyncUpByMin<int>(gbdt_config->tree_config.feature_fraction_seed); GlobalSyncUpByMin<int>(gbdt_config->tree_config.feature_fraction_seed);
gbdt_config->tree_config.feature_fraction = gbdt_config->tree_config.feature_fraction =
GlobalSyncUpByMin<double>(gbdt_config->tree_config.feature_fraction); GlobalSyncUpByMin<float>(gbdt_config->tree_config.feature_fraction);
} }
} }
// create boosting // create boosting
boosting_ = boosting_ =
Boosting::CreateBoosting(config_.boosting_type, config_.boosting_config); Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.input_model.c_str());
// create objective function // create objective function
objective_fun_ = objective_fun_ =
ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
...@@ -228,9 +226,8 @@ void Application::InitTrain() { ...@@ -228,9 +226,8 @@ void Application::InitTrain() {
// initialize the objective function // initialize the objective function
objective_fun_->Init(train_data_->metadata(), train_data_->num_data()); objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
// initialize the boosting // initialize the boosting
boosting_->Init(train_data_, objective_fun_, boosting_->Init(config_.boosting_config, train_data_, objective_fun_,
ConstPtrInVectorWarpper<Metric>(train_metric_), ConstPtrInVectorWarpper<Metric>(train_metric_));
config_.io_config.output_model.c_str());
// add validation data into boosting // add validation data into boosting
for (size_t i = 0; i < valid_datas_.size(); ++i) { for (size_t i = 0; i < valid_datas_.size(); ++i) {
boosting_->AddDataset(valid_datas_[i], boosting_->AddDataset(valid_datas_[i],
...@@ -240,15 +237,30 @@ void Application::InitTrain() { ...@@ -240,15 +237,30 @@ void Application::InitTrain() {
} }
void Application::Train() { void Application::Train() {
Log::Info("Start train"); Log::Info("Start train ...");
boosting_->Train(); int total_iter = config_.boosting_config->num_iterations;
Log::Info("Finish train"); bool is_finished = false;
bool need_eval = true;
auto start_time = std::chrono::high_resolution_clock::now();
for (int iter = 0; iter < total_iter && !is_finished; ++iter) {
is_finished = boosting_->TrainOneIter(nullptr, nullptr, need_eval);
auto end_time = std::chrono::high_resolution_clock::now();
// output used time per iteration
Log::Info("%f seconds elapsed, finished %d iteration", std::chrono::duration<double,
std::milli>(end_time - start_time) * 1e-3, iter + 1);
boosting_->SaveModelToFile(is_finished, config_.io_config.output_model.c_str());
}
is_finished = true;
// save model to file
boosting_->SaveModelToFile(is_finished, config_.io_config.output_model.c_str());
Log::Info("Finished train");
} }
void Application::Predict() { void Application::Predict() {
// create predictor // create predictor
Predictor predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index); Predictor predictor(boosting_, config_.io_config.is_sigmoid,
config_.predict_leaf_index, config_.io_config.num_model_predict);
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header); config_.io_config.output_result.c_str(), config_.io_config.has_header);
Log::Info("Finish predict."); Log::Info("Finish predict.");
...@@ -256,21 +268,10 @@ void Application::Predict() { ...@@ -256,21 +268,10 @@ void Application::Predict() {
void Application::InitPredict() { void Application::InitPredict() {
boosting_ = boosting_ =
Boosting::CreateBoosting(config_.boosting_type, config_.boosting_config); Boosting::CreateBoosting(config_.io_config.input_model.c_str());
LoadModel();
Log::Info("Finish predict initilization."); Log::Info("Finish predict initilization.");
} }
void Application::LoadModel() {
TextReader<size_t> model_reader(config_.io_config.input_model.c_str(), false);
model_reader.ReadAllLines();
std::stringstream ss;
for (auto& line : model_reader.Lines()) {
ss << line << '\n';
}
boosting_->ModelsFromString(ss.str(), config_.io_config.num_model_predict);
}
template<typename T> template<typename T>
T Application::GlobalSyncUpByMin(T& local) { T Application::GlobalSyncUpByMin(T& local) {
T global = local; T global = local;
......
...@@ -28,8 +28,9 @@ public: ...@@ -28,8 +28,9 @@ public:
* \param is_sigmoid True if need to predict result with sigmoid transform(if needed, like binary classification) * \param is_sigmoid True if need to predict result with sigmoid transform(if needed, like binary classification)
* \param predict_leaf_index True if output leaf index instead of prediction score * \param predict_leaf_index True if output leaf index instead of prediction score
*/ */
Predictor(const Boosting* boosting, bool is_simgoid, bool predict_leaf_index) Predictor(const Boosting* boosting, bool is_simgoid, bool is_predict_leaf_index, int num_used_model)
: is_simgoid_(is_simgoid), predict_leaf_index(predict_leaf_index) { : is_simgoid_(is_simgoid), is_predict_leaf_index_(is_predict_leaf_index),
num_used_model_(num_used_model) {
boosting_ = boosting; boosting_ = boosting;
num_features_ = boosting_->MaxFeatureIdx() + 1; num_features_ = boosting_->MaxFeatureIdx() + 1;
#pragma omp parallel #pragma omp parallel
...@@ -37,9 +38,9 @@ public: ...@@ -37,9 +38,9 @@ public:
{ {
num_threads_ = omp_get_num_threads(); num_threads_ = omp_get_num_threads();
} }
features_ = new double*[num_threads_]; features_ = new float*[num_threads_];
for (int i = 0; i < num_threads_; ++i) { for (int i = 0; i < num_threads_; ++i) {
features_[i] = new double[num_features_]; features_[i] = new float[num_features_];
} }
} }
/*! /*!
...@@ -59,10 +60,10 @@ public: ...@@ -59,10 +60,10 @@ public:
* \param features Feature for this record * \param features Feature for this record
* \return Prediction result * \return Prediction result
*/ */
double PredictRawOneLine(const std::vector<std::pair<int, double>>& features) { float PredictRawOneLine(const std::vector<std::pair<int, float>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result without sigmoid transformation // get result without sigmoid transformation
return boosting_->PredictRaw(features_[tid]); return boosting_->PredictRaw(features_[tid], num_used_model_);
} }
/*! /*!
...@@ -70,10 +71,10 @@ public: ...@@ -70,10 +71,10 @@ public:
* \param features Feature for this record * \param features Feature for this record
* \return Predictied leaf index * \return Predictied leaf index
*/ */
std::vector<int> PredictLeafIndexOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<int> PredictLeafIndexOneLine(const std::vector<std::pair<int, float>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result for leaf index // get result for leaf index
return boosting_->PredictLeafIndex(features_[tid]); return boosting_->PredictLeafIndex(features_[tid], num_used_model_);
} }
/*! /*!
...@@ -81,10 +82,10 @@ public: ...@@ -81,10 +82,10 @@ public:
* \param features Feature of this record * \param features Feature of this record
* \return Prediction result * \return Prediction result
*/ */
double PredictOneLine(const std::vector<std::pair<int, double>>& features) { float PredictOneLine(const std::vector<std::pair<int, float>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed // get result with sigmoid transform if needed
return boosting_->Predict(features_[tid]); return boosting_->Predict(features_[tid], num_used_model_);
} }
/*! /*!
* \brief predicting on data, then saving result to disk * \brief predicting on data, then saving result to disk
...@@ -111,17 +112,16 @@ public: ...@@ -111,17 +112,16 @@ public:
} }
// function for parse data // function for parse data
std::function<void(const char*, std::vector<std::pair<int, double>>*)> parser_fun; std::function<void(const char*, std::vector<std::pair<int, float>>*)> parser_fun;
double tmp_label; float tmp_label;
parser_fun = [this, &parser, &tmp_label] parser_fun = [this, &parser, &tmp_label]
(const char* buffer, std::vector<std::pair<int, double>>* feature) { (const char* buffer, std::vector<std::pair<int, float>>* feature) {
parser->ParseOneLine(buffer, feature, &tmp_label); parser->ParseOneLine(buffer, feature, &tmp_label);
}; };
std::function<std::string(const std::vector<std::pair<int, double>>&)> predict_fun; std::function<std::string(const std::vector<std::pair<int, float>>&)> predict_fun;
if (predict_leaf_index) { if (is_predict_leaf_index_) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, float>>& features){
std::vector<int> predicted_leaf_index = PredictLeafIndexOneLine(features); std::vector<int> predicted_leaf_index = PredictLeafIndexOneLine(features);
std::stringstream result_ss; std::stringstream result_ss;
for (size_t i = 0; i < predicted_leaf_index.size(); ++i){ for (size_t i = 0; i < predicted_leaf_index.size(); ++i){
...@@ -135,12 +135,12 @@ public: ...@@ -135,12 +135,12 @@ public:
} }
else { else {
if (is_simgoid_) { if (is_simgoid_) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, float>>& features){
return std::to_string(PredictOneLine(features)); return std::to_string(PredictOneLine(features));
}; };
} }
else { else {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, float>>& features){
return std::to_string(PredictRawOneLine(features)); return std::to_string(PredictRawOneLine(features));
}; };
} }
...@@ -148,10 +148,10 @@ public: ...@@ -148,10 +148,10 @@ public:
std::function<void(data_size_t, const std::vector<std::string>&)> process_fun = std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
[this, &parser_fun, &predict_fun, &result_file] [this, &parser_fun, &predict_fun, &result_file]
(data_size_t, const std::vector<std::string>& lines) { (data_size_t, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, float>> oneline_features;
std::vector<std::string> pred_result(lines.size(), ""); std::vector<std::string> pred_result(lines.size(), "");
#pragma omp parallel for schedule(static) private(oneline_features) #pragma omp parallel for schedule(static) private(oneline_features)
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); i++) { for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
oneline_features.clear(); oneline_features.clear();
// parser // parser
parser_fun(lines[i].c_str(), &oneline_features); parser_fun(lines[i].c_str(), &oneline_features);
...@@ -171,10 +171,10 @@ public: ...@@ -171,10 +171,10 @@ public:
} }
private: private:
int PutFeatureValuesToBuffer(const std::vector<std::pair<int, double>>& features) { int PutFeatureValuesToBuffer(const std::vector<std::pair<int, float>>& features) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
// init feature value // init feature value
std::memset(features_[tid], 0, sizeof(double)*num_features_); std::memset(features_[tid], 0, sizeof(float)*num_features_);
// put feature value // put feature value
for (const auto& p : features) { for (const auto& p : features) {
if (p.first < num_features_) { if (p.first < num_features_) {
...@@ -186,7 +186,7 @@ private: ...@@ -186,7 +186,7 @@ private:
/*! \brief Boosting model */ /*! \brief Boosting model */
const Boosting* boosting_; const Boosting* boosting_;
/*! \brief Buffer for feature values */ /*! \brief Buffer for feature values */
double** features_; float** features_;
/*! \brief Number of features */ /*! \brief Number of features */
int num_features_; int num_features_;
/*! \brief True if need to predict result with sigmoid transform */ /*! \brief True if need to predict result with sigmoid transform */
...@@ -194,7 +194,9 @@ private: ...@@ -194,7 +194,9 @@ private:
/*! \brief Number of threads */ /*! \brief Number of threads */
int num_threads_; int num_threads_;
/*! \brief True if output leaf index instead of prediction score */ /*! \brief True if output leaf index instead of prediction score */
bool predict_leaf_index; bool is_predict_leaf_index_;
/*! \brief Number of used model */
int num_used_model_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -3,13 +3,57 @@ ...@@ -3,13 +3,57 @@
namespace LightGBM { namespace LightGBM {
Boosting* Boosting::CreateBoosting(BoostingType type, BoostingType GetBoostingTypeFromModelFile(const char* filename) {
const BoostingConfig* config) { TextReader<size_t> model_reader(filename, true);
std::string type = model_reader.first_line();
if (type == std::string("gbdt")) {
return BoostingType::kGBDT;
}
return BoostingType::kUnknow;
}
void LoadFileToBoosting(Boosting* boosting, const char* filename) {
if (boosting != nullptr) {
TextReader<size_t> model_reader(filename, true);
model_reader.ReadAllLines();
std::stringstream str_buf;
for (auto& line : model_reader.Lines()) {
str_buf << line << '\n';
}
boosting->ModelsFromString(str_buf.str());
}
}
Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
if (filename[0] == '\0') {
if (type == BoostingType::kGBDT) { if (type == BoostingType::kGBDT) {
return new GBDT(config); return new GBDT();
} else { } else {
return nullptr; return nullptr;
} }
} else {
Boosting* ret = nullptr;
auto type_in_file = GetBoostingTypeFromModelFile(filename);
if (type_in_file == type) {
if (type == BoostingType::kGBDT) {
ret = new GBDT();
}
LoadFileToBoosting(ret, filename);
} else {
Log::Fatal("Boosting type in parameter is not same with the type in model file");
}
return ret;
}
}
Boosting* Boosting::CreateBoosting(const char* filename) {
auto type = GetBoostingTypeFromModelFile(filename);
Boosting* ret = nullptr;
if (type == BoostingType::kGBDT) {
ret = new GBDT();
}
LoadFileToBoosting(ret, filename);
return ret;
} }
} // namespace LightGBM } // namespace LightGBM
...@@ -16,13 +16,10 @@ ...@@ -16,13 +16,10 @@
namespace LightGBM { namespace LightGBM {
GBDT::GBDT(const BoostingConfig* config) GBDT::GBDT()
: tree_learner_(nullptr), train_score_updater_(nullptr), : tree_learner_(nullptr), train_score_updater_(nullptr),
gradients_(nullptr), hessians_(nullptr), gradients_(nullptr), hessians_(nullptr),
out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr) { out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr) {
max_feature_idx_ = 0;
gbdt_config_ = dynamic_cast<const GBDTConfig*>(config);
early_stopping_round_ = gbdt_config_->early_stopping_round;
} }
GBDT::~GBDT() { GBDT::~GBDT() {
...@@ -40,8 +37,12 @@ GBDT::~GBDT() { ...@@ -40,8 +37,12 @@ GBDT::~GBDT() {
} }
} }
void GBDT::Init(const Dataset* train_data, const ObjectiveFunction* object_function, void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics, const char* output_model_filename) { const std::vector<const Metric*>& training_metrics) {
gbdt_config_ = dynamic_cast<const GBDTConfig*>(config);
iter_ = 0;
max_feature_idx_ = 0;
early_stopping_round_ = gbdt_config_->early_stopping_round;
train_data_ = train_data; train_data_ = train_data;
// create tree learner // create tree learner
tree_learner_ = tree_learner_ =
...@@ -57,8 +58,10 @@ void GBDT::Init(const Dataset* train_data, const ObjectiveFunction* object_funct ...@@ -57,8 +58,10 @@ void GBDT::Init(const Dataset* train_data, const ObjectiveFunction* object_funct
train_score_updater_ = new ScoreUpdater(train_data_); train_score_updater_ = new ScoreUpdater(train_data_);
num_data_ = train_data_->num_data(); num_data_ = train_data_->num_data();
// create buffer for gradients and hessians // create buffer for gradients and hessians
if (object_function_ != nullptr) {
gradients_ = new score_t[num_data_]; gradients_ = new score_t[num_data_];
hessians_ = new score_t[num_data_]; hessians_ = new score_t[num_data_];
}
// get max feature index // get max feature index
max_feature_idx_ = train_data_->num_total_features() - 1; max_feature_idx_ = train_data_->num_total_features() - 1;
...@@ -77,18 +80,8 @@ void GBDT::Init(const Dataset* train_data, const ObjectiveFunction* object_funct ...@@ -77,18 +80,8 @@ void GBDT::Init(const Dataset* train_data, const ObjectiveFunction* object_funct
// initialize random generator // initialize random generator
random_ = Random(gbdt_config_->bagging_seed); random_ = Random(gbdt_config_->bagging_seed);
// open model output file
#ifdef _MSC_VER
fopen_s(&output_model_file, output_model_filename, "w");
#else
output_model_file = fopen(output_model_filename, "w");
#endif
// output models
fprintf(output_model_file, "%s", this->ModelsToString().c_str());
} }
void GBDT::AddDataset(const Dataset* valid_data, void GBDT::AddDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) { const std::vector<const Metric*>& valid_metrics) {
// for a validation dataset, we need its score and metric // for a validation dataset, we need its score and metric
...@@ -165,72 +158,46 @@ void GBDT::UpdateScoreOutOfBag(const Tree* tree) { ...@@ -165,72 +158,46 @@ void GBDT::UpdateScoreOutOfBag(const Tree* tree) {
} }
} }
void GBDT::Train() { bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
// training start time
auto start_time = std::chrono::high_resolution_clock::now();
for (int iter = 0; iter < gbdt_config_->num_iterations; ++iter) {
// boosting first // boosting first
if (gradient == nullptr || hessian == nullptr) {
Boosting(); Boosting();
gradient = gradients_;
hessian = hessians_;
}
// bagging logic // bagging logic
Bagging(iter); Bagging(iter_);
// train a new tree // train a new tree
Tree * new_tree = TrainOneTree(); Tree * new_tree = tree_learner_->Train(gradient, hessian);
// if cannot learn a new tree, then stop // if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) { if (new_tree->num_leaves() <= 1) {
Log::Info("Can't training anymore, there isn't any leaf meets split requirements."); Log::Info("Can't training anymore, there isn't any leaf meets split requirements.");
break; return true;
} }
// shrinkage by learning rate // shrinkage by learning rate
new_tree->Shrinkage(gbdt_config_->learning_rate); new_tree->Shrinkage(gbdt_config_->learning_rate);
// update score // update score
UpdateScore(new_tree); UpdateScore(new_tree);
UpdateScoreOutOfBag(new_tree); UpdateScoreOutOfBag(new_tree);
bool is_met_early_stopping = false;
// print message for metric // print message for metric
bool is_early_stopping = OutputMetric(iter + 1); if (is_eval) {
is_met_early_stopping = OutputMetric(iter_ + 1);
}
// add model // add model
models_.push_back(new_tree); models_.push_back(new_tree);
// save model to file per iteration ++iter_;
if (early_stopping_round_ > 0){ if (is_met_early_stopping) {
// if use early stopping, save previous model at (iter - early_stopping_round_) iteration Log::Info("Early stopping at iteration %d, the best iteration round is %d",
if (iter >= early_stopping_round_){ iter_, iter_ - early_stopping_round_);
fprintf(output_model_file, "Tree=%d\n", iter - early_stopping_round_); // pop last early_stopping_round_ models
Tree * printing_tree = models_.at(iter - early_stopping_round_); for (int i = 0; i < early_stopping_round_; ++i) {
fprintf(output_model_file, "%s\n", printing_tree->ToString().c_str()); delete models_.back();
fflush(output_model_file); models_.pop_back();
}
}
else{
fprintf(output_model_file, "Tree=%d\n", iter);
fprintf(output_model_file, "%s\n", new_tree->ToString().c_str());
fflush(output_model_file);
}
auto end_time = std::chrono::high_resolution_clock::now();
// output used time per iteration
Log::Info("%f seconds elapsed, finished %d iteration", std::chrono::duration<double,
std::milli>(end_time - start_time) * 1e-3, iter + 1);
if (is_early_stopping) {
// close file with an early-stopping message
Log::Info("Early stopping at iteration %d, the best iteration round is %d", iter + 1, iter + 1 - early_stopping_round_);
FeatureImportance(iter - early_stopping_round_ + 1);
fclose(output_model_file);
return;
}
}
// close file
int remaining_models = gbdt_config_->num_iterations - early_stopping_round_;
if (early_stopping_round_ > 0 && remaining_models > 0) {
for (int iter = remaining_models; iter < static_cast<int>(models_.size()); ++iter){
fprintf(output_model_file, "Tree=%d\n", iter);
fprintf(output_model_file, "%s\n", models_.at(iter)->ToString().c_str());
} }
fflush(output_model_file);
} }
FeatureImportance(static_cast<int>(models_.size())); return is_met_early_stopping;
fclose(output_model_file);
}
Tree* GBDT::TrainOneTree() {
return tree_learner_->Train(gradients_, hessians_);
} }
void GBDT::UpdateScore(const Tree* tree) { void GBDT::UpdateScore(const Tree* tree) {
...@@ -245,57 +212,129 @@ void GBDT::UpdateScore(const Tree* tree) { ...@@ -245,57 +212,129 @@ void GBDT::UpdateScore(const Tree* tree) {
bool GBDT::OutputMetric(int iter) { bool GBDT::OutputMetric(int iter) {
bool ret = false; bool ret = false;
// print training metric // print training metric
if ((iter % gbdt_config_->output_freq) == 0) {
for (auto& sub_metric : training_metrics_) { for (auto& sub_metric : training_metrics_) {
sub_metric->PrintAndGetLoss(iter, train_score_updater_->score()); auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score());
Log::Info("Iteration:%d, %s : %s", iter, name, Common::ArrayToString<float>(scores, ' ').c_str());
}
} }
// print validation metric // print validation metric
if ((iter % gbdt_config_->output_freq) == 0 || early_stopping_round_ > 0) {
for (size_t i = 0; i < valid_metrics_.size(); ++i) { for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) { for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
score_t test_score_ = valid_metrics_[i][j]->PrintAndGetLoss(iter, valid_score_updater_[i]->score()); auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
if (!ret && early_stopping_round_ > 0){ if ((iter % gbdt_config_->output_freq) == 0) {
bool the_bigger_the_better_ = valid_metrics_[i][j]->the_bigger_the_better; auto name = valid_metrics_[i][j]->GetName();
Log::Info("Iteration:%d, %s : %s", iter, name, Common::ArrayToString<float>(test_scores, ' ').c_str());
}
if (!ret && early_stopping_round_ > 0) {
bool the_bigger_the_better = valid_metrics_[i][j]->is_bigger_better();
if (best_score_[i][j] < 0 if (best_score_[i][j] < 0
|| (!the_bigger_the_better_ && test_score_ < best_score_[i][j]) || (!the_bigger_the_better && test_scores.back() < best_score_[i][j])
|| ( the_bigger_the_better_ && test_score_ > best_score_[i][j])){ || (the_bigger_the_better && test_scores.back() > best_score_[i][j])) {
best_score_[i][j] = test_score_; best_score_[i][j] = test_scores.back();
best_iter_[i][j] = iter; best_iter_[i][j] = iter;
} } else {
else {
if (iter - best_iter_[i][j] >= early_stopping_round_) ret = true; if (iter - best_iter_[i][j] >= early_stopping_round_) ret = true;
} }
} }
} }
} }
}
return ret;
}
/*! \brief Get eval result */
std::vector<std::string> GBDT::EvalCurrent(bool is_eval_train) const {
std::vector<std::string> ret;
if (is_eval_train) {
for (auto& sub_metric : training_metrics_) {
auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score());
std::stringstream str_buf;
str_buf << name << " : " << Common::ArrayToString<float>(scores, ' ');
ret.emplace_back(str_buf.str());
}
}
for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
auto name = valid_metrics_[i][j]->GetName();
auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score());
std::stringstream str_buf;
str_buf << name << " : " << Common::ArrayToString<float>(test_scores, ' ');
ret.emplace_back(str_buf.str());
}
}
return ret;
}
/*! \brief Get prediction result */
const std::vector<const score_t*> GBDT::PredictCurrent(bool is_predict_train) const {
std::vector<const score_t*> ret;
if (is_predict_train) {
ret.push_back(train_score_updater_->score());
}
for (size_t i = 0; i < valid_metrics_.size(); ++i) {
ret.push_back(valid_score_updater_[i]->score());
}
return ret; return ret;
} }
void GBDT::Boosting() { void GBDT::Boosting() {
if (object_function_ == nullptr) {
Log::Fatal("No object function provided");
}
// objective function will calculate gradients and hessians // objective function will calculate gradients and hessians
object_function_-> object_function_->
GetGradients(train_score_updater_->score(), gradients_, hessians_); GetGradients(train_score_updater_->score(), gradients_, hessians_);
} }
void GBDT::SaveModelToFile(bool is_finish, const char* filename) {
std::string GBDT::ModelsToString() const { // first time to this function, open file
// serialize this object to string if (saved_model_size_ == -1) {
std::stringstream str_buf; model_output_file_.open(filename);
// output model type
model_output_file_ << "gbdt" << std::endl;
// output label index // output label index
str_buf << "label_index=" << label_idx_ << std::endl; model_output_file_ << "label_index=" << label_idx_ << std::endl;
// output max_feature_idx // output max_feature_idx
str_buf << "max_feature_idx=" << max_feature_idx_ << std::endl; model_output_file_ << "max_feature_idx=" << max_feature_idx_ << std::endl;
// output sigmoid parameter // output sigmoid parameter
str_buf << "sigmoid=" << object_function_->GetSigmoid() << std::endl; model_output_file_ << "sigmoid=" << object_function_->GetSigmoid() << std::endl;
str_buf << std::endl; model_output_file_ << std::endl;
saved_model_size_ = 0;
}
// already saved
if (!model_output_file_.is_open()) {
return;
}
int rest = static_cast<int>(models_.size()) - early_stopping_round_;
// output tree models // output tree models
for (size_t i = 0; i < models_.size(); ++i) { for (int i = saved_model_size_; i < rest; ++i) {
str_buf << "Tree=" << i << std::endl; model_output_file_ << "Tree=" << i << std::endl;
str_buf << models_[i]->ToString() << std::endl; model_output_file_ << models_[i]->ToString() << std::endl;
}
if (rest > 0) {
saved_model_size_ = rest;
}
model_output_file_.flush();
// training finished, can close file
if (is_finish) {
for (int i = saved_model_size_; i < static_cast<int>(models_.size()); ++i) {
model_output_file_ << "Tree=" << i << std::endl;
model_output_file_ << models_[i]->ToString() << std::endl;
}
model_output_file_ << std::endl << FeatureImportance() << std::endl;
model_output_file_.close();
} }
return str_buf.str();
} }
void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) { void GBDT::ModelsFromString(const std::string& model_str) {
// use serialized string to restore this object // use serialized string to restore this object
models_.clear(); models_.clear();
std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n'); std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
...@@ -363,20 +402,16 @@ void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) { ...@@ -363,20 +402,16 @@ void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) {
int end = static_cast<int>(i); int end = static_cast<int>(i);
std::string tree_str = Common::Join(lines, start, end, '\n'); std::string tree_str = Common::Join(lines, start, end, '\n');
models_.push_back(new Tree(tree_str)); models_.push_back(new Tree(tree_str));
if (num_used_model > 0 && models_.size() >= static_cast<size_t>(num_used_model)) {
break;
}
} else { } else {
++i; ++i;
} }
} }
Log::Info("%d models has been loaded\n", models_.size()); Log::Info("%d models has been loaded\n", models_.size());
} }
void GBDT::FeatureImportance(const int last_iter) { std::string GBDT::FeatureImportance() const {
std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0); std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
for (int iter = 0; iter < last_iter; ++iter) { for (size_t iter = 0; iter < models_.size(); ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) { for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
++feature_importances[models_[iter]->split_feature_real(split_idx)]; ++feature_importances[models_[iter]->split_feature_real(split_idx)];
} }
...@@ -392,38 +427,47 @@ void GBDT::FeatureImportance(const int last_iter) { ...@@ -392,38 +427,47 @@ void GBDT::FeatureImportance(const int last_iter) {
const std::pair<size_t, std::string>& rhs) { const std::pair<size_t, std::string>& rhs) {
return lhs.first > rhs.first; return lhs.first > rhs.first;
}); });
std::stringstream str_buf;
// write to model file // write to model file
fprintf(output_model_file, "\nfeature importances:\n"); str_buf << std::endl << "feature importances:" << std::endl;
for (size_t i = 0; i < pairs.size(); ++i) { for (size_t i = 0; i < pairs.size(); ++i) {
fprintf(output_model_file, "%s=%s\n", pairs[i].second.c_str(), str_buf << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
std::to_string(pairs[i].first).c_str());
} }
fflush(output_model_file); return str_buf.str();
} }
double GBDT::PredictRaw(const double* value) const { float GBDT::PredictRaw(const float* value, int num_used_model) const {
double ret = 0.0; if (num_used_model < 0) {
for (size_t i = 0; i < models_.size(); ++i) { num_used_model = static_cast<int>(models_.size());
}
float ret = 0.0f;
for (int i = 0; i < num_used_model; ++i) {
ret += models_[i]->Predict(value); ret += models_[i]->Predict(value);
} }
return ret; return ret;
} }
double GBDT::Predict(const double* value) const { float GBDT::Predict(const float* value, int num_used_model) const {
double ret = 0.0; if (num_used_model < 0) {
for (size_t i = 0; i < models_.size(); ++i) { num_used_model = static_cast<int>(models_.size());
}
float ret = 0.0f;
for (int i = 0; i < num_used_model; ++i) {
ret += models_[i]->Predict(value); ret += models_[i]->Predict(value);
} }
// if need sigmoid transform // if need sigmoid transform
if (sigmoid_ > 0) { if (sigmoid_ > 0) {
ret = 1.0 / (1.0 + std::exp(- 2.0f * sigmoid_ * ret)); ret = 1.0f / (1.0f + std::exp(- 2.0f * sigmoid_ * ret));
} }
return ret; return ret;
} }
std::vector<int> GBDT::PredictLeafIndex(const double* value) const { std::vector<int> GBDT::PredictLeafIndex(const float* value, int num_used_model) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
}
std::vector<int> ret; std::vector<int> ret;
for (size_t i = 0; i < models_.size(); ++i) { for (int i = 0; i < num_used_model; ++i) {
ret.push_back(models_[i]->PredictLeafIndex(value)); ret.push_back(models_[i]->PredictLeafIndex(value));
} }
return ret; return ret;
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <cstdio> #include <cstdio>
#include <vector> #include <vector>
#include <string> #include <string>
#include <fstream>
namespace LightGBM { namespace LightGBM {
/*! /*!
...@@ -16,9 +17,8 @@ class GBDT: public Boosting { ...@@ -16,9 +17,8 @@ class GBDT: public Boosting {
public: public:
/*! /*!
* \brief Constructor * \brief Constructor
* \param config Config of GBDT
*/ */
explicit GBDT(const BoostingConfig* config); GBDT();
/*! /*!
* \brief Destructor * \brief Destructor
*/ */
...@@ -31,9 +31,8 @@ public: ...@@ -31,9 +31,8 @@ public:
* \param training_metrics Training metrics * \param training_metrics Training metrics
* \param output_model_filename Filename of output model * \param output_model_filename Filename of output model
*/ */
void Init(const Dataset* train_data, const ObjectiveFunction* object_function, void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics, const std::vector<const Metric*>& training_metrics)
const char* output_model_filename)
override; override;
/*! /*!
* \brief Adding a validation dataset * \brief Adding a validation dataset
...@@ -45,38 +44,47 @@ public: ...@@ -45,38 +44,47 @@ public:
/*! /*!
* \brief one training iteration * \brief one training iteration
*/ */
void Train() override; bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override;
/*! \brief Get eval result */
std::vector<std::string> EvalCurrent(bool is_eval_train) const override;
/*! \brief Get prediction result */
const std::vector<const score_t*> PredictCurrent(bool is_predict_train) const override;
/*! /*!
* \brief Predtion for one record without sigmoid transformation * \brief Predtion for one record without sigmoid transformation
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double PredictRaw(const double * feature_values) const override; float PredictRaw(const float* feature_values, int num_used_model) const override;
/*! /*!
* \brief Predtion for one record with sigmoid transformation if enabled * \brief Predtion for one record with sigmoid transformation if enabled
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Prediction result for this record * \return Prediction result for this record
*/ */
double Predict(const double * feature_values) const override; float Predict(const float* feature_values, int num_used_model) const override;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Predicted leaf index for this record * \return Predicted leaf index for this record
*/ */
std::vector<int> PredictLeafIndex(const double* value) const override; std::vector<int> PredictLeafIndex(const float* value, int num_used_model) const override;
/*! /*!
* \brief Serialize models by string * \brief Serialize models by string
* \return String output of tranined model * \return String output of tranined model
*/ */
std::string ModelsToString() const override; void SaveModelToFile(bool is_finish, const char* filename) override;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
* \param model_str The string of model
*/ */
void ModelsFromString(const std::string& model_str, int num_used_model) override; void ModelsFromString(const std::string& model_str) override;
/*! /*!
* \brief Get max feature index of this model * \brief Get max feature index of this model
* \return Max feature index of this model * \return Max feature index of this model
...@@ -95,6 +103,11 @@ public: ...@@ -95,6 +103,11 @@ public:
*/ */
inline int NumberOfSubModels() const override { return static_cast<int>(models_.size()); } inline int NumberOfSubModels() const override { return static_cast<int>(models_.size()); }
/*!
* \brief Get Type name of this boosting object
*/
const char* Name() const override { return "gbdt"; }
private: private:
/*! /*!
* \brief Implement bagging logic * \brief Implement bagging logic
...@@ -112,11 +125,6 @@ private: ...@@ -112,11 +125,6 @@ private:
*/ */
void Boosting(); void Boosting();
/*! /*!
* \brief training one tree
* \return Trained tree of this iteration
*/
Tree* TrainOneTree();
/*!
* \brief updating score after tree was trained * \brief updating score after tree was trained
* \param tree Trained tree of this iteration * \param tree Trained tree of this iteration
*/ */
...@@ -130,8 +138,9 @@ private: ...@@ -130,8 +138,9 @@ private:
* \brief Calculate feature importances * \brief Calculate feature importances
* \param last_iter Last tree use to calculate * \param last_iter Last tree use to calculate
*/ */
void FeatureImportance(const int last_iter); std::string FeatureImportance() const;
/*! \brief current iteration */
int iter_;
/*! \brief Pointer to training data */ /*! \brief Pointer to training data */
const Dataset* train_data_; const Dataset* train_data_;
/*! \brief Config of gbdt */ /*! \brief Config of gbdt */
...@@ -173,16 +182,17 @@ private: ...@@ -173,16 +182,17 @@ private:
data_size_t num_data_; data_size_t num_data_;
/*! \brief Random generator, used for bagging */ /*! \brief Random generator, used for bagging */
Random random_; Random random_;
/*! \brief The filename that the models will save to */
FILE * output_model_file;
/*! /*!
* \brief Sigmoid parameter, used for prediction. * \brief Sigmoid parameter, used for prediction.
* if > 0 meas output score will transform by sigmoid function * if > 0 meas output score will transform by sigmoid function
*/ */
double sigmoid_; float sigmoid_;
/*! \brief Index of label column */ /*! \brief Index of label column */
data_size_t label_idx_; data_size_t label_idx_;
/*! \brief Saved number of models */
int saved_model_size_ = -1;
/*! \brief File to write models */
std::ofstream model_output_file_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -23,7 +23,7 @@ BinMapper::BinMapper(const BinMapper& other) ...@@ -23,7 +23,7 @@ BinMapper::BinMapper(const BinMapper& other)
num_bin_ = other.num_bin_; num_bin_ = other.num_bin_;
is_trival_ = other.is_trival_; is_trival_ = other.is_trival_;
sparse_rate_ = other.sparse_rate_; sparse_rate_ = other.sparse_rate_;
bin_upper_bound_ = new double[num_bin_]; bin_upper_bound_ = new float[num_bin_];
for (int i = 0; i < num_bin_; ++i) { for (int i = 0; i < num_bin_; ++i) {
bin_upper_bound_[i] = other.bin_upper_bound_[i]; bin_upper_bound_[i] = other.bin_upper_bound_[i];
} }
...@@ -38,10 +38,10 @@ BinMapper::~BinMapper() { ...@@ -38,10 +38,10 @@ BinMapper::~BinMapper() {
delete[] bin_upper_bound_; delete[] bin_upper_bound_;
} }
void BinMapper::FindBin(std::vector<double>* values, int max_bin) { void BinMapper::FindBin(std::vector<float>* values, int max_bin) {
size_t sample_size = values->size(); size_t sample_size = values->size();
// find distinct_values first // find distinct_values first
double* distinct_values = new double[sample_size]; float* distinct_values = new float[sample_size];
int *counts = new int[sample_size]; int *counts = new int[sample_size];
int num_values = 1; int num_values = 1;
std::sort(values->begin(), values->end()); std::sort(values->begin(), values->end());
...@@ -61,19 +61,19 @@ void BinMapper::FindBin(std::vector<double>* values, int max_bin) { ...@@ -61,19 +61,19 @@ void BinMapper::FindBin(std::vector<double>* values, int max_bin) {
if (num_values <= max_bin) { if (num_values <= max_bin) {
// use distinct value is enough // use distinct value is enough
num_bin_ = num_values; num_bin_ = num_values;
bin_upper_bound_ = new double[num_values]; bin_upper_bound_ = new float[num_values];
for (int i = 0; i < num_values - 1; ++i) { for (int i = 0; i < num_values - 1; ++i) {
bin_upper_bound_[i] = (distinct_values[i] + distinct_values[i + 1]) / 2; bin_upper_bound_[i] = (distinct_values[i] + distinct_values[i + 1]) / 2;
} }
cnt_in_bin0 = counts[0]; cnt_in_bin0 = counts[0];
bin_upper_bound_[num_values - 1] = std::numeric_limits<double>::infinity(); bin_upper_bound_[num_values - 1] = std::numeric_limits<float>::infinity();
} else { } else {
// need find bins // need find bins
num_bin_ = max_bin; num_bin_ = max_bin;
bin_upper_bound_ = new double[max_bin]; bin_upper_bound_ = new float[max_bin];
double * bin_lower_bound = new double[max_bin]; float * bin_lower_bound = new float[max_bin];
// mean size for one bin // mean size for one bin
double mean_bin_size = sample_size / static_cast<double>(max_bin); float mean_bin_size = sample_size / static_cast<float>(max_bin);
int rest_sample_cnt = static_cast<int>(sample_size); int rest_sample_cnt = static_cast<int>(sample_size);
int cur_cnt_inbin = 0; int cur_cnt_inbin = 0;
int bin_cnt = 0; int bin_cnt = 0;
...@@ -88,24 +88,24 @@ void BinMapper::FindBin(std::vector<double>* values, int max_bin) { ...@@ -88,24 +88,24 @@ void BinMapper::FindBin(std::vector<double>* values, int max_bin) {
++bin_cnt; ++bin_cnt;
bin_lower_bound[bin_cnt] = distinct_values[i + 1]; bin_lower_bound[bin_cnt] = distinct_values[i + 1];
cur_cnt_inbin = 0; cur_cnt_inbin = 0;
mean_bin_size = rest_sample_cnt / static_cast<double>(max_bin - bin_cnt); mean_bin_size = rest_sample_cnt / static_cast<float>(max_bin - bin_cnt);
} }
} }
cur_cnt_inbin += counts[num_values - 1]; cur_cnt_inbin += counts[num_values - 1];
// update bin upper bound // update bin upper bound
for (int i = 0; i < bin_cnt; ++i) { for (int i = 0; i < bin_cnt; ++i) {
bin_upper_bound_[i] = (bin_upper_bound_[i] + bin_lower_bound[i + 1]) / 2.0; bin_upper_bound_[i] = (bin_upper_bound_[i] + bin_lower_bound[i + 1]) / 2.0f;
} }
// last bin upper bound // last bin upper bound
bin_upper_bound_[bin_cnt] = std::numeric_limits<double>::infinity(); bin_upper_bound_[bin_cnt] = std::numeric_limits<float>::infinity();
++bin_cnt; ++bin_cnt;
delete[] bin_lower_bound; delete[] bin_lower_bound;
// if no so much bin // if no so much bin
if (bin_cnt < max_bin) { if (bin_cnt < max_bin) {
// old bin data // old bin data
double * tmp_bin_upper_bound = bin_upper_bound_; float* tmp_bin_upper_bound = bin_upper_bound_;
num_bin_ = bin_cnt; num_bin_ = bin_cnt;
bin_upper_bound_ = new double[num_bin_]; bin_upper_bound_ = new float[num_bin_];
// copy back // copy back
for (int i = 0; i < num_bin_; ++i) { for (int i = 0; i < num_bin_; ++i) {
bin_upper_bound_[i] = tmp_bin_upper_bound[i]; bin_upper_bound_[i] = tmp_bin_upper_bound[i];
...@@ -123,7 +123,7 @@ void BinMapper::FindBin(std::vector<double>* values, int max_bin) { ...@@ -123,7 +123,7 @@ void BinMapper::FindBin(std::vector<double>* values, int max_bin) {
is_trival_ = false; is_trival_ = false;
} }
// calculate sparse rate // calculate sparse rate
sparse_rate_ = static_cast<double>(cnt_in_bin0) / static_cast<double>(sample_size); sparse_rate_ = static_cast<float>(cnt_in_bin0) / static_cast<float>(sample_size);
} }
...@@ -131,8 +131,8 @@ int BinMapper::SizeForSpecificBin(int bin) { ...@@ -131,8 +131,8 @@ int BinMapper::SizeForSpecificBin(int bin) {
int size = 0; int size = 0;
size += sizeof(int); size += sizeof(int);
size += sizeof(bool); size += sizeof(bool);
size += sizeof(double); size += sizeof(float);
size += bin * sizeof(double); size += bin * sizeof(float);
return size; return size;
} }
...@@ -143,7 +143,7 @@ void BinMapper::CopyTo(char * buffer) { ...@@ -143,7 +143,7 @@ void BinMapper::CopyTo(char * buffer) {
buffer += sizeof(is_trival_); buffer += sizeof(is_trival_);
std::memcpy(buffer, &sparse_rate_, sizeof(sparse_rate_)); std::memcpy(buffer, &sparse_rate_, sizeof(sparse_rate_));
buffer += sizeof(sparse_rate_); buffer += sizeof(sparse_rate_);
std::memcpy(buffer, bin_upper_bound_, num_bin_ * sizeof(double)); std::memcpy(buffer, bin_upper_bound_, num_bin_ * sizeof(float));
} }
void BinMapper::CopyFrom(const char * buffer) { void BinMapper::CopyFrom(const char * buffer) {
...@@ -154,19 +154,19 @@ void BinMapper::CopyFrom(const char * buffer) { ...@@ -154,19 +154,19 @@ void BinMapper::CopyFrom(const char * buffer) {
std::memcpy(&sparse_rate_, buffer, sizeof(sparse_rate_)); std::memcpy(&sparse_rate_, buffer, sizeof(sparse_rate_));
buffer += sizeof(sparse_rate_); buffer += sizeof(sparse_rate_);
if (bin_upper_bound_ != nullptr) { delete[] bin_upper_bound_; } if (bin_upper_bound_ != nullptr) { delete[] bin_upper_bound_; }
bin_upper_bound_ = new double[num_bin_]; bin_upper_bound_ = new float[num_bin_];
std::memcpy(bin_upper_bound_, buffer, num_bin_ * sizeof(double)); std::memcpy(bin_upper_bound_, buffer, num_bin_ * sizeof(float));
} }
void BinMapper::SaveBinaryToFile(FILE* file) const { void BinMapper::SaveBinaryToFile(FILE* file) const {
fwrite(&num_bin_, sizeof(num_bin_), 1, file); fwrite(&num_bin_, sizeof(num_bin_), 1, file);
fwrite(&is_trival_, sizeof(is_trival_), 1, file); fwrite(&is_trival_, sizeof(is_trival_), 1, file);
fwrite(&sparse_rate_, sizeof(sparse_rate_), 1, file); fwrite(&sparse_rate_, sizeof(sparse_rate_), 1, file);
fwrite(bin_upper_bound_, sizeof(double), num_bin_, file); fwrite(bin_upper_bound_, sizeof(float), num_bin_, file);
} }
size_t BinMapper::SizesInByte() const { size_t BinMapper::SizesInByte() const {
return sizeof(num_bin_) + sizeof(is_trival_) + sizeof(sparse_rate_) + sizeof(double) * num_bin_; return sizeof(num_bin_) + sizeof(is_trival_) + sizeof(sparse_rate_) + sizeof(float) * num_bin_;
} }
template class DenseBin<uint8_t>; template class DenseBin<uint8_t>;
...@@ -182,9 +182,9 @@ template class OrderedSparseBin<uint16_t>; ...@@ -182,9 +182,9 @@ template class OrderedSparseBin<uint16_t>;
template class OrderedSparseBin<uint32_t>; template class OrderedSparseBin<uint32_t>;
Bin* Bin::CreateBin(data_size_t num_data, int num_bin, double sparse_rate, bool is_enable_sparse, bool* is_sparse, int default_bin) { Bin* Bin::CreateBin(data_size_t num_data, int num_bin, float sparse_rate, bool is_enable_sparse, bool* is_sparse, int default_bin) {
// sparse threshold // sparse threshold
const double kSparseThreshold = 0.8; const float kSparseThreshold = 0.8f;
if (sparse_rate >= kSparseThreshold && is_enable_sparse) { if (sparse_rate >= kSparseThreshold && is_enable_sparse) {
*is_sparse = true; *is_sparse = true;
return CreateSparseBin(num_data, num_bin, default_bin); return CreateSparseBin(num_data, num_bin, default_bin);
......
...@@ -10,6 +10,26 @@ ...@@ -10,6 +10,26 @@
namespace LightGBM { namespace LightGBM {
void OverallConfig::LoadFromString(const char* str) {
std::unordered_map<std::string, std::string> params;
auto args = Common::Split(str, " \t\n\r");
for (auto arg : args) {
std::vector<std::string> tmp_strs = Common::Split(arg.c_str(), '=');
if (tmp_strs.size() == 2) {
std::string key = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[0]));
std::string value = Common::RemoveQuotationSymbol(Common::Trim(tmp_strs[1]));
if (key.size() <= 0) {
continue;
}
params[key] = value;
} else {
Log::Error("Unknown parameter %s", arg.c_str());
}
}
ParameterAlias::KeyAliasTransform(&params);
Set(params);
}
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) { void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
// load main config types // load main config types
GetInt(params, "num_threads", &num_threads); GetInt(params, "num_threads", &num_threads);
...@@ -173,38 +193,34 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -173,38 +193,34 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) { void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetBool(params, "is_unbalance", &is_unbalance); GetBool(params, "is_unbalance", &is_unbalance);
GetDouble(params, "sigmoid", &sigmoid); GetFloat(params, "sigmoid", &sigmoid);
GetInt(params, "max_position", &max_position); GetInt(params, "max_position", &max_position);
CHECK(max_position > 0); CHECK(max_position > 0);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) { if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToDoubleArray(tmp_str, ','); label_gain = Common::StringToFloatArray(tmp_str, ',');
} else { } else {
// label_gain = 2^i - 1, may overflow, so we use 31 here // label_gain = 2^i - 1, may overflow, so we use 31 here
const int max_label = 31; const int max_label = 31;
label_gain.push_back(0.0); label_gain.push_back(0.0f);
for (int i = 1; i < max_label; ++i) { for (int i = 1; i < max_label; ++i) {
label_gain.push_back((1 << i) - 1); label_gain.push_back(static_cast<float>((1 << i) - 1));
} }
} }
} }
void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) { void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "early_stopping_round", &early_stopping_round); GetFloat(params, "sigmoid", &sigmoid);
GetInt(params, "metric_freq", &output_freq);
CHECK(output_freq >= 0);
GetDouble(params, "sigmoid", &sigmoid);
GetBool(params, "is_training_metric", &is_provide_training_metric);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) { if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToDoubleArray(tmp_str, ','); label_gain = Common::StringToFloatArray(tmp_str, ',');
} else { } else {
// label_gain = 2^i - 1, may overflow, so we use 31 here // label_gain = 2^i - 1, may overflow, so we use 31 here
const int max_label = 31; const int max_label = 31;
label_gain.push_back(0.0); label_gain.push_back(0.0f);
for (int i = 1; i < max_label; ++i) { for (int i = 1; i < max_label; ++i) {
label_gain.push_back((1 << i) - 1); label_gain.push_back(static_cast<float>((1 << i) - 1));
} }
} }
if (GetString(params, "ndcg_eval_at", &tmp_str)) { if (GetString(params, "ndcg_eval_at", &tmp_str)) {
...@@ -224,14 +240,14 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param ...@@ -224,14 +240,14 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param
void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) { void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "min_data_in_leaf", &min_data_in_leaf); GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf); GetFloat(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0); CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0);
GetInt(params, "num_leaves", &num_leaves); GetInt(params, "num_leaves", &num_leaves);
CHECK(num_leaves > 1); CHECK(num_leaves > 1);
GetInt(params, "feature_fraction_seed", &feature_fraction_seed); GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
GetDouble(params, "feature_fraction", &feature_fraction); GetFloat(params, "feature_fraction", &feature_fraction);
CHECK(feature_fraction > 0.0 && feature_fraction <= 1.0); CHECK(feature_fraction > 0.0f && feature_fraction <= 1.0f);
GetDouble(params, "histogram_pool_size", &histogram_pool_size); GetFloat(params, "histogram_pool_size", &histogram_pool_size);
GetInt(params, "max_depth", &max_depth); GetInt(params, "max_depth", &max_depth);
CHECK(max_depth > 1 || max_depth < 0); CHECK(max_depth > 1 || max_depth < 0);
} }
...@@ -243,12 +259,15 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par ...@@ -243,12 +259,15 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetInt(params, "bagging_seed", &bagging_seed); GetInt(params, "bagging_seed", &bagging_seed);
GetInt(params, "bagging_freq", &bagging_freq); GetInt(params, "bagging_freq", &bagging_freq);
CHECK(bagging_freq >= 0); CHECK(bagging_freq >= 0);
GetDouble(params, "bagging_fraction", &bagging_fraction); GetFloat(params, "bagging_fraction", &bagging_fraction);
CHECK(bagging_fraction > 0.0 && bagging_fraction <= 1.0); CHECK(bagging_fraction > 0.0f && bagging_fraction <= 1.0f);
GetDouble(params, "learning_rate", &learning_rate); GetFloat(params, "learning_rate", &learning_rate);
CHECK(learning_rate > 0.0); CHECK(learning_rate > 0.0f);
GetInt(params, "early_stopping_round", &early_stopping_round); GetInt(params, "early_stopping_round", &early_stopping_round);
CHECK(early_stopping_round >= 0); CHECK(early_stopping_round >= 0);
GetInt(params, "metric_freq", &output_freq);
CHECK(output_freq >= 0);
GetBool(params, "is_training_metric", &is_provide_training_metric);
} }
void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) { void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
......
...@@ -274,10 +274,10 @@ void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partiti ...@@ -274,10 +274,10 @@ void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partiti
void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<std::string>& sample_data) { void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<std::string>& sample_data) {
// sample_values[i][j], means the value of j-th sample on i-th feature // sample_values[i][j], means the value of j-th sample on i-th feature
std::vector<std::vector<double>> sample_values; std::vector<std::vector<float>> sample_values;
// temp buffer for one line features and label // temp buffer for one line features and label
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, float>> oneline_features;
double label; float label;
for (size_t i = 0; i < sample_data.size(); ++i) { for (size_t i = 0; i < sample_data.size(); ++i) {
oneline_features.clear(); oneline_features.clear();
// parse features // parse features
...@@ -286,13 +286,13 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -286,13 +286,13 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
for (auto& feature_values : sample_values) { for (auto& feature_values : sample_values) {
feature_values.push_back(0.0); feature_values.push_back(0.0);
} }
for (std::pair<int, double>& inner_data : oneline_features) { for (std::pair<int, float>& inner_data : oneline_features) {
if (static_cast<size_t>(inner_data.first) >= sample_values.size()) { if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
// if need expand feature set // if need expand feature set
size_t need_size = inner_data.first - sample_values.size() + 1; size_t need_size = inner_data.first - sample_values.size() + 1;
for (size_t j = 0; j < need_size; ++j) { for (size_t j = 0; j < need_size; ++j) {
// push i+1 0 // push i+1 0
sample_values.emplace_back(i + 1, 0.0); sample_values.emplace_back(i + 1, 0.0f);
} }
} }
// edit the feature value // edit the feature value
...@@ -507,8 +507,8 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo ...@@ -507,8 +507,8 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo
} }
void Dataset::ExtractFeaturesFromMemory() { void Dataset::ExtractFeaturesFromMemory() {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, float>> oneline_features;
double tmp_label = 0.0; float tmp_label = 0.0f;
if (predict_fun_ == nullptr) { if (predict_fun_ == nullptr) {
// if doesn't need to prediction with initial model // if doesn't need to prediction with initial model
#pragma omp parallel for schedule(guided) private(oneline_features) firstprivate(tmp_label) #pragma omp parallel for schedule(guided) private(oneline_features) firstprivate(tmp_label)
...@@ -577,7 +577,7 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -577,7 +577,7 @@ void Dataset::ExtractFeaturesFromMemory() {
} }
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; i++) { for (int i = 0; i < num_features_; ++i) {
features_[i]->FinishLoad(); features_[i]->FinishLoad();
} }
// text data can be free after loaded feature values // text data can be free after loaded feature values
...@@ -593,10 +593,10 @@ void Dataset::ExtractFeaturesFromFile() { ...@@ -593,10 +593,10 @@ void Dataset::ExtractFeaturesFromFile() {
std::function<void(data_size_t, const std::vector<std::string>&)> process_fun = std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
[this, &init_score] [this, &init_score]
(data_size_t start_idx, const std::vector<std::string>& lines) { (data_size_t start_idx, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, float>> oneline_features;
double tmp_label = 0.0; float tmp_label = 0.0f;
#pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label) #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); i++) { for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
oneline_features.clear(); oneline_features.clear();
// parser // parser
...@@ -639,7 +639,7 @@ void Dataset::ExtractFeaturesFromFile() { ...@@ -639,7 +639,7 @@ void Dataset::ExtractFeaturesFromFile() {
} }
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; i++) { for (int i = 0; i < num_features_; ++i) {
features_[i]->FinishLoad(); features_[i]->FinishLoad();
} }
} }
...@@ -805,7 +805,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -805,7 +805,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
const data_size_t* query_boundaries = metadata_.query_boundaries(); const data_size_t* query_boundaries = metadata_.query_boundaries();
if (query_boundaries == nullptr) { if (query_boundaries == nullptr) {
// if not contain query file, minimal sample unit is one record // if not contain query file, minimal sample unit is one record
for (data_size_t i = 0; i < num_data_; i++) { for (data_size_t i = 0; i < num_data_; ++i) {
if (random_.NextInt(0, num_machines) == rank) { if (random_.NextInt(0, num_machines) == rank) {
used_data_indices_.push_back(i); used_data_indices_.push_back(i);
} }
...@@ -815,7 +815,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit ...@@ -815,7 +815,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
data_size_t num_queries = metadata_.num_queries(); data_size_t num_queries = metadata_.num_queries();
data_size_t qid = -1; data_size_t qid = -1;
bool is_query_used = false; bool is_query_used = false;
for (data_size_t i = 0; i < num_data_; i++) { for (data_size_t i = 0; i < num_data_; ++i) {
if (qid >= num_queries) { if (qid >= num_queries) {
Log::Fatal("current query is exceed the range of query file, please ensure your query file is correct"); Log::Fatal("current query is exceed the range of query file, please ensure your query file is correct");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment