Unverified Commit dc699574 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Refine config object (#1381)

* [WIP] refine config

* [wip] ready for the auto code generate

* auto generate config codes

* use with to open file

* fix bug

* fix pylint

* fix bug

* fix pylint

* fix bugs.

* tmp for failed test.

* fix tests.

* added nthreads alias

* added new aliases from new config.h

* fixed duplicated alias

* refactored parameter_generator.py

* added new aliases from config.h and removed remaining old names

* fix bugs & some miss alias

* added aliases

* add more descriptions.

* add comment.
parent 497e60ed
...@@ -32,7 +32,7 @@ Core Parameters ...@@ -32,7 +32,7 @@ Core Parameters
- **Note**: Only can be used in CLI version - **Note**: Only can be used in CLI version
- ``task``, default=\ ``train``, type=enum, options=\ ``train``, ``predict``, ``convert_model``, ``refit`` - ``task``, default=\ ``train``, type=enum, options=\ ``train``, ``predict``, ``convert_model``, ``refit``, alias=\ ``task_type``
- ``train``, alias=\ ``training``, for training - ``train``, alias=\ ``training``, for training
...@@ -47,7 +47,7 @@ Core Parameters ...@@ -47,7 +47,7 @@ Core Parameters
- ``application``, default=\ ``regression``, type=enum, - ``application``, default=\ ``regression``, type=enum,
options=\ ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gammma``, ``tweedie``, options=\ ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gammma``, ``tweedie``,
``binary``, ``multiclass``, ``multiclassova``, ``xentropy``, ``xentlambda``, ``lambdarank``, ``binary``, ``multiclass``, ``multiclassova``, ``xentropy``, ``xentlambda``, ``lambdarank``,
alias=\ ``objective``, ``app`` alias=\ ``app``, ``objective``, ``objective_type``
- regression application - regression application
...@@ -107,11 +107,11 @@ Core Parameters ...@@ -107,11 +107,11 @@ Core Parameters
- ``goss``, Gradient-based One-Side Sampling - ``goss``, Gradient-based One-Side Sampling
- ``data``, default=\ ``""``, type=string, alias=\ ``train``, ``train_data`` - ``data``, default=\ ``""``, type=string, alias=\ ``train``, ``train_data``, ``data_filename``
- training data, LightGBM will train from this data - training data, LightGBM will train from this data
- ``valid``, default=\ ``""``, type=multi-string, alias=\ ``test``, ``valid_data``, ``test_data`` - ``valid``, default=\ ``""``, type=multi-string, alias=\ ``test``, ``valid_data``, ``test_data``, ``valid_filenames``
- validation/test data, LightGBM will output metrics for these data - validation/test data, LightGBM will output metrics for these data
...@@ -137,7 +137,7 @@ Core Parameters ...@@ -137,7 +137,7 @@ Core Parameters
- number of leaves in one tree - number of leaves in one tree
- ``tree_learner``, default=\ ``serial``, type=enum, options=\ ``serial``, ``feature``, ``data``, ``voting``, alias=\ ``tree`` - ``tree_learner``, default=\ ``serial``, type=enum, options=\ ``serial``, ``feature``, ``data``, ``voting``, alias=\ ``tree``, ``tree_learner_type``
- ``serial``, single machine tree learner - ``serial``, single machine tree learner
...@@ -149,7 +149,7 @@ Core Parameters ...@@ -149,7 +149,7 @@ Core Parameters
- refer to `Parallel Learning Guide <./Parallel-Learning-Guide.rst>`__ to get more details - refer to `Parallel Learning Guide <./Parallel-Learning-Guide.rst>`__ to get more details
- ``num_threads``, default=\ ``OpenMP_default``, type=int, alias=\ ``num_thread``, ``nthread`` - ``num_threads``, default=\ ``OpenMP_default``, type=int, alias=\ ``num_thread``, ``nthread``, ``nthreads``
- number of threads for LightGBM - number of threads for LightGBM
...@@ -204,7 +204,7 @@ Learning Control Parameters ...@@ -204,7 +204,7 @@ Learning Control Parameters
- random seed for ``feature_fraction`` - random seed for ``feature_fraction``
- ``bagging_fraction``, default=\ ``1.0``, type=double, ``0.0 < bagging_fraction <= 1.0``, alias=\ ``sub_row``, ``subsample`` - ``bagging_fraction``, default=\ ``1.0``, type=double, ``0.0 < bagging_fraction <= 1.0``, alias=\ ``sub_row``, ``subsample``, ``bagging``
- like ``feature_fraction``, but this will randomly select part of data without resampling - like ``feature_fraction``, but this will randomly select part of data without resampling
...@@ -312,7 +312,7 @@ Learning Control Parameters ...@@ -312,7 +312,7 @@ Learning Control Parameters
- set this to larger value for more accurate result, but it will slow down the training speed - set this to larger value for more accurate result, but it will slow down the training speed
- ``monotone_constraint``, default=\ ``None``, type=multi-int, alias=\ ``mc`` - ``monotone_constraint``, default=\ ``None``, type=multi-int, alias=\ ``mc``, ``monotone_constraints``
- used for constraints of monotonic features - used for constraints of monotonic features
...@@ -443,7 +443,7 @@ IO Parameters ...@@ -443,7 +443,7 @@ IO Parameters
- **Note**: the negative values will be treated as **missing values** - **Note**: the negative values will be treated as **missing values**
- ``predict_raw_score``, default=\ ``false``, type=bool, alias=\ ``raw_score``, ``is_predict_raw_score`` - ``predict_raw_score``, default=\ ``false``, type=bool, alias=\ ``raw_score``, ``is_predict_raw_score``, ``predict_rawscore``
- only used in ``prediction`` task - only used in ``prediction`` task
...@@ -501,17 +501,17 @@ IO Parameters ...@@ -501,17 +501,17 @@ IO Parameters
- set to ``false`` to use ``na`` to represent missing values - set to ``false`` to use ``na`` to represent missing values
- ``init_score_file``, default=\ ``""``, type=string - ``init_score_file``, default=\ ``""``, type=string, alias=\ ``init_score_filename``, ``initscore_filename``, ``init_score``
- path to training initial score file, ``""`` will use ``train_data_file`` + ``.init`` (if exists) - path to training initial score file, ``""`` will use ``train_data_file`` + ``.init`` (if exists)
- ``valid_init_score_file``, default=\ ``""``, type=multi-string - ``valid_init_score_file``, default=\ ``""``, type=multi-string, alias=\ ``valid_data_initscores``, ``valid_data_init_scores``, ``valid_init_score``
- path to validation initial score file, ``""`` will use ``valid_data_file`` + ``.init`` (if exists) - path to validation initial score file, ``""`` will use ``valid_data_file`` + ``.init`` (if exists)
- separate by ``,`` for multi-validation data - separate by ``,`` for multi-validation data
- ``forced_splits``, default=\ ``""``, type=string - ``forced_splits``, default=\ ``""``, type=string, alias=\ ``forced_splits_file``, ``forcedsplits_filename``, ``forced_splits_filename``
- path to a ``.json`` file that specifies splits to force at the top of every decision tree before best-first learning commences - path to a ``.json`` file that specifies splits to force at the top of every decision tree before best-first learning commences
...@@ -593,7 +593,7 @@ Objective Parameters ...@@ -593,7 +593,7 @@ Objective Parameters
Metric Parameters Metric Parameters
----------------- -----------------
- ``metric``, default=\ ``''``, type=multi-enum - ``metric``, default=\ ``''``, type=multi-enum, alias=\ ``metric_types``
- metric to be evaluated on the evaluation sets **in addition** to what is provided in the training arguments - metric to be evaluated on the evaluation sets **in addition** to what is provided in the training arguments
...@@ -650,7 +650,7 @@ Metric Parameters ...@@ -650,7 +650,7 @@ Metric Parameters
- frequency for metric output - frequency for metric output
- ``train_metric``, default=\ ``false``, type=bool, alias=\ ``training_metric``, ``is_training_metric`` - ``train_metric``, default=\ ``false``, type=bool, alias=\ ``training_metric``, ``is_training_metric``, ``is_provide_training_metric``
- set this to ``true`` if you need to output metric result of training - set this to ``true`` if you need to output metric result of training
......
import os
def GetParameterInfos(config_hpp):
is_inparameter = False
parameter_group = None
cur_key = None
cur_info = {}
keys = []
member_infos = []
with open(config_hpp) as config_hpp_file:
for line in config_hpp_file:
if "#pragma region Parameters" in line:
is_inparameter = True
elif "#pragma region" in line and "Parameters" in line:
cur_key = line.split("region")[1].strip()
keys.append(cur_key)
member_infos.append([])
elif '#pragma endregion' in line:
if cur_key is not None:
cur_key = None
elif is_inparameter:
is_inparameter = False
elif cur_key is not None:
line = line.strip()
if line.startswith("//"):
tokens = line.split("//")[1].split("=")
key = tokens[0].strip()
val = '='.join(tokens[1:]).strip()
if key not in cur_info:
if key == "descl2":
cur_info["desc"] = []
else:
cur_info[key] = []
if key == "desc":
cur_info["desc"].append(["l1", val])
elif key == "descl2":
cur_info["desc"].append(["l2", val])
else:
cur_info[key].append(val)
elif line:
has_eqsgn = False
tokens = line.split("=")
if len(tokens) == 2:
if "default" not in cur_info:
cur_info["default"] = [tokens[1][:-1].strip()]
has_eqsgn = True
tokens = line.split()
cur_info["inner_type"] = [tokens[0].strip()]
if "name" not in cur_info:
if has_eqsgn:
cur_info["name"] = [tokens[1].strip()]
else:
cur_info["name"] = [tokens[1][:-1].strip()]
member_infos[-1].append(cur_info)
cur_info = {}
return (keys, member_infos)
def GetNames(infos):
names = []
for x in infos:
for y in x:
names.append(y["name"][0])
return names
def GetAlias(infos):
pairs = []
for x in infos:
for y in x:
if "alias" in y:
name = y["name"][0]
alias = y["alias"][0].split(',')
for name2 in alias:
pairs.append([name2.strip(), name])
return pairs
def SetOneVarFromString(name, type, checks):
ret = ""
univar_mapper = {"int": "GetInt", "double": "GetDouble", "bool": "GetBool", "std::string": "GetString"}
if "vector" not in type:
ret += " %s(params, \"%s\", &%s);\n" % (univar_mapper[type], name, name)
if len(checks) > 0:
for check in checks:
ret += " CHECK(%s %s);\n" % (name, check)
ret += "\n"
else:
ret += " if (GetString(params, \"%s\", &tmp_str)) {\n" % (name)
type2 = type.split("<")[1][:-1]
if type2 == "std::string":
ret += " %s = Common::Split(tmp_str.c_str(), ',');\n" % (name)
else:
ret += " %s = Common::StringToArray<%s>(tmp_str, ',');\n" % (name, type2)
ret += " }\n\n"
return ret
def GenParameterCode(config_hpp, config_out_cpp):
keys, infos = GetParameterInfos(config_hpp)
names = GetNames(infos)
alias = GetAlias(infos)
str_to_write = "/// This file is auto generated by LightGBM\\helper\\parameter_generator.py\n"
str_to_write += "#include<LightGBM/config.h>\nnamespace LightGBM {\n"
# alias table
str_to_write += "std::unordered_map<std::string, std::string> Config::alias_table({\n"
for pair in alias:
str_to_write += " {\"%s\", \"%s\"}, \n" % (pair[0], pair[1])
str_to_write += "});\n\n"
# names
str_to_write += "std::unordered_set<std::string> Config::parameter_set({\n"
for name in names:
str_to_write += " \"%s\", \n" % (name)
str_to_write += "});\n\n"
# from strings
str_to_write += "void Config::GetMembersFromString(const std::unordered_map<std::string, std::string>& params) {\n"
str_to_write += " std::string tmp_str = \"\";\n"
for x in infos:
for y in x:
if "[doc-only]" in y:
continue
type = y["inner_type"][0]
name = y["name"][0]
checks = []
if "check" in y:
checks = y["check"]
tmp = SetOneVarFromString(name, type, checks)
str_to_write += tmp
# tails
str_to_write += "}\n\n"
str_to_write += "std::string Config::SaveMembersToString() const {\n"
str_to_write += " std::stringstream str_buf;\n"
for x in infos:
for y in x:
if "[doc-only]" in y:
continue
type = y["inner_type"][0]
name = y["name"][0]
if "vector" in type:
if "int8" in type:
str_to_write += " str_buf << \"[%s: \" << Common::Join(Common::ArrayCast<int8_t, int>(%s),\",\") << \"]\\n\";\n" % (name, name)
else:
str_to_write += " str_buf << \"[%s: \" << Common::Join(%s,\",\") << \"]\\n\";\n" % (name, name)
else:
str_to_write += " str_buf << \"[%s: \" << %s << \"]\\n\";\n" % (name, name)
# tails
str_to_write += " return str_buf.str();\n"
str_to_write += "}\n\n"
str_to_write += "}\n"
with open(config_out_cpp, "w") as config_out_cpp_file:
config_out_cpp_file.write(str_to_write)
if __name__ == "__main__":
config_hpp = os.path.join(os.path.pardir, 'include', 'LightGBM', 'config.h')
config_out_cpp = os.path.join(os.path.pardir, 'src', 'io', 'config_auto.cpp')
GenParameterCode(config_hpp, config_out_cpp)
...@@ -56,7 +56,7 @@ private: ...@@ -56,7 +56,7 @@ private:
void ConvertModel(); void ConvertModel();
/*! \brief All configs */ /*! \brief All configs */
OverallConfig config_; Config config_;
/*! \brief Training data */ /*! \brief Training data */
std::unique_ptr<Dataset> train_data_; std::unique_ptr<Dataset> train_data_;
/*! \brief Validation data */ /*! \brief Validation data */
...@@ -73,10 +73,10 @@ private: ...@@ -73,10 +73,10 @@ private:
inline void Application::Run() { inline void Application::Run() {
if (config_.task_type == TaskType::kPredict || config_.task_type == TaskType::KRefitTree) { if (config_.task == TaskType::kPredict || config_.task == TaskType::KRefitTree) {
InitPredict(); InitPredict();
Predict(); Predict();
} else if (config_.task_type == TaskType::kConvertModel) { } else if (config_.task == TaskType::kConvertModel) {
ConvertModel(); ConvertModel();
} else { } else {
InitTrain(); InitTrain();
......
...@@ -32,7 +32,7 @@ public: ...@@ -32,7 +32,7 @@ public:
* \param training_metrics Training metric * \param training_metrics Training metric
*/ */
virtual void Init( virtual void Init(
const BoostingConfig* config, const Config* config,
const Dataset* train_data, const Dataset* train_data,
const ObjectiveFunction* objective_function, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0; const std::vector<const Metric*>& training_metrics) = 0;
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function, virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0; const std::vector<const Metric*>& training_metrics) = 0;
virtual void ResetConfig(const BoostingConfig* config) = 0; virtual void ResetConfig(const Config* config) = 0;
......
This diff is collapsed.
...@@ -273,7 +273,7 @@ public: ...@@ -273,7 +273,7 @@ public:
* \param label_idx index of label column * \param label_idx index of label column
* \return Object of parser * \return Object of parser
*/ */
static Parser* CreateParser(const char* filename, bool has_header, int num_features, int label_idx); static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx);
}; };
/*! \brief The main class of data set, /*! \brief The main class of data set,
...@@ -292,7 +292,7 @@ public: ...@@ -292,7 +292,7 @@ public:
int** sample_non_zero_indices, int** sample_non_zero_indices,
const int* num_per_col, const int* num_per_col,
size_t total_sample_cnt, size_t total_sample_cnt,
const IOConfig& io_config); const Config& io_config);
/*! \brief Destructor */ /*! \brief Destructor */
LIGHTGBM_EXPORT ~Dataset(); LIGHTGBM_EXPORT ~Dataset();
......
...@@ -8,7 +8,7 @@ namespace LightGBM { ...@@ -8,7 +8,7 @@ namespace LightGBM {
class DatasetLoader { class DatasetLoader {
public: public:
LIGHTGBM_EXPORT DatasetLoader(const IOConfig& io_config, const PredictFunction& predict_fun, int num_class, const char* filename); LIGHTGBM_EXPORT DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename);
LIGHTGBM_EXPORT ~DatasetLoader(); LIGHTGBM_EXPORT ~DatasetLoader();
...@@ -54,7 +54,7 @@ private: ...@@ -54,7 +54,7 @@ private:
/*! \brief Check can load from binary file */ /*! \brief Check can load from binary file */
std::string CheckCanLoadFromBin(const char* filename); std::string CheckCanLoadFromBin(const char* filename);
const IOConfig& io_config_; const Config& config_;
/*! \brief Random generator*/ /*! \brief Random generator*/
Random random_; Random random_;
/*! \brief prediction function for initial model */ /*! \brief prediction function for initial model */
......
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
* \param type Specific type of metric * \param type Specific type of metric
* \param config Config for metric * \param config Config for metric
*/ */
LIGHTGBM_EXPORT static Metric* CreateMetric(const std::string& type, const MetricConfig& config); LIGHTGBM_EXPORT static Metric* CreateMetric(const std::string& type, const Config& config);
}; };
...@@ -56,11 +56,14 @@ public: ...@@ -56,11 +56,14 @@ public:
*/ */
class DCGCalculator { class DCGCalculator {
public: public:
static void DefaultEvalAt(std::vector<int>* eval_at);
static void DefaultLabelGain(std::vector<double>* label_gain);
/*! /*!
* \brief Initial logic * \brief Initial logic
* \param label_gain Gain for labels, default is 2^i - 1 * \param label_gain Gain for labels, default is 2^i - 1
*/ */
static void Init(std::vector<double> label_gain); static void Init(const std::vector<double>& label_gain);
/*! /*!
* \brief Calculate the DCG score at position k * \brief Calculate the DCG score at position k
......
...@@ -89,7 +89,7 @@ public: ...@@ -89,7 +89,7 @@ public:
* \brief Initialize * \brief Initialize
* \param config Config of network setting * \param config Config of network setting
*/ */
static void Init(NetworkConfig config); static void Init(Config config);
/*! /*!
* \brief Initialize * \brief Initialize
*/ */
......
...@@ -71,7 +71,7 @@ public: ...@@ -71,7 +71,7 @@ public:
* \param config Config for objective function * \param config Config for objective function
*/ */
LIGHTGBM_EXPORT static ObjectiveFunction* CreateObjectiveFunction(const std::string& type, LIGHTGBM_EXPORT static ObjectiveFunction* CreateObjectiveFunction(const std::string& type,
const ObjectiveConfig& config); const Config& config);
/*! /*!
* \brief Load objective function from string object * \brief Load objective function from string object
......
...@@ -170,7 +170,7 @@ public: ...@@ -170,7 +170,7 @@ public:
std::string ToJSON() const; std::string ToJSON() const;
/*! \brief Serialize this object to if-else statement*/ /*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool is_predict_leaf_index) const; std::string ToIfElse(int index, bool predict_leaf_index) const;
inline static bool IsZero(double fval) { inline static bool IsZero(double fval) {
if (fval > -kZeroThreshold && fval <= kZeroThreshold) { if (fval > -kZeroThreshold && fval <= kZeroThreshold) {
...@@ -307,9 +307,9 @@ private: ...@@ -307,9 +307,9 @@ private:
std::string NodeToJSON(int index) const; std::string NodeToJSON(int index) const;
/*! \brief Serialize one node to if-else statement*/ /*! \brief Serialize one node to if-else statement*/
std::string NodeToIfElse(int index, bool is_predict_leaf_index) const; std::string NodeToIfElse(int index, bool predict_leaf_index) const;
std::string NodeToIfElseByMap(int index, bool is_predict_leaf_index) const; std::string NodeToIfElseByMap(int index, bool predict_leaf_index) const;
double ExpectedValue() const; double ExpectedValue() const;
......
...@@ -36,9 +36,9 @@ public: ...@@ -36,9 +36,9 @@ public:
/*! /*!
* \brief Reset tree configs * \brief Reset tree configs
* \param tree_config config of tree * \param config config of tree
*/ */
virtual void ResetConfig(const TreeConfig* tree_config) = 0; virtual void ResetConfig(const Config* config) = 0;
/*! /*!
* \brief training tree model on dataset * \brief training tree model on dataset
...@@ -85,11 +85,11 @@ public: ...@@ -85,11 +85,11 @@ public:
* \brief Create object of tree learner * \brief Create object of tree learner
* \param learner_type Type of tree learner * \param learner_type Type of tree learner
* \param device_type Type of tree learner * \param device_type Type of tree learner
* \param tree_config config of tree * \param config config of tree
*/ */
static TreeLearner* CreateTreeLearner(const std::string& learner_type, static TreeLearner* CreateTreeLearner(const std::string& learner_type,
const std::string& device_type, const std::string& device_type,
const TreeConfig* tree_config); const Config* config);
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -33,7 +33,7 @@ Application::Application(int argc, char** argv) { ...@@ -33,7 +33,7 @@ Application::Application(int argc, char** argv) {
if (config_.num_threads > 0) { if (config_.num_threads > 0) {
omp_set_num_threads(config_.num_threads); omp_set_num_threads(config_.num_threads);
} }
if (config_.io_config.data_filename.size() == 0 && config_.task_type != TaskType::kConvertModel) { if (config_.data.size() == 0 && config_.task != TaskType::kConvertModel) {
Log::Fatal("No training/prediction data, application quit"); Log::Fatal("No training/prediction data, application quit");
} }
omp_set_nested(0); omp_set_nested(0);
...@@ -48,13 +48,13 @@ Application::~Application() { ...@@ -48,13 +48,13 @@ Application::~Application() {
void Application::LoadParameters(int argc, char** argv) { void Application::LoadParameters(int argc, char** argv) {
std::unordered_map<std::string, std::string> params; std::unordered_map<std::string, std::string> params;
for (int i = 1; i < argc; ++i) { for (int i = 1; i < argc; ++i) {
ConfigBase::KV2Map(params, argv[i]); Config::KV2Map(params, argv[i]);
} }
// check for alias // check for alias
ParameterAlias::KeyAliasTransform(&params); ParameterAlias::KeyAliasTransform(&params);
// read parameters from config file // read parameters from config file
if (params.count("config_file") > 0) { if (params.count("config") > 0) {
TextReader<size_t> config_reader(params["config_file"].c_str(), false); TextReader<size_t> config_reader(params["config"].c_str(), false);
config_reader.ReadAllLines(); config_reader.ReadAllLines();
if (!config_reader.Lines().empty()) { if (!config_reader.Lines().empty()) {
for (auto& line : config_reader.Lines()) { for (auto& line : config_reader.Lines()) {
...@@ -66,11 +66,11 @@ void Application::LoadParameters(int argc, char** argv) { ...@@ -66,11 +66,11 @@ void Application::LoadParameters(int argc, char** argv) {
if (line.size() == 0) { if (line.size() == 0) {
continue; continue;
} }
ConfigBase::KV2Map(params, line.c_str()); Config::KV2Map(params, line.c_str());
} }
} else { } else {
Log::Warning("Config file %s doesn't exist, will ignore", Log::Warning("Config file %s doesn't exist, will ignore",
params["config_file"].c_str()); params["config"].c_str());
} }
} }
// check for alias again // check for alias again
...@@ -87,37 +87,37 @@ void Application::LoadData() { ...@@ -87,37 +87,37 @@ void Application::LoadData() {
PredictFunction predict_fun = nullptr; PredictFunction predict_fun = nullptr;
PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig()); PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
// need to continue training // need to continue training
if (boosting_->NumberOfTotalModel() > 0 && config_.task_type != TaskType::KRefitTree) { if (boosting_->NumberOfTotalModel() > 0 && config_.task != TaskType::KRefitTree) {
predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, false, -1, -1)); predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, false, -1, -1));
predict_fun = predictor->GetPredictFunction(); predict_fun = predictor->GetPredictFunction();
} }
// sync up random seed for data partition // sync up random seed for data partition
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
config_.io_config.data_random_seed = Network::GlobalSyncUpByMin(config_.io_config.data_random_seed); config_.data_random_seed = Network::GlobalSyncUpByMin(config_.data_random_seed);
} }
DatasetLoader dataset_loader(config_.io_config, predict_fun, DatasetLoader dataset_loader(config_, predict_fun,
config_.boosting_config.num_class, config_.io_config.data_filename.c_str()); config_.num_class, config_.data.c_str());
// load Training data // load Training data
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
// load data for parallel training // load data for parallel training
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(),
config_.io_config.initscore_filename.c_str(), config_.initscore_filename.c_str(),
Network::rank(), Network::num_machines())); Network::rank(), Network::num_machines()));
} else { } else {
// load data for single machine // load data for single machine
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), config_.io_config.initscore_filename.c_str(), train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(), config_.initscore_filename.c_str(),
0, 1)); 0, 1));
} }
// need save binary file // need save binary file
if (config_.io_config.is_save_binary_file) { if (config_.save_binary) {
train_data_->SaveBinaryFile(nullptr); train_data_->SaveBinaryFile(nullptr);
} }
// create training metric // create training metric
if (config_.boosting_config.is_provide_training_metric) { if (config_.is_provide_training_metric) {
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config)); auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init(train_data_->metadata(), train_data_->num_data()); metric->Init(train_data_->metadata(), train_data_->num_data());
train_metric_.push_back(std::move(metric)); train_metric_.push_back(std::move(metric));
...@@ -126,28 +126,28 @@ void Application::LoadData() { ...@@ -126,28 +126,28 @@ void Application::LoadData() {
train_metric_.shrink_to_fit(); train_metric_.shrink_to_fit();
if (!config_.metric_types.empty()) { if (!config_.metric.empty()) {
// only when have metrics then need to construct validation data // only when have metrics then need to construct validation data
// Add validation data, if it exists // Add validation data, if it exists
for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) { for (size_t i = 0; i < config_.valid.size(); ++i) {
// add // add
auto new_dataset = std::unique_ptr<Dataset>( auto new_dataset = std::unique_ptr<Dataset>(
dataset_loader.LoadFromFileAlignWithOtherDataset( dataset_loader.LoadFromFileAlignWithOtherDataset(
config_.io_config.valid_data_filenames[i].c_str(), config_.valid[i].c_str(),
config_.io_config.valid_data_initscores[i].c_str(), config_.valid_data_initscores[i].c_str(),
train_data_.get()) train_data_.get())
); );
valid_datas_.push_back(std::move(new_dataset)); valid_datas_.push_back(std::move(new_dataset));
// need save binary file // need save binary file
if (config_.io_config.is_save_binary_file) { if (config_.save_binary) {
valid_datas_.back()->SaveBinaryFile(nullptr); valid_datas_.back()->SaveBinaryFile(nullptr);
} }
// add metric for validation data // add metric for validation data
valid_metrics_.emplace_back(); valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config)); auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init(valid_datas_.back()->metadata(), metric->Init(valid_datas_.back()->metadata(),
valid_datas_.back()->num_data()); valid_datas_.back()->num_data());
...@@ -167,30 +167,30 @@ void Application::LoadData() { ...@@ -167,30 +167,30 @@ void Application::LoadData() {
void Application::InitTrain() { void Application::InitTrain() {
if (config_.is_parallel) { if (config_.is_parallel) {
// need init network // need init network
Network::Init(config_.network_config); Network::Init(config_);
Log::Info("Finished initializing network"); Log::Info("Finished initializing network");
config_.boosting_config.tree_config.feature_fraction_seed = config_.feature_fraction_seed =
Network::GlobalSyncUpByMin(config_.boosting_config.tree_config.feature_fraction_seed); Network::GlobalSyncUpByMin(config_.feature_fraction_seed);
config_.boosting_config.tree_config.feature_fraction = config_.feature_fraction =
Network::GlobalSyncUpByMin(config_.boosting_config.tree_config.feature_fraction); Network::GlobalSyncUpByMin(config_.feature_fraction);
config_.boosting_config.drop_seed = config_.drop_seed =
Network::GlobalSyncUpByMin(config_.boosting_config.drop_seed); Network::GlobalSyncUpByMin(config_.drop_seed);
} }
// create boosting // create boosting
boosting_.reset( boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type, Boosting::CreateBoosting(config_.boosting,
config_.io_config.input_model.c_str())); config_.input_model.c_str()));
// create objective function // create objective function
objective_fun_.reset( objective_fun_.reset(
ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, ObjectiveFunction::CreateObjectiveFunction(config_.objective,
config_.objective_config)); config_));
// load training data // load training data
LoadData(); LoadData();
// initialize the objective function // initialize the objective function
objective_fun_->Init(train_data_->metadata(), train_data_->num_data()); objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
// initialize the boosting // initialize the boosting
boosting_->Init(&config_.boosting_config, train_data_.get(), objective_fun_.get(), boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
// add validation data into boosting // add validation data into boosting
for (size_t i = 0; i < valid_datas_.size(); ++i) { for (size_t i = 0; i < valid_datas_.size(); ++i) {
...@@ -202,22 +202,22 @@ void Application::InitTrain() { ...@@ -202,22 +202,22 @@ void Application::InitTrain() {
void Application::Train() { void Application::Train() {
Log::Info("Started training..."); Log::Info("Started training...");
boosting_->Train(config_.io_config.snapshot_freq, config_.io_config.output_model); boosting_->Train(config_.snapshot_freq, config_.output_model);
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str()); boosting_->SaveModelToFile(-1, config_.output_model.c_str());
// convert model to if-else statement code // convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) { if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str()); boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
} }
Log::Info("Finished training"); Log::Info("Finished training");
} }
void Application::Predict() { void Application::Predict() {
if (config_.task_type == TaskType::KRefitTree) { if (config_.task == TaskType::KRefitTree) {
// create predictor // create predictor
Predictor predictor(boosting_.get(), -1, false, true, false, false, 1, 1); Predictor predictor(boosting_.get(), -1, false, true, false, false, 1, 1);
predictor.Predict(config_.io_config.data_filename.c_str(), config_.io_config.output_result.c_str(), config_.io_config.has_header); predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header);
TextReader<int> result_reader(config_.io_config.output_result.c_str(), false); TextReader<int> result_reader(config_.output_result.c_str(), false);
result_reader.ReadAllLines(); result_reader.ReadAllLines();
std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size()); std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size());
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
...@@ -226,41 +226,41 @@ void Application::Predict() { ...@@ -226,41 +226,41 @@ void Application::Predict() {
// Free memory // Free memory
result_reader.Lines()[i].clear(); result_reader.Lines()[i].clear();
} }
DatasetLoader dataset_loader(config_.io_config, nullptr, DatasetLoader dataset_loader(config_, nullptr,
config_.boosting_config.num_class, config_.io_config.data_filename.c_str()); config_.num_class, config_.data.c_str());
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), config_.io_config.initscore_filename.c_str(), train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(), config_.initscore_filename.c_str(),
0, 1)); 0, 1));
train_metric_.clear(); train_metric_.clear();
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
config_.objective_config)); config_));
objective_fun_->Init(train_data_->metadata(), train_data_->num_data()); objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
boosting_->Init(&config_.boosting_config, train_data_.get(), objective_fun_.get(), boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
boosting_->RefitTree(pred_leaf); boosting_->RefitTree(pred_leaf);
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str()); boosting_->SaveModelToFile(-1, config_.output_model.c_str());
Log::Info("Finished RefitTree"); Log::Info("Finished RefitTree");
} else { } else {
// create predictor // create predictor
Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score, Predictor predictor(boosting_.get(), config_.num_iteration_predict, config_.predict_raw_score,
config_.io_config.is_predict_leaf_index, config_.io_config.is_predict_contrib, config_.predict_leaf_index, config_.predict_contrib,
config_.io_config.pred_early_stop, config_.io_config.pred_early_stop_freq, config_.pred_early_stop, config_.pred_early_stop_freq,
config_.io_config.pred_early_stop_margin); config_.pred_early_stop_margin);
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.data.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header); config_.output_result.c_str(), config_.header);
Log::Info("Finished prediction"); Log::Info("Finished prediction");
} }
} }
void Application::InitPredict() { void Application::InitPredict() {
boosting_.reset( boosting_.reset(
Boosting::CreateBoosting("gbdt", config_.io_config.input_model.c_str())); Boosting::CreateBoosting("gbdt", config_.input_model.c_str()));
Log::Info("Finished initializing prediction, total used %d iterations", boosting_->GetCurrentIteration()); Log::Info("Finished initializing prediction, total used %d iterations", boosting_->GetCurrentIteration());
} }
void Application::ConvertModel() { void Application::ConvertModel() {
boosting_.reset( boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type, config_.io_config.input_model.c_str())); Boosting::CreateBoosting(config_.boosting, config_.input_model.c_str()));
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str()); boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
} }
......
...@@ -29,11 +29,11 @@ public: ...@@ -29,11 +29,11 @@ public:
* \param boosting Input boosting model * \param boosting Input boosting model
* \param num_iteration Number of boosting round * \param num_iteration Number of boosting round
* \param is_raw_score True if need to predict result with raw score * \param is_raw_score True if need to predict result with raw score
* \param is_predict_leaf_index True to output leaf index instead of prediction score * \param predict_leaf_index True to output leaf index instead of prediction score
* \param is_predict_contrib True to output feature contributions instead of prediction score * \param predict_contrib True to output feature contributions instead of prediction score
*/ */
Predictor(Boosting* boosting, int num_iteration, Predictor(Boosting* boosting, int num_iteration,
bool is_raw_score, bool is_predict_leaf_index, bool is_predict_contrib, bool is_raw_score, bool predict_leaf_index, bool predict_contrib,
bool early_stop, int early_stop_freq, double early_stop_margin) { bool early_stop, int early_stop_freq, double early_stop_margin) {
early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig()); early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
...@@ -55,14 +55,14 @@ public: ...@@ -55,14 +55,14 @@ public:
{ {
num_threads_ = omp_get_num_threads(); num_threads_ = omp_get_num_threads();
} }
boosting->InitPredict(num_iteration, is_predict_contrib); boosting->InitPredict(num_iteration, predict_contrib);
boosting_ = boosting; boosting_ = boosting;
num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index, is_predict_contrib); num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, predict_leaf_index, predict_contrib);
num_feature_ = boosting_->MaxFeatureIdx() + 1; num_feature_ = boosting_->MaxFeatureIdx() + 1;
predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f)); predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f));
const int kFeatureThreshold = 100000; const int kFeatureThreshold = 100000;
const size_t KSparseThreshold = static_cast<size_t>(0.01 * num_feature_); const size_t KSparseThreshold = static_cast<size_t>(0.01 * num_feature_);
if (is_predict_leaf_index) { if (predict_leaf_index) {
predict_fun_ = [this, kFeatureThreshold, KSparseThreshold](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this, kFeatureThreshold, KSparseThreshold](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) { if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) {
...@@ -75,7 +75,7 @@ public: ...@@ -75,7 +75,7 @@ public:
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features); ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
} }
}; };
} else if (is_predict_contrib) { } else if (predict_contrib) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features); CopyToPredictBuffer(predict_buf_[tid].data(), features);
...@@ -127,27 +127,27 @@ public: ...@@ -127,27 +127,27 @@ public:
* \param data_filename Filename of data * \param data_filename Filename of data
* \param result_filename Filename of output result * \param result_filename Filename of output result
*/ */
void Predict(const char* data_filename, const char* result_filename, bool has_header) { void Predict(const char* data_filename, const char* result_filename, bool header) {
auto writer = VirtualFileWriter::Make(result_filename); auto writer = VirtualFileWriter::Make(result_filename);
if (!writer->Init()) { if (!writer->Init()) {
Log::Fatal("Prediction results file %s cannot be found", result_filename); Log::Fatal("Prediction results file %s cannot be found", result_filename);
} }
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx())); auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx()));
if (parser == nullptr) { if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename); Log::Fatal("Could not recognize the data format of data file %s", data_filename);
} }
TextReader<data_size_t> predict_data_reader(data_filename, has_header); TextReader<data_size_t> predict_data_reader(data_filename, header);
std::unordered_map<int, int> feature_names_map_; std::unordered_map<int, int> feature_names_map_;
bool need_adjust = false; bool need_adjust = false;
if (has_header) { if (header) {
std::string first_line = predict_data_reader.first_line(); std::string first_line = predict_data_reader.first_line();
std::vector<std::string> header = Common::Split(first_line.c_str(), "\t,"); std::vector<std::string> header_words = Common::Split(first_line.c_str(), "\t,");
header.erase(header.begin() + boosting_->LabelIdx()); header_words.erase(header_words.begin() + boosting_->LabelIdx());
for (int i = 0; i < static_cast<int>(header.size()); ++i) { for (int i = 0; i < static_cast<int>(header_words.size()); ++i) {
for (int j = 0; j < static_cast<int>(boosting_->FeatureNames().size()); ++j) { for (int j = 0; j < static_cast<int>(boosting_->FeatureNames().size()); ++j) {
if (header[i] == boosting_->FeatureNames()[j]) { if (header_words[i] == boosting_->FeatureNames()[j]) {
feature_names_map_[i] = j; feature_names_map_[i] = j;
break; break;
} }
......
...@@ -32,17 +32,17 @@ public: ...@@ -32,17 +32,17 @@ public:
* \param training_metrics Training metrics * \param training_metrics Training metrics
* \param output_model_filename Filename of output model * \param output_model_filename Filename of output model
*/ */
void Init(const BoostingConfig* config, const Dataset* train_data, void Init(const Config* config, const Dataset* train_data,
const ObjectiveFunction* objective_function, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, objective_function, training_metrics); GBDT::Init(config, train_data, objective_function, training_metrics);
random_for_drop_ = Random(gbdt_config_->drop_seed); random_for_drop_ = Random(config_->drop_seed);
sum_weight_ = 0.0f; sum_weight_ = 0.0f;
} }
void ResetConfig(const BoostingConfig* config) override { void ResetConfig(const Config* config) override {
GBDT::ResetConfig(config); GBDT::ResetConfig(config);
random_for_drop_ = Random(gbdt_config_->drop_seed); random_for_drop_ = Random(config_->drop_seed);
sum_weight_ = 0.0f; sum_weight_ = 0.0f;
} }
...@@ -57,7 +57,7 @@ public: ...@@ -57,7 +57,7 @@ public:
} }
// normalize // normalize
Normalize(); Normalize();
if (!gbdt_config_->uniform_drop) { if (!config_->uniform_drop) {
tree_weight_.push_back(shrinkage_rate_); tree_weight_.push_back(shrinkage_rate_);
sum_weight_ += shrinkage_rate_; sum_weight_ += shrinkage_rate_;
} }
...@@ -85,31 +85,31 @@ private: ...@@ -85,31 +85,31 @@ private:
*/ */
void DroppingTrees() { void DroppingTrees() {
drop_index_.clear(); drop_index_.clear();
bool is_skip = random_for_drop_.NextFloat() < gbdt_config_->skip_drop; bool is_skip = random_for_drop_.NextFloat() < config_->skip_drop;
// select dropping tree indices based on drop_rate and tree weights // select dropping tree indices based on drop_rate and tree weights
if (!is_skip) { if (!is_skip) {
double drop_rate = gbdt_config_->drop_rate; double drop_rate = config_->drop_rate;
if (!gbdt_config_->uniform_drop) { if (!config_->uniform_drop) {
double inv_average_weight = static_cast<double>(tree_weight_.size()) / sum_weight_; double inv_average_weight = static_cast<double>(tree_weight_.size()) / sum_weight_;
if (gbdt_config_->max_drop > 0) { if (config_->max_drop > 0) {
drop_rate = std::min(drop_rate, gbdt_config_->max_drop * inv_average_weight / sum_weight_); drop_rate = std::min(drop_rate, config_->max_drop * inv_average_weight / sum_weight_);
} }
for (int i = 0; i < iter_; ++i) { for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextFloat() < drop_rate * tree_weight_[i] * inv_average_weight) { if (random_for_drop_.NextFloat() < drop_rate * tree_weight_[i] * inv_average_weight) {
drop_index_.push_back(num_init_iteration_ + i); drop_index_.push_back(num_init_iteration_ + i);
if (drop_index_.size() >= static_cast<size_t>(gbdt_config_->max_drop)) { if (drop_index_.size() >= static_cast<size_t>(config_->max_drop)) {
break; break;
} }
} }
} }
} else { } else {
if (gbdt_config_->max_drop > 0) { if (config_->max_drop > 0) {
drop_rate = std::min(drop_rate, gbdt_config_->max_drop / static_cast<double>(iter_)); drop_rate = std::min(drop_rate, config_->max_drop / static_cast<double>(iter_));
} }
for (int i = 0; i < iter_; ++i) { for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextFloat() < drop_rate) { if (random_for_drop_.NextFloat() < drop_rate) {
drop_index_.push_back(num_init_iteration_ + i); drop_index_.push_back(num_init_iteration_ + i);
if (drop_index_.size() >= static_cast<size_t>(gbdt_config_->max_drop)) { if (drop_index_.size() >= static_cast<size_t>(config_->max_drop)) {
break; break;
} }
} }
...@@ -124,13 +124,13 @@ private: ...@@ -124,13 +124,13 @@ private:
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id); train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
} }
if (!gbdt_config_->xgboost_dart_mode) { if (!config_->xgboost_dart_mode) {
shrinkage_rate_ = gbdt_config_->learning_rate / (1.0f + static_cast<double>(drop_index_.size())); shrinkage_rate_ = config_->learning_rate / (1.0f + static_cast<double>(drop_index_.size()));
} else { } else {
if (drop_index_.empty()) { if (drop_index_.empty()) {
shrinkage_rate_ = gbdt_config_->learning_rate; shrinkage_rate_ = config_->learning_rate;
} else { } else {
shrinkage_rate_ = gbdt_config_->learning_rate / (gbdt_config_->learning_rate + static_cast<double>(drop_index_.size())); shrinkage_rate_ = config_->learning_rate / (config_->learning_rate + static_cast<double>(drop_index_.size()));
} }
} }
} }
...@@ -146,7 +146,7 @@ private: ...@@ -146,7 +146,7 @@ private:
*/ */
void Normalize() { void Normalize() {
double k = static_cast<double>(drop_index_.size()); double k = static_cast<double>(drop_index_.size());
if (!gbdt_config_->xgboost_dart_mode) { if (!config_->xgboost_dart_mode) {
for (auto i : drop_index_) { for (auto i : drop_index_) {
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id; auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id;
...@@ -159,7 +159,7 @@ private: ...@@ -159,7 +159,7 @@ private:
models_[curr_tree]->Shrinkage(-k); models_[curr_tree]->Shrinkage(-k);
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id); train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
if (!gbdt_config_->uniform_drop) { if (!config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f)); sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f));
tree_weight_[i] *= (k / (k + 1.0f)); tree_weight_[i] *= (k / (k + 1.0f));
} }
...@@ -174,12 +174,12 @@ private: ...@@ -174,12 +174,12 @@ private:
score_updater->AddScore(models_[curr_tree].get(), cur_tree_id); score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
// update training score // update training score
models_[curr_tree]->Shrinkage(-k / gbdt_config_->learning_rate); models_[curr_tree]->Shrinkage(-k / config_->learning_rate);
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id); train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
if (!gbdt_config_->uniform_drop) { if (!config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + gbdt_config_->learning_rate));; sum_weight_ -= tree_weight_[i] * (1.0f / (k + config_->learning_rate));;
tree_weight_[i] *= (k / (k + gbdt_config_->learning_rate)); tree_weight_[i] *= (k / (k + config_->learning_rate));
} }
} }
} }
......
...@@ -61,7 +61,7 @@ GBDT::~GBDT() { ...@@ -61,7 +61,7 @@ GBDT::~GBDT() {
#endif #endif
} }
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, void GBDT::Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) { const std::vector<const Metric*>& training_metrics) {
CHECK(train_data != nullptr); CHECK(train_data != nullptr);
CHECK(train_data->num_features() > 0); CHECK(train_data->num_features() > 0);
...@@ -70,9 +70,9 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -70,9 +70,9 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
num_iteration_for_pred_ = 0; num_iteration_for_pred_ = 0;
max_feature_idx_ = 0; max_feature_idx_ = 0;
num_class_ = config->num_class; num_class_ = config->num_class;
gbdt_config_ = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config)); config_ = std::unique_ptr<Config>(new Config(*config));
early_stopping_round_ = gbdt_config_->early_stopping_round; early_stopping_round_ = config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate; shrinkage_rate_ = config_->learning_rate;
std::string forced_splits_path = config->forcedsplits_filename; std::string forced_splits_path = config->forcedsplits_filename;
//load forced_splits file //load forced_splits file
...@@ -93,7 +93,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -93,7 +93,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
is_constant_hessian_ = false; is_constant_hessian_ = false;
} }
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->device_type, &gbdt_config_->tree_config)); tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type, config_.get()));
// init tree learner // init tree learner
tree_learner_->Init(train_data_, is_constant_hessian_); tree_learner_->Init(train_data_, is_constant_hessian_);
...@@ -123,7 +123,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -123,7 +123,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
feature_infos_ = train_data_->feature_infos(); feature_infos_ = train_data_->feature_infos();
// if need bagging, create buffer // if need bagging, create buffer
ResetBaggingConfig(gbdt_config_.get(), true); ResetBaggingConfig(config_.get(), true);
// reset config for tree learner // reset config for tree learner
class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true); class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
...@@ -214,7 +214,7 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t ...@@ -214,7 +214,7 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t
if (cnt <= 0) { if (cnt <= 0) {
return 0; return 0;
} }
data_size_t bag_data_cnt = static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt); data_size_t bag_data_cnt = static_cast<data_size_t>(config_->bagging_fraction * cnt);
data_size_t cur_left_cnt = 0; data_size_t cur_left_cnt = 0;
data_size_t cur_right_cnt = 0; data_size_t cur_right_cnt = 0;
auto right_buffer = buffer + bag_data_cnt; auto right_buffer = buffer + bag_data_cnt;
...@@ -233,7 +233,7 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t ...@@ -233,7 +233,7 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t
void GBDT::Bagging(int iter) { void GBDT::Bagging(int iter) {
// if need bagging // if need bagging
if ((bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) if ((bag_data_cnt_ < num_data_ && iter % config_->bagging_freq == 0)
|| need_re_bagging_) { || need_re_bagging_) {
need_re_bagging_ = false; need_re_bagging_ = false;
const data_size_t min_inner_size = 1000; const data_size_t min_inner_size = 1000;
...@@ -249,7 +249,7 @@ void GBDT::Bagging(int iter) { ...@@ -249,7 +249,7 @@ void GBDT::Bagging(int iter) {
if (cur_start > num_data_) { continue; } if (cur_start > num_data_) { continue; }
data_size_t cur_cnt = inner_size; data_size_t cur_cnt = inner_size;
if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; } if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i); Random cur_rand(config_->bagging_seed + iter * num_threads_ + i);
data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start); data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
offsets_buf_[i] = cur_start; offsets_buf_[i] = cur_start;
left_cnts_buf_[i] = cur_left_count; left_cnts_buf_[i] = cur_left_count;
...@@ -318,7 +318,7 @@ double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj) { ...@@ -318,7 +318,7 @@ double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj) {
void GBDT::Train(int snapshot_freq, const std::string& model_output_path) { void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
bool is_finished = false; bool is_finished = false;
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
for (int iter = 0; iter < gbdt_config_->num_iterations && !is_finished; ++iter) { for (int iter = 0; iter < config_->num_iterations && !is_finished; ++iter) {
is_finished = TrainOneIter(nullptr, nullptr); is_finished = TrainOneIter(nullptr, nullptr);
if (!is_finished) { if (!is_finished) {
is_finished = EvalAndCheckEarlyStopping(); is_finished = EvalAndCheckEarlyStopping();
...@@ -364,7 +364,7 @@ double GBDT::BoostFromAverage() { ...@@ -364,7 +364,7 @@ double GBDT::BoostFromAverage() {
if (models_.empty() && !train_score_updater_->has_init_score() if (models_.empty() && !train_score_updater_->has_init_score()
&& num_class_ <= 1 && num_class_ <= 1
&& objective_function_ != nullptr) { && objective_function_ != nullptr) {
if (gbdt_config_->boost_from_average) { if (config_->boost_from_average) {
double init_score = ObtainAutomaticInitialScore(objective_function_); double init_score = ObtainAutomaticInitialScore(objective_function_);
if (std::fabs(init_score) > kEpsilon) { if (std::fabs(init_score) > kEpsilon) {
train_score_updater_->AddScore(init_score, 0); train_score_updater_->AddScore(init_score, 0);
...@@ -580,7 +580,7 @@ std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* scor ...@@ -580,7 +580,7 @@ std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* scor
} }
std::string GBDT::OutputMetric(int iter) { std::string GBDT::OutputMetric(int iter) {
bool need_output = (iter % gbdt_config_->output_freq) == 0; bool need_output = (iter % config_->metric_freq) == 0;
std::string ret = ""; std::string ret = "";
std::stringstream msg_buf; std::stringstream msg_buf;
std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs; std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
...@@ -777,24 +777,24 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* ...@@ -777,24 +777,24 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
feature_infos_ = train_data_->feature_infos(); feature_infos_ = train_data_->feature_infos();
tree_learner_->ResetTrainingData(train_data); tree_learner_->ResetTrainingData(train_data);
ResetBaggingConfig(gbdt_config_.get(), true); ResetBaggingConfig(config_.get(), true);
} }
} }
void GBDT::ResetConfig(const BoostingConfig* config) { void GBDT::ResetConfig(const Config* config) {
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config)); auto new_config = std::unique_ptr<Config>(new Config(*config));
early_stopping_round_ = new_config->early_stopping_round; early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate; shrinkage_rate_ = new_config->learning_rate;
if (tree_learner_ != nullptr) { if (tree_learner_ != nullptr) {
tree_learner_->ResetConfig(&new_config->tree_config); tree_learner_->ResetConfig(new_config.get());
} }
if (train_data_ != nullptr) { if (train_data_ != nullptr) {
ResetBaggingConfig(new_config.get(), false); ResetBaggingConfig(new_config.get(), false);
} }
gbdt_config_.reset(new_config.release()); config_.reset(new_config.release());
} }
void GBDT::ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset) { void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
// if need bagging, create buffer // if need bagging, create buffer
if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) { if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
bag_data_cnt_ = bag_data_cnt_ =
......
...@@ -43,7 +43,7 @@ public: ...@@ -43,7 +43,7 @@ public:
* \param objective_function Training objective function * \param objective_function Training objective function
* \param training_metrics Training metrics * \param training_metrics Training metrics
*/ */
void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, void Init(const Config* gbdt_config, const Dataset* train_data,
const ObjectiveFunction* objective_function, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override; const std::vector<const Metric*>& training_metrics) override;
...@@ -83,7 +83,7 @@ public: ...@@ -83,7 +83,7 @@ public:
* \brief Reset Boosting Config * \brief Reset Boosting Config
* \param gbdt_config Config for boosting * \param gbdt_config Config for boosting
*/ */
void ResetConfig(const BoostingConfig* gbdt_config) override; void ResetConfig(const Config* gbdt_config) override;
/*! /*!
* \brief Adding a validation dataset * \brief Adding a validation dataset
...@@ -335,7 +335,7 @@ protected: ...@@ -335,7 +335,7 @@ protected:
/*! /*!
* \brief reset config for bagging * \brief reset config for bagging
*/ */
void ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset); void ResetBaggingConfig(const Config* config, bool is_change_dataset);
/*! /*!
* \brief Implement bagging logic * \brief Implement bagging logic
...@@ -384,7 +384,7 @@ protected: ...@@ -384,7 +384,7 @@ protected:
/*! \brief Pointer to training data */ /*! \brief Pointer to training data */
const Dataset* train_data_; const Dataset* train_data_;
/*! \brief Config of gbdt */ /*! \brief Config of gbdt */
std::unique_ptr<BoostingConfig> gbdt_config_; std::unique_ptr<Config> config_;
/*! \brief Tree learner, will use this class to learn trees */ /*! \brief Tree learner, will use this class to learn trees */
std::unique_ptr<TreeLearner> tree_learner_; std::unique_ptr<TreeLearner> tree_learner_;
/*! \brief Objective function */ /*! \brief Objective function */
......
...@@ -300,6 +300,10 @@ std::string GBDT::SaveModelToString(int num_iteration) const { ...@@ -300,6 +300,10 @@ std::string GBDT::SaveModelToString(int num_iteration) const {
for (size_t i = 0; i < pairs.size(); ++i) { for (size_t i = 0; i < pairs.size(); ++i) {
ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n'; ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n';
} }
if (config_ != nullptr) {
ss << "parameters:" << '\n';
ss << config_->ToString() << "\n";
}
return ss.str(); return ss.str();
} }
......
...@@ -39,7 +39,7 @@ public: ...@@ -39,7 +39,7 @@ public:
#endif #endif
} }
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, objective_function, training_metrics); GBDT::Init(config, train_data, objective_function, training_metrics);
ResetGoss(); ResetGoss();
...@@ -51,15 +51,15 @@ public: ...@@ -51,15 +51,15 @@ public:
ResetGoss(); ResetGoss();
} }
void ResetConfig(const BoostingConfig* config) override { void ResetConfig(const Config* config) override {
GBDT::ResetConfig(config); GBDT::ResetConfig(config);
ResetGoss(); ResetGoss();
} }
void ResetGoss() { void ResetGoss() {
CHECK(gbdt_config_->top_rate + gbdt_config_->other_rate <= 1.0f); CHECK(config_->top_rate + config_->other_rate <= 1.0f);
CHECK(gbdt_config_->top_rate > 0.0f && gbdt_config_->other_rate > 0.0f); CHECK(config_->top_rate > 0.0f && config_->other_rate > 0.0f);
if (gbdt_config_->bagging_freq > 0 && gbdt_config_->bagging_fraction != 1.0f) { if (config_->bagging_freq > 0 && config_->bagging_fraction != 1.0f) {
Log::Fatal("Cannot use bagging in GOSS"); Log::Fatal("Cannot use bagging in GOSS");
} }
Log::Info("Using GOSS"); Log::Info("Using GOSS");
...@@ -74,8 +74,8 @@ public: ...@@ -74,8 +74,8 @@ public:
right_write_pos_buf_.resize(num_threads_); right_write_pos_buf_.resize(num_threads_);
is_use_subset_ = false; is_use_subset_ = false;
if (gbdt_config_->top_rate + gbdt_config_->other_rate <= 0.5) { if (config_->top_rate + config_->other_rate <= 0.5) {
auto bag_data_cnt = static_cast<data_size_t>((gbdt_config_->top_rate + gbdt_config_->other_rate) * num_data_); auto bag_data_cnt = static_cast<data_size_t>((config_->top_rate + config_->other_rate) * num_data_);
bag_data_cnt = std::max(1, bag_data_cnt); bag_data_cnt = std::max(1, bag_data_cnt);
tmp_subset_.reset(new Dataset(bag_data_cnt)); tmp_subset_.reset(new Dataset(bag_data_cnt));
tmp_subset_->CopyFeatureMapperFrom(train_data_); tmp_subset_->CopyFeatureMapperFrom(train_data_);
...@@ -93,8 +93,8 @@ public: ...@@ -93,8 +93,8 @@ public:
tmp_gradients[i] += std::fabs(gradients_[idx] * hessians_[idx]); tmp_gradients[i] += std::fabs(gradients_[idx] * hessians_[idx]);
} }
} }
data_size_t top_k = static_cast<data_size_t>(cnt * gbdt_config_->top_rate); data_size_t top_k = static_cast<data_size_t>(cnt * config_->top_rate);
data_size_t other_k = static_cast<data_size_t>(cnt * gbdt_config_->other_rate); data_size_t other_k = static_cast<data_size_t>(cnt * config_->other_rate);
top_k = std::max(1, top_k); top_k = std::max(1, top_k);
ArrayArgs<score_t>::ArgMaxAtK(&tmp_gradients, 0, static_cast<int>(tmp_gradients.size()), top_k - 1); ArrayArgs<score_t>::ArgMaxAtK(&tmp_gradients, 0, static_cast<int>(tmp_gradients.size()), top_k - 1);
score_t threshold = tmp_gradients[top_k - 1]; score_t threshold = tmp_gradients[top_k - 1];
...@@ -135,7 +135,7 @@ public: ...@@ -135,7 +135,7 @@ public:
void Bagging(int iter) override { void Bagging(int iter) override {
bag_data_cnt_ = num_data_; bag_data_cnt_ = num_data_;
// not subsample for first iterations // not subsample for first iterations
if (iter < static_cast<int>(1.0f / gbdt_config_->learning_rate)) { return; } if (iter < static_cast<int>(1.0f / config_->learning_rate)) { return; }
const data_size_t min_inner_size = 100; const data_size_t min_inner_size = 100;
data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_; data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
...@@ -150,7 +150,7 @@ public: ...@@ -150,7 +150,7 @@ public:
if (cur_start > num_data_) { continue; } if (cur_start > num_data_) { continue; }
data_size_t cur_cnt = inner_size; data_size_t cur_cnt = inner_size;
if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; } if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i); Random cur_rand(config_->bagging_seed + iter * num_threads_ + i);
data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt,
tmp_indices_.data() + cur_start, tmp_indice_right_.data() + cur_start); tmp_indices_.data() + cur_start, tmp_indice_right_.data() + cur_start);
offsets_buf_[i] = cur_start; offsets_buf_[i] = cur_start;
......
...@@ -24,10 +24,10 @@ public: ...@@ -24,10 +24,10 @@ public:
~RF() {} ~RF() {}
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f); CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
CHECK(config->tree_config.feature_fraction < 1.0f && config->tree_config.feature_fraction > 0.0f); CHECK(config->feature_fraction < 1.0f && config->feature_fraction > 0.0f);
GBDT::Init(config, train_data, objective_function, training_metrics); GBDT::Init(config, train_data, objective_function, training_metrics);
if (num_init_iteration_ > 0) { if (num_init_iteration_ > 0) {
...@@ -50,9 +50,9 @@ public: ...@@ -50,9 +50,9 @@ public:
} }
} }
void ResetConfig(const BoostingConfig* config) override { void ResetConfig(const Config* config) override {
CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f); CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
CHECK(config->tree_config.feature_fraction < 1.0f && config->tree_config.feature_fraction > 0.0f); CHECK(config->feature_fraction < 1.0f && config->feature_fraction > 0.0f);
GBDT::ResetConfig(config); GBDT::ResetConfig(config);
// not shrinkage rate for the RF // not shrinkage rate for the RF
shrinkage_rate_ = 1.0f; shrinkage_rate_ = 1.0f;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment