Unverified Commit dc699574 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Refine config object (#1381)

* [WIP] refine config

* [wip] ready for the auto code generate

* auto generate config codes

* use with to open file

* fix bug

* fix pylint

* fix bug

* fix pylint

* fix bugs.

* tmp for failed test.

* fix tests.

* added nthreads alias

* added new aliases from new config.h

* fixed duplicated alias

* refactored parameter_generator.py

* added new aliases from config.h and removed remaining old names

* fix bugs & some miss alias

* added aliases

* add more descriptions.

* add comment.
parent 497e60ed
...@@ -32,7 +32,7 @@ Core Parameters ...@@ -32,7 +32,7 @@ Core Parameters
- **Note**: Only can be used in CLI version - **Note**: Only can be used in CLI version
- ``task``, default=\ ``train``, type=enum, options=\ ``train``, ``predict``, ``convert_model``, ``refit`` - ``task``, default=\ ``train``, type=enum, options=\ ``train``, ``predict``, ``convert_model``, ``refit``, alias=\ ``task_type``
- ``train``, alias=\ ``training``, for training - ``train``, alias=\ ``training``, for training
...@@ -47,7 +47,7 @@ Core Parameters ...@@ -47,7 +47,7 @@ Core Parameters
- ``application``, default=\ ``regression``, type=enum, - ``application``, default=\ ``regression``, type=enum,
options=\ ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gammma``, ``tweedie``, options=\ ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gammma``, ``tweedie``,
``binary``, ``multiclass``, ``multiclassova``, ``xentropy``, ``xentlambda``, ``lambdarank``, ``binary``, ``multiclass``, ``multiclassova``, ``xentropy``, ``xentlambda``, ``lambdarank``,
alias=\ ``objective``, ``app`` alias=\ ``app``, ``objective``, ``objective_type``
- regression application - regression application
...@@ -107,11 +107,11 @@ Core Parameters ...@@ -107,11 +107,11 @@ Core Parameters
- ``goss``, Gradient-based One-Side Sampling - ``goss``, Gradient-based One-Side Sampling
- ``data``, default=\ ``""``, type=string, alias=\ ``train``, ``train_data`` - ``data``, default=\ ``""``, type=string, alias=\ ``train``, ``train_data``, ``data_filename``
- training data, LightGBM will train from this data - training data, LightGBM will train from this data
- ``valid``, default=\ ``""``, type=multi-string, alias=\ ``test``, ``valid_data``, ``test_data`` - ``valid``, default=\ ``""``, type=multi-string, alias=\ ``test``, ``valid_data``, ``test_data``, ``valid_filenames``
- validation/test data, LightGBM will output metrics for these data - validation/test data, LightGBM will output metrics for these data
...@@ -137,7 +137,7 @@ Core Parameters ...@@ -137,7 +137,7 @@ Core Parameters
- number of leaves in one tree - number of leaves in one tree
- ``tree_learner``, default=\ ``serial``, type=enum, options=\ ``serial``, ``feature``, ``data``, ``voting``, alias=\ ``tree`` - ``tree_learner``, default=\ ``serial``, type=enum, options=\ ``serial``, ``feature``, ``data``, ``voting``, alias=\ ``tree``, ``tree_learner_type``
- ``serial``, single machine tree learner - ``serial``, single machine tree learner
...@@ -149,7 +149,7 @@ Core Parameters ...@@ -149,7 +149,7 @@ Core Parameters
- refer to `Parallel Learning Guide <./Parallel-Learning-Guide.rst>`__ to get more details - refer to `Parallel Learning Guide <./Parallel-Learning-Guide.rst>`__ to get more details
- ``num_threads``, default=\ ``OpenMP_default``, type=int, alias=\ ``num_thread``, ``nthread`` - ``num_threads``, default=\ ``OpenMP_default``, type=int, alias=\ ``num_thread``, ``nthread``, ``nthreads``
- number of threads for LightGBM - number of threads for LightGBM
...@@ -204,7 +204,7 @@ Learning Control Parameters ...@@ -204,7 +204,7 @@ Learning Control Parameters
- random seed for ``feature_fraction`` - random seed for ``feature_fraction``
- ``bagging_fraction``, default=\ ``1.0``, type=double, ``0.0 < bagging_fraction <= 1.0``, alias=\ ``sub_row``, ``subsample`` - ``bagging_fraction``, default=\ ``1.0``, type=double, ``0.0 < bagging_fraction <= 1.0``, alias=\ ``sub_row``, ``subsample``, ``bagging``
- like ``feature_fraction``, but this will randomly select part of data without resampling - like ``feature_fraction``, but this will randomly select part of data without resampling
...@@ -312,7 +312,7 @@ Learning Control Parameters ...@@ -312,7 +312,7 @@ Learning Control Parameters
- set this to larger value for more accurate result, but it will slow down the training speed - set this to larger value for more accurate result, but it will slow down the training speed
- ``monotone_constraint``, default=\ ``None``, type=multi-int, alias=\ ``mc`` - ``monotone_constraint``, default=\ ``None``, type=multi-int, alias=\ ``mc``, ``monotone_constraints``
- used for constraints of monotonic features - used for constraints of monotonic features
...@@ -443,7 +443,7 @@ IO Parameters ...@@ -443,7 +443,7 @@ IO Parameters
- **Note**: the negative values will be treated as **missing values** - **Note**: the negative values will be treated as **missing values**
- ``predict_raw_score``, default=\ ``false``, type=bool, alias=\ ``raw_score``, ``is_predict_raw_score`` - ``predict_raw_score``, default=\ ``false``, type=bool, alias=\ ``raw_score``, ``is_predict_raw_score``, ``predict_rawscore``
- only used in ``prediction`` task - only used in ``prediction`` task
...@@ -501,17 +501,17 @@ IO Parameters ...@@ -501,17 +501,17 @@ IO Parameters
- set to ``false`` to use ``na`` to represent missing values - set to ``false`` to use ``na`` to represent missing values
- ``init_score_file``, default=\ ``""``, type=string - ``init_score_file``, default=\ ``""``, type=string, alias=\ ``init_score_filename``, ``initscore_filename``, ``init_score``
- path to training initial score file, ``""`` will use ``train_data_file`` + ``.init`` (if exists) - path to training initial score file, ``""`` will use ``train_data_file`` + ``.init`` (if exists)
- ``valid_init_score_file``, default=\ ``""``, type=multi-string - ``valid_init_score_file``, default=\ ``""``, type=multi-string, alias=\ ``valid_data_initscores``, ``valid_data_init_scores``, ``valid_init_score``
- path to validation initial score file, ``""`` will use ``valid_data_file`` + ``.init`` (if exists) - path to validation initial score file, ``""`` will use ``valid_data_file`` + ``.init`` (if exists)
- separate by ``,`` for multi-validation data - separate by ``,`` for multi-validation data
- ``forced_splits``, default=\ ``""``, type=string - ``forced_splits``, default=\ ``""``, type=string, alias=\ ``forced_splits_file``, ``forcedsplits_filename``, ``forced_splits_filename``
- path to a ``.json`` file that specifies splits to force at the top of every decision tree before best-first learning commences - path to a ``.json`` file that specifies splits to force at the top of every decision tree before best-first learning commences
...@@ -593,7 +593,7 @@ Objective Parameters ...@@ -593,7 +593,7 @@ Objective Parameters
Metric Parameters Metric Parameters
----------------- -----------------
- ``metric``, default=\ ``''``, type=multi-enum - ``metric``, default=\ ``''``, type=multi-enum, alias=\ ``metric_types``
- metric to be evaluated on the evaluation sets **in addition** to what is provided in the training arguments - metric to be evaluated on the evaluation sets **in addition** to what is provided in the training arguments
...@@ -650,7 +650,7 @@ Metric Parameters ...@@ -650,7 +650,7 @@ Metric Parameters
- frequency for metric output - frequency for metric output
- ``train_metric``, default=\ ``false``, type=bool, alias=\ ``training_metric``, ``is_training_metric`` - ``train_metric``, default=\ ``false``, type=bool, alias=\ ``training_metric``, ``is_training_metric``, ``is_provide_training_metric``
- set this to ``true`` if you need to output metric result of training - set this to ``true`` if you need to output metric result of training
......
import os
def GetParameterInfos(config_hpp):
is_inparameter = False
parameter_group = None
cur_key = None
cur_info = {}
keys = []
member_infos = []
with open(config_hpp) as config_hpp_file:
for line in config_hpp_file:
if "#pragma region Parameters" in line:
is_inparameter = True
elif "#pragma region" in line and "Parameters" in line:
cur_key = line.split("region")[1].strip()
keys.append(cur_key)
member_infos.append([])
elif '#pragma endregion' in line:
if cur_key is not None:
cur_key = None
elif is_inparameter:
is_inparameter = False
elif cur_key is not None:
line = line.strip()
if line.startswith("//"):
tokens = line.split("//")[1].split("=")
key = tokens[0].strip()
val = '='.join(tokens[1:]).strip()
if key not in cur_info:
if key == "descl2":
cur_info["desc"] = []
else:
cur_info[key] = []
if key == "desc":
cur_info["desc"].append(["l1", val])
elif key == "descl2":
cur_info["desc"].append(["l2", val])
else:
cur_info[key].append(val)
elif line:
has_eqsgn = False
tokens = line.split("=")
if len(tokens) == 2:
if "default" not in cur_info:
cur_info["default"] = [tokens[1][:-1].strip()]
has_eqsgn = True
tokens = line.split()
cur_info["inner_type"] = [tokens[0].strip()]
if "name" not in cur_info:
if has_eqsgn:
cur_info["name"] = [tokens[1].strip()]
else:
cur_info["name"] = [tokens[1][:-1].strip()]
member_infos[-1].append(cur_info)
cur_info = {}
return (keys, member_infos)
def GetNames(infos):
names = []
for x in infos:
for y in x:
names.append(y["name"][0])
return names
def GetAlias(infos):
pairs = []
for x in infos:
for y in x:
if "alias" in y:
name = y["name"][0]
alias = y["alias"][0].split(',')
for name2 in alias:
pairs.append([name2.strip(), name])
return pairs
def SetOneVarFromString(name, type, checks):
ret = ""
univar_mapper = {"int": "GetInt", "double": "GetDouble", "bool": "GetBool", "std::string": "GetString"}
if "vector" not in type:
ret += " %s(params, \"%s\", &%s);\n" % (univar_mapper[type], name, name)
if len(checks) > 0:
for check in checks:
ret += " CHECK(%s %s);\n" % (name, check)
ret += "\n"
else:
ret += " if (GetString(params, \"%s\", &tmp_str)) {\n" % (name)
type2 = type.split("<")[1][:-1]
if type2 == "std::string":
ret += " %s = Common::Split(tmp_str.c_str(), ',');\n" % (name)
else:
ret += " %s = Common::StringToArray<%s>(tmp_str, ',');\n" % (name, type2)
ret += " }\n\n"
return ret
def GenParameterCode(config_hpp, config_out_cpp):
keys, infos = GetParameterInfos(config_hpp)
names = GetNames(infos)
alias = GetAlias(infos)
str_to_write = "/// This file is auto generated by LightGBM\\helper\\parameter_generator.py\n"
str_to_write += "#include<LightGBM/config.h>\nnamespace LightGBM {\n"
# alias table
str_to_write += "std::unordered_map<std::string, std::string> Config::alias_table({\n"
for pair in alias:
str_to_write += " {\"%s\", \"%s\"}, \n" % (pair[0], pair[1])
str_to_write += "});\n\n"
# names
str_to_write += "std::unordered_set<std::string> Config::parameter_set({\n"
for name in names:
str_to_write += " \"%s\", \n" % (name)
str_to_write += "});\n\n"
# from strings
str_to_write += "void Config::GetMembersFromString(const std::unordered_map<std::string, std::string>& params) {\n"
str_to_write += " std::string tmp_str = \"\";\n"
for x in infos:
for y in x:
if "[doc-only]" in y:
continue
type = y["inner_type"][0]
name = y["name"][0]
checks = []
if "check" in y:
checks = y["check"]
tmp = SetOneVarFromString(name, type, checks)
str_to_write += tmp
# tails
str_to_write += "}\n\n"
str_to_write += "std::string Config::SaveMembersToString() const {\n"
str_to_write += " std::stringstream str_buf;\n"
for x in infos:
for y in x:
if "[doc-only]" in y:
continue
type = y["inner_type"][0]
name = y["name"][0]
if "vector" in type:
if "int8" in type:
str_to_write += " str_buf << \"[%s: \" << Common::Join(Common::ArrayCast<int8_t, int>(%s),\",\") << \"]\\n\";\n" % (name, name)
else:
str_to_write += " str_buf << \"[%s: \" << Common::Join(%s,\",\") << \"]\\n\";\n" % (name, name)
else:
str_to_write += " str_buf << \"[%s: \" << %s << \"]\\n\";\n" % (name, name)
# tails
str_to_write += " return str_buf.str();\n"
str_to_write += "}\n\n"
str_to_write += "}\n"
with open(config_out_cpp, "w") as config_out_cpp_file:
config_out_cpp_file.write(str_to_write)
if __name__ == "__main__":
config_hpp = os.path.join(os.path.pardir, 'include', 'LightGBM', 'config.h')
config_out_cpp = os.path.join(os.path.pardir, 'src', 'io', 'config_auto.cpp')
GenParameterCode(config_hpp, config_out_cpp)
...@@ -56,7 +56,7 @@ private: ...@@ -56,7 +56,7 @@ private:
void ConvertModel(); void ConvertModel();
/*! \brief All configs */ /*! \brief All configs */
OverallConfig config_; Config config_;
/*! \brief Training data */ /*! \brief Training data */
std::unique_ptr<Dataset> train_data_; std::unique_ptr<Dataset> train_data_;
/*! \brief Validation data */ /*! \brief Validation data */
...@@ -73,10 +73,10 @@ private: ...@@ -73,10 +73,10 @@ private:
inline void Application::Run() { inline void Application::Run() {
if (config_.task_type == TaskType::kPredict || config_.task_type == TaskType::KRefitTree) { if (config_.task == TaskType::kPredict || config_.task == TaskType::KRefitTree) {
InitPredict(); InitPredict();
Predict(); Predict();
} else if (config_.task_type == TaskType::kConvertModel) { } else if (config_.task == TaskType::kConvertModel) {
ConvertModel(); ConvertModel();
} else { } else {
InitTrain(); InitTrain();
......
...@@ -32,7 +32,7 @@ public: ...@@ -32,7 +32,7 @@ public:
* \param training_metrics Training metric * \param training_metrics Training metric
*/ */
virtual void Init( virtual void Init(
const BoostingConfig* config, const Config* config,
const Dataset* train_data, const Dataset* train_data,
const ObjectiveFunction* objective_function, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0; const std::vector<const Metric*>& training_metrics) = 0;
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function, virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0; const std::vector<const Metric*>& training_metrics) = 0;
virtual void ResetConfig(const BoostingConfig* config) = 0; virtual void ResetConfig(const Config* config) = 0;
......
...@@ -16,27 +16,15 @@ ...@@ -16,27 +16,15 @@
namespace LightGBM { namespace LightGBM {
const std::string kDefaultTreeLearnerType = "serial"; /*! \brief Types of tasks */
const std::string kDefaultDevice = "cpu"; enum TaskType {
const std::string kDefaultBoostingType = "gbdt"; kTrain, kPredict, kConvertModel, KRefitTree
const std::string kDefaultObjectiveType = "regression"; };
const int kDefaultNumLeaves = 31; const int kDefaultNumLeaves = 31;
/*! struct Config {
* \brief The interface for Config
*/
struct ConfigBase {
public: public:
/*! \brief virtual destructor */ std::string ToString() const;
virtual ~ConfigBase() {}
/*!
* \brief Set current config object by params
* \param params Store the key and value for params
*/
virtual void Set(
const std::unordered_map<std::string, std::string>& params) = 0;
/*! /*!
* \brief Get string value by specific name of key * \brief Get string value by specific name of key
* \param params Store the key and value for params * \param params Store the key and value for params
...@@ -83,230 +71,627 @@ public: ...@@ -83,230 +71,627 @@ public:
static void KV2Map(std::unordered_map<std::string, std::string>& params, const char* kv); static void KV2Map(std::unordered_map<std::string, std::string>& params, const char* kv);
static std::unordered_map<std::string, std::string> Str2Map(const char* parameters); static std::unordered_map<std::string, std::string> Str2Map(const char* parameters);
};
/*! \brief Types of tasks */ #pragma region Parameters
enum TaskType { #pragma region Core Parameters
kTrain, kPredict, kConvertModel, KRefitTree
}; // [doc-only]
// alias=config_file
// desc=path of config file
// desc=**Note**: Only can be used in CLI version
std::string config = "";
// [doc-only]
// type=enum
// default=train
// options=train,predict,convert_model,refit
// alias=task_type
// desc=``train``, alias=\ ``training``, for training
// desc=``predict``, alias=\ ``prediction``, ``test``, for prediction
// desc=``convert_model``, for converting model file into if-else format, see more information in `Convert model parameters <#convert-model-parameters>`__
// desc=``refit``, alias = \ ``refit_tree``, refit existing models with new data
// desc=**Note**: Only can be used in CLI version
TaskType task = TaskType::kTrain;
// [doc-only]
// type=enum
// options=regression,regression_l1,huber,fair,poisson,quantile,mape,gammma,tweedie,binary,multiclass,multiclassova,xentropy,xentlambda,lambdarank
// alias=application,app,objective_type
// desc=regression application
// descl2=``regression_l2``, L2 loss, alias=\ ``regression``, ``mean_squared_error``, ``mse``, ``l2_root``, ``root_mean_squared_error``, ``rmse``
// descl2=``regression_l1``, L1 loss, alias=\ ``mean_absolute_error``, ``mae``
// descl2=``huber``, `Huber loss`_
// descl2=``fair``, `Fair loss`_
// descl2=``poisson``, `Poisson regression`_
// descl2=``quantile``, `Quantile regression`_
// descl2=``mape``, `MAPE loss`_, alias=\ ``mean_absolute_percentage_error``
// descl2=``gamma``, Gamma regression with log-link. It might be useful, e.g., for modeling insurance claims severity, or for any target that might be `gamma-distributed`_
// descl2=``tweedie``, Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any target that might be `tweedie-distributed`_
// desc=``binary``, binary `log loss`_ classification application
// desc=multi-class classification application
// descl2=``multiclass``, `softmax`_ objective function, alias=\ ``softmax``
// descl2=``multiclassova``, `One-vs-All`_ binary objective function, alias=\ ``multiclass_ova``, ``ova``, ``ovr``
// descl2=``num_class`` should be set as well
// desc=cross-entropy application
// descl2=``xentropy``, objective function for cross-entropy (with optional linear weights), alias=\ ``cross_entropy``
// descl2=``xentlambda``, alternative parameterization of cross-entropy, alias=\ ``cross_entropy_lambda``
// descl2=the label is anything in interval [0, 1]
// desc=``lambdarank``, `lambdarank`_ application
// descl2=the label should be ``int`` type in lambdarank tasks, and larger number represent the higher relevance (e.g. 0:bad, 1:fair, 2:good, 3:perfect)
// descl2=`label_gain <#objective-parameters>`__ can be used to set the gain(weight) of ``int`` label
// descl2=all values in ``label`` must be smaller than number of elements in ``label_gain``
std::string objective = "regression";
// [doc-only]
// type=enum
// alias=boosting_type,boost
// options=gbdt,rf,dart,goss
// desc=``gbdt``, traditional Gradient Boosting Decision Tree
// desc=``rf``, Random Forest
// desc=``dart``, `Dropouts meet Multiple Additive Regression Trees`_
// desc=``goss``, Gradient - based One - Side Sampling
std::string boosting = "gbdt";
// alias=train,train_data,data_filename
// desc=training data, LightGBM will train from this data
std::string data = "";
// alias=test,valid_data,test_data,valid_filenames
// desc=validation/test data, LightGBM will output metrics for these data
// desc=support multi validation data, separate by ``,``
std::vector<std::string> valid;
// alias=num_iteration,num_tree,num_trees,num_round,num_rounds,num_boost_round,n_estimators
// check=>=0
// desc=number of boosting iterations
// desc=**Note**: for Python/R package,**this parameter is ignored**, use num_boost_round (Python) or nrounds (R) input arguments of train and cv methods instead
// desc=**Note**: internally,LightGBM constructs num_class * num_iterations trees for multiclass problems
int num_iterations = 100;
/*! \brief Config for input and output files */ // alias=shrinkage_rate
struct IOConfig: public ConfigBase { // check=>0
public: // desc=shrinkage rate
// desc=in dart,it also affects on normalization weights of dropped trees
double learning_rate = 0.1;
// default=31
// alias = num_leaf
// check=>1
// desc=max number of leaves in one tree
int num_leaves = kDefaultNumLeaves;
// [doc-only]
// type=enum
// options=serial, feature, data, voting
// alias = tree, tree_learner_type
// desc=serial,single machine tree learner
// desc=feature,alias=feature_parallel,feature parallel tree learner
// desc=data,alias=data_parallel,data parallel tree learner
// desc=voting,alias=voting_parallel,voting parallel tree learner
// desc=refer to `Parallel Learning Guide <./Parallel-Learning-Guide.rst>`__ to get more details
std::string tree_learner = "serial";
// default=OpenMP_default
// alias = num_thread, nthread, nthreads
// desc = number of threads for LightGBM
// desc=for the best speed,set this to the number of **real CPU cores**,
// not the number of threads(most CPU using `hyper-threading`_ to generate 2 threads per CPU core)
// desc=do not set it too large if your dataset is small (do not use 64 threads for a dataset with 10,000 rows for instance)
// desc=be aware a task manager or any similar CPU monitoring tool might report cores not being fully utilized. **This is normal**
// desc=for parallel learning,should not use full CPU cores since this will cause poor performance for the network
int num_threads = 0;
// [doc-only]
// options=cpu,gpu
// desc = choose device for the tree learning, you can use GPU to achieve the faster learning
// desc=**Note**: it is recommended to use the smaller max_bin (e.g. 63) to get the better speed up
// desc=**Note**: for the faster speed,GPU use 32-bit float point to sum up by default,may affect the accuracy for some tasks.
// desc=You can set gpu_use_dp = true to enable 64 - bit float point, but it will slow down the training
// desc=**Note**: refer to `Installation Guide <./Installation-Guide.rst#build-gpu-version>`__ to build with GPU
std::string device_type = "cpu";
// [doc-only]
// alias=random_seed
// desc=Use this seed to generate seeds for others, e.g. data_random_seed.
// desc=Will be override if set other seeds as well
// default=none
int seed = 0;
#pragma endregion
#pragma region Learning Control Parameters
// desc=limit the max depth for tree model. This is used to deal with over-fitting when #data is small. Tree still grows by leaf-wise
// desc=< 0 means no limit
int max_depth = -1;
// alias = min_data_per_leaf, min_data, min_child_samples
// check=>=0
// desc=minimal number of data in one leaf. Can be used to deal with over-fitting
int min_data_in_leaf = 20;
// alias=min_sum_hessian_per_leaf,min_sum_hessian,min_hessian,min_child_weight
// check >=0
// desc=minimal sum hessian in one leaf. Like min_data_in_leaf,it can be used to deal with over-fitting
double min_sum_hessian_in_leaf = 1e-3;
// alias=sub_row,subsample,bagging
// check=>0
// check=<=1.0
// desc = like feature_fraction, but this will randomly select part of data without resampling
// desc=can be used to speed up training
// desc=can be used to deal with over-fitting
// desc=**Note**: To enable bagging,bagging_freq should be set to a non zero value as well
double bagging_fraction = 1.0;
// alias=subsample_freq
// desc=frequency for bagging,0 means disable bagging. k means will perform bagging at every k iteration
// desc=**Note**: to enable bagging,bagging_fraction should be set as well
int bagging_freq = 0;
// alias = bagging_fraction_seed
// desc = random seed for bagging
int bagging_seed = 3;
// alias = sub_feature, colsample_bytree
// check=>0
// check=<=1.0
// desc=LightGBM will randomly select part of features on each iteration if feature_fraction smaller than 1.0. For example, if set to 0.8, will select 80 % features before training each tree
// desc=can be used to speed up training
// desc=can be used to deal with over-fitting
double feature_fraction = 1.0;
// desc=random seed for feature_fraction
int feature_fraction_seed = 2;
// alias=early_stopping_rounds,early_stopping
// desc=will stop training if one metric of one validation data doesn't improve in last early_stopping_round rounds
// desc=enable when greater than 0
int early_stopping_round = 0;
// alias=max_tree_output,max_leaf_output
// desc=Used to limit the max output of tree leaves
// desc=when <= 0,there is not constraint
// desc=the final max output of leaves is learning_rate*max_delta_step
double max_delta_step = 0.0;
// alias=reg_alpha
// check=>=0
// desc=L1 regularization
double lambda_l1 = 0.0;
// alias = reg_lambda
// check=>=0
// desc = L2 regularization
double lambda_l2 = 0.0;
// alias=min_split_gain
// desc=the minimal gain to perform split
double min_gain_to_split = 0.0;
// check=>=0
// check=<=1.0
// desc=only used in dart
double drop_rate = 0.1;
// desc=only used in dart,max number of dropped trees on one iteration
// desc=<=0 means no limit
int max_drop = 50;
// check=>=0
// check=<=1.0
// desc=only used in dart,probability of skipping drop
double skip_drop = 0.5;
// desc=only used in dart,set this to true if want to use xgboost dart mode
bool xgboost_dart_mode = false;
// desc=only used in dart,set this to true if want to use uniform drop
bool uniform_drop = false;
// desc=only used in dart,random seed to choose dropping models
int drop_seed = 4;
// check=>=0
// check=<=1.0
// desc=only used in goss,the retain ratio of large gradient data
double top_rate = 0.2;
// check=>=0
// check=<=1.0
// desc=only used in goss,the retain ratio of small gradient data
double other_rate = 0.1;
// check=>0
// desc=min number of data per categorical group
int min_data_per_group = 100;
// check=>0
// desc=use for the categorical features
// desc=limit the max threshold points in categorical features
int max_cat_threshold = 32;
// check=>=0
// desc=L2 regularization in categorcial split
double cat_l2 = 10;
// check=>=0
// desc=used for the categorical features
// desc=this can reduce the effect of noises in categorical features,especially for categories with few data
double cat_smooth = 10;
// check=>0
// desc=when number of categories of one feature smaller than or equal to max_cat_to_onehot,one-vs-other split algorithm will be used
int max_cat_to_onehot = 4;
// alias = topk
// desc=used in `Voting parallel <./Parallel-Learning-Guide.rst#choose-appropriate-parallel-algorithm>`__
// desc=set this to larger value for more accurate result,but it will slow down the training speed
int top_k = 20;
// type = multi-int
// alias = mc,monotone_constraint
// default=none
// desc=used for constraints of monotonic features
// desc=1 means increasing,-1 means decreasing,0 means non-constraint
// desc=you need to specify all features in order. For example,mc=-1,0,1 means the decreasing for 1st feature,non-constraint for 2nd feature and increasing for the 3rd feature
std::vector<int8_t> monotone_constraints;
// alias=forced_splits_filename,forced_splits_file,forced_splits
// desc = path to a.json file that specifies splits to force at the top of every decision tree before best - first learning commences
// desc=.json file can be arbitrarily nested,and each split contains feature,threshold fields,as well as left and right fields representing subsplits.Categorical splits are forced in a one - hot fashion, with left representing the split containing the feature value and right representing other values
// desc=see `this file <https://github.com/Microsoft/LightGBM/tree/master/examples/binary_classification/forced_splits.json>`__ as an example
std::string forcedsplits_filename = "";
#pragma endregion
#pragma region IO Parameters
// check=>1
// desc=max number of bins that feature values will be bucketed in.
// desc=Small number of bins may reduce training accuracy but may increase general power(deal with over - fitting)
// desc=LightGBM will auto compress memory according max_bin.
// desc=For example, LightGBM will use uint8_t for feature value if max_bin = 255
int max_bin = 255; int max_bin = 255;
int num_class = 1;
// check=>0
// desc=min number of data inside one bin,use this to avoid one-data-one-bin (may over-fitting)
int min_data_in_bin = 3;
// desc=random seed for data partition in parallel learning (not include feature parallel)
int data_random_seed = 1; int data_random_seed = 1;
std::string data_filename = "";
std::string initscore_filename = ""; // alias=model_output,model_out
std::vector<std::string> valid_data_filenames; // desc=file name of output model in training
std::vector<std::string> valid_data_initscores;
int snapshot_freq = -1;
std::string output_model = "LightGBM_model.txt"; std::string output_model = "LightGBM_model.txt";
std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "gbdt_prediction.cpp"; // alias = model_input, model_in
// desc=file name of input model
// desc=for prediction task,this model will be used for prediction data
// desc=for train task,training will be continued from this model
std::string input_model = ""; std::string input_model = "";
int verbosity = 1; // alias=predict_result,prediction_result
int num_iteration_predict = -1; // desc=file name of prediction result in prediction task
bool is_pre_partition = false; std::string output_result = "LightGBM_predict_result.txt";
// alias = is_pre_partition
// desc=used for parallel learning (not include feature parallel)
// desc=true if training data are pre-partitioned,and different machines use different partitions
bool pre_partition = false;
// alias = is_sparse, enable_sparse
// desc = used to enable / disable sparse optimization.Set to false to disable sparse optimization
bool is_enable_sparse = true; bool is_enable_sparse = true;
/*! \brief The threshold of zero elements precentage for treating a feature as a sparse feature.
* Default is 0.8, where a feature is treated as a sparse feature when there are over 80% zeros. // check=>0
* When setting to 1.0, all features are processed as dense features. // check=<=1
*/ // desc=the threshold of zero elements precentage for treating a feature as a sparse feature.
double sparse_threshold = 0.8; double sparse_threshold = 0.8;
bool use_two_round_loading = false;
bool is_save_binary_file = false; // alias=two_round_loading,use_two_round_loading
bool enable_load_from_binary_file = true; // desc = by default, LightGBM will map data file to memory and load features from memory.
int bin_construct_sample_cnt = 200000; // desc = This will provide faster data loading speed.But it may run out of memory when the data file is very big
bool is_predict_leaf_index = false; // desc = set this to true if data file is too big to fit in memory
bool is_predict_contrib = false; bool two_round = false;
bool is_predict_raw_score = false;
int min_data_in_leaf = 20; // alias = is_save_binary, is_save_binary_file
int min_data_in_bin = 3; // desc = if true LightGBM will save the dataset(include validation data) to a binary file.
double max_conflict_rate = 0.0; // desc = Speed up the data loading for the next time
bool enable_bundle = true; bool save_binary = false;
bool has_header = false;
std::vector<int8_t> monotone_constraints; // alias=verbose
/*! \brief Index or column name of label, default is the first column // desc= <0 = Fatal, =0 = Error(Warn), >0 = Info
* And add an prefix "name:" while using column name */ int verbosity = 1;
// alias = has_header
// desc=set this to true if input data has header
bool header = false;
// alias=label
// desc=specify the label column
// desc=use number for index,e.g. label=0 means column\_0 is the label
// desc=add a prefix name: for column name,e.g. label=name:is_click
std::string label_column = ""; std::string label_column = "";
/*! \brief Index or column name of weight, < 0 means not used
* And add an prefix "name:" while using column name // alias=weight
* Note: when using Index, it doesn't count the label index */ // desc=specify the weight column
// desc=use number for index,e.g. weight=0 means column\_0 is the weight
// desc=add a prefix name: for column name,e.g. weight=name:weight
// desc=**Note**: index starts from 0. And it doesn't count the label column when passing type is Index,e.g. when label is column\_0,and weight is column\_1,the correct parameter is weight=0
std::string weight_column = ""; std::string weight_column = "";
/*! \brief Index or column name of group/query id, < 0 means not used
* And add an prefix "name:" while using column name // alias = query_column, group, query
* Note: when using Index, it doesn't count the label index */ // desc=specify the query/group id column
// desc=use number for index,e.g. query=0 means column\_0 is the query id
// desc=add a prefix name: for column name,e.g. query=name:query_id
// desc=**Note**: data should be grouped by query\_id. Index starts from 0. And it doesn't count the label column when passing type is Index,e.g. when label is column\_0 and query\_id is column\_1,the correct parameter is query=0
std::string group_column = ""; std::string group_column = "";
/*! \brief ignored features, separate by ','
* And add an prefix "name:" while using column name // alias = ignore_feature, blacklist
* Note: when using Index, it doesn't count the label index */ // desc=specify some ignoring columns in training
// desc=use number for index,e.g. ignore_column=0,1,2 means column\_0,column\_1 and column\_2 will be ignored
// desc=add a prefix name: for column name,e.g. ignore_column=name:c1,c2,c3 means c1,c2 and c3 will be ignored
// desc=**Note**: works only in case of loading data directly from file
// desc=**Note**: index starts from 0. And it doesn't count the label column
std::string ignore_column = ""; std::string ignore_column = "";
/*! \brief specific categorical columns, Note:only support for integer type categorical
* And add an prefix "name:" while using column name
* Note: when using Index, it doesn't count the label index */
std::string categorical_column = "";
std::string device_type = kDefaultDevice;
/*! \brief Set to true if want to use early stop for the prediction */ // alias=categorical_column,cat_feature,cat_column
// desc=specify categorical features
// desc=use number for index,e.g. categorical_feature=0,1,2 means column\_0,column\_1 and column\_2 are categorical features
// desc=add a prefix name: for column name,e.g. categorical_feature=name:c1,c2,c3 means c1,c2 and c3 are categorical features
// desc=**Note**: only supports categorical with int type. Index starts from 0. And it doesn't count the label column
// desc=**Note**: the negative values will be treated as **missing values**
std::string categorical_feature = "";
// alias=raw_score,is_predict_raw_score,predict_rawscore
// desc=only used in prediction task
// desc=set to true to predict only the raw scores
// desc=set to false to predict transformed scores
bool predict_raw_score = false;
// alias=leaf_index,is_predict_leaf_index
// desc=only used in prediction task
// desc=set to true to predict with leaf index of all trees
bool predict_leaf_index = false;
// alias=contrib,is_predict_contrib
// desc=only used in prediction task
// desc=set to true to estimate `SHAP values`_,which represent how each feature contributs to each prediction.
// desc=Produces number of features + 1 values where the last value is the expected value of the model output over the training data
bool predict_contrib = false;
// desc=only used in prediction task
// desc=use to specify how many trained iterations will be used in prediction
// desc=<= 0 means no limit
int num_iteration_predict = -1;
// desc=if true will use early-stopping to speed up the prediction. May affect the accuracy
bool pred_early_stop = false; bool pred_early_stop = false;
/*! \brief Frequency of checking the pred_early_stop */
// desc=the frequency of checking early-stopping prediction
int pred_early_stop_freq = 10; int pred_early_stop_freq = 10;
/*! \brief Threshold of margin of pred_early_stop */
// desc = the threshold of margin in early - stopping prediction
double pred_early_stop_margin = 10.0; double pred_early_stop_margin = 10.0;
bool zero_as_missing = false;
// alias=subsample_for_bin
// check=>0
// desc=number of data that sampled to construct histogram bins
// desc=will give better training result when set this larger,but will increase data loading time
// desc=set this to larger value if data is very sparse
int bin_construct_sample_cnt = 200000;
// desc=set to false to disable the special handle of missing value
bool use_missing = true; bool use_missing = true;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
};
/*! \brief Config for objective function */ // desc=set to true to treat all zero as missing values (including the unshown values in libsvm/sparse matrics)
struct ObjectiveConfig: public ConfigBase { // desc=set to false to use na to represent missing values
public: bool zero_as_missing = false;
virtual ~ObjectiveConfig() {}
// alias=init_score_filename,init_score_file,init_score
// desc = path to training initial score file, "" will use train_data_file + .init(if exists)
std::string initscore_filename = "";
// alias=valid_data_init_scores,valid_init_score_file,valid_init_score
// desc=path to validation initial score file,"" will use valid_data_file + .init (if exists)
// desc=separate by ,for multi-validation data
std::vector<std::string> valid_data_initscores;
// desc=max cache size(unit:MB) for historical histogram. < 0 means no limit
double histogram_pool_size = -1.0;
// desc=set to true to enable auto loading from previous saved binary datasets
// desc=set to false will ignore the binary datasets
bool enable_load_from_binary_file = true;
// desc=set to false to disable Exclusive Feature Bundling (EFB), which is described in LightGBM NIPS2017 paper
// desc=disable this may cause the slow training speed for sparse datasets
bool enable_bundle = true;
// check=>=0
// check=<1
// desc=max conflict rate for bundles in EFB
// desc=set to zero will diallow the conflict, and provide more accurace results
// desc=the speed may be faster if set it to a larger value
double max_conflict_rate = 0.0;
// desc=frequency of saving model file snapshot
// desc=set to positive numbers will enable this function
// desc=for example, the model file will be snopshoted at each iteration if set it to 1
int snapshot_freq = -1;
// desc=only cpp is supported yet
// desc=if convert_model_language is set when task is set to train,the model will also be converted
std::string convert_model_language = "";
// desc=output file name of converted model
std::string convert_model = "gbdt_prediction.cpp";
#pragma endregion
#pragma region Objective Parameters
// alias=num_classes
// desc=need to specify this in multi-class classification
int num_class = 1;
// check=>0
// desc=parameter for sigmoid function. Will be used in binary and multiclassova classification and in lambdarank
double sigmoid = 1.0; double sigmoid = 1.0;
// desc=parameter for `Huber loss`_ and `Quantile regression`_. Will be used in regression task
double alpha = 0.9;
// desc=parameter for `Fair loss`_. Will be used in regression task
double fair_c = 1.0; double fair_c = 1.0;
// desc=parameter for `Poisson regression`_ to safeguard optimization
double poisson_max_delta_step = 0.7; double poisson_max_delta_step = 0.7;
// for lambdarank
std::vector<double> label_gain; // desc=only used in regression task
// for lambdarank // desc=adjust initial score to the mean of labels for faster convergence
int max_position = 20; bool boost_from_average = true;
// for binary
// alias=unbalanced_sets
// desc=used in binary classification
// desc=set this to true if training data are unbalance
bool is_unbalance = false; bool is_unbalance = false;
// for multiclass
int num_class = 1; // check=>0
// Balancing of positive and negative weights // desc=weight of positive class in binary classification task
double scale_pos_weight = 1.0; double scale_pos_weight = 1.0;
// True will sqrt fit the sqrt(label)
// desc=only used in regression, usually works better for the large-range of labels
// desc=will fit sqrt(label) instead and prediction result will be also automatically converted to pow2(prediction)
bool reg_sqrt = false; bool reg_sqrt = false;
double alpha = 0.9;
double tweedie_variance_power = 1.5;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
};
/*! \brief Config for metrics interface*/ // desc=only used in tweedie regression
struct MetricConfig: public ConfigBase { // desc=controls the variance of the tweedie distribution
public: // desc=set closer to 2 to shift towards a gamma distribution
virtual ~MetricConfig() {} // desc=set closer to 1 to shift towards a poisson distribution
int num_class = 1;
double sigmoid = 1.0;
double fair_c = 1.0;
double alpha = 0.9;
double tweedie_variance_power = 1.5; double tweedie_variance_power = 1.5;
std::vector<double> label_gain;
std::vector<int> eval_at;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
};
// default = 0, 1, 3, 7, 15, 31, 63, ..., 2 ^ 30 - 1
// desc=used in lambdarank
// desc=relevant gain for labels. For example,the gain of label 2 is 3 if using default label gains
// desc=separate by ,
std::vector<double> label_gain;
/*! \brief Config for tree model */ // check=>0
struct TreeConfig: public ConfigBase { // desc=used in lambdarank
public: // desc=will optimize `NDCG`_ at this position
int min_data_in_leaf = 20; int max_position = 20;
double min_sum_hessian_in_leaf = 1e-3;
double max_delta_step = 0.0;
double lambda_l1 = 0.0;
double lambda_l2 = 0.0;
double min_gain_to_split = 0.0;
// should > 1
int num_leaves = kDefaultNumLeaves;
int feature_fraction_seed = 2;
double feature_fraction = 1.0;
// max cache size(unit:MB) for historical histogram. < 0 means no limit
double histogram_pool_size = -1.0;
// max depth of tree model.
// Still grow tree by leaf-wise, but limit the max depth to avoid over-fitting
// And the max leaves will be min(num_leaves, pow(2, max_depth))
// max_depth < 0 means no limit
int max_depth = -1;
int top_k = 20;
/*! \brief OpenCL platform ID. Usually each GPU vendor exposes one OpenCL platform.
* Default value is -1, using the system-wide default platform
*/
int gpu_platform_id = -1;
/*! \brief OpenCL device ID in the specified platform. Each GPU in the selected platform has a
* unique device ID. Default value is -1, using the default device in the selected platform
*/
int gpu_device_id = -1;
/*! \brief Set to true to use double precision math on GPU (default using single precision) */
bool gpu_use_dp = false;
int min_data_per_group = 100;
int max_cat_threshold = 32;
double cat_l2 = 10;
double cat_smooth = 10;
int max_cat_to_onehot = 4;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
};
/*! \brief Config for Boosting */ #pragma endregion
struct BoostingConfig: public ConfigBase {
public: #pragma region Metric Parameters
virtual ~BoostingConfig() {}
int output_freq = 1; // [doc-only]
// alias=metric_types
// default=''
// type=multi-enum
// desc=metric to be evaluated on the evaluation sets **in addition** to what is provided in the training arguments
// descl2='' (empty string or not specific),metric corresponding to specified objective will be used (this is possible only for pre - defined objective functions, otherwise no evaluation metric will be added)
// descl2='None' (string,**not** a None value),no metric registered,alias=na
// descl2=l1,absolute loss,alias=mean_absolute_error,mae,regression_l1
// descl2=l2,square loss,alias=mean_squared_error,mse,regression_l2,regression
// descl2=l2_root,root square loss,alias=root_mean_squared_error,rmse
// descl2=quantile,`Quantile regression`_
// descl2=mape,`MAPE loss`_,alias=mean_absolute_percentage_error
// descl2=huber,`Huber loss`_
// descl2=fair,`Fair loss`_
// descl2=poisson,negative log-likelihood for `Poisson regression`_
// descl2=gamma,negative log-likelihood for Gamma regression
// descl2=gamma_deviance,residual deviance for Gamma regression
// descl2=tweedie,negative log-likelihood for Tweedie regression
// descl2=ndcg,`NDCG`_
// descl2=map,`MAP`_,alias=mean_average_precision
// descl2=auc,`AUC`_
// descl2=binary_logloss,`log loss`_,alias=binary
// descl2=binary_error,for one sample: 0 for correct classification,1 for error classification
// descl2=multi_logloss,log loss for mulit-class classification,alias=multiclass,softmax,multiclassova,multiclass_ova,ova,ovr
// descl2=multi_error,error rate for mulit-class classification
// descl2=xentropy,cross-entropy (with optional linear weights),alias=cross_entropy
// descl2=xentlambda,"intensity-weighted" cross-entropy,alias=cross_entropy_lambda
// descl2=kldiv,`Kullback-Leibler divergence`_,alias=kullback_leibler
// desc=support multiple metrics,separated by ,
std::vector<std::string> metric;
// check=>0
// alias = output_freq
// desc = frequency for metric output
int metric_freq = 1;
// alias=training_metric,is_training_metric,train_metric
// desc=set this to true if you need to output metric result over training dataset
bool is_provide_training_metric = false; bool is_provide_training_metric = false;
int num_iterations = 100;
double learning_rate = 0.1;
double bagging_fraction = 1.0;
int bagging_seed = 3;
int bagging_freq = 0;
int early_stopping_round = 0;
int num_class = 1;
double drop_rate = 0.1;
int max_drop = 50;
double skip_drop = 0.5;
bool xgboost_dart_mode = false;
bool uniform_drop = false;
int drop_seed = 4;
double top_rate = 0.2;
double other_rate = 0.1;
// only used for the regression. Will boost from the average labels.
bool boost_from_average = true;
std::string tree_learner_type = kDefaultTreeLearnerType;
std::string device_type = kDefaultDevice;
TreeConfig tree_config;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
/* filename of forced splits */ // default=1,2,3,4,5
std::string forcedsplits_filename = ""; // alias=ndcg_eval_at,ndcg_at
}; // desc=`NDCG`_ evaluation positions,separated by ,
std::vector<int> eval_at;
/*! \brief Config for Network */ #pragma endregion
struct NetworkConfig: public ConfigBase {
public: #pragma region Network Parameters
// alias=num_machine
// desc=used for parallel learning,the number of machines for parallel learning application
// desc=need to set this in both socket and mpi versions
int num_machines = 1; int num_machines = 1;
// alias = local_port
// desc=TCP listen port for local machines
// desc=you should allow this port in firewall settings before training
int local_listen_port = 12400; int local_listen_port = 12400;
// desc=socket time-out in minutes
int time_out = 120; // in minutes int time_out = 120; // in minutes
// alias=mlist
// desc=file that lists machines for this parallel learning application
// desc=each line contains one IP and one port for one machine. The format is ip port,separate by space
std::string machine_list_filename = ""; std::string machine_list_filename = "";
// alias=works,nodes
// desc=list of machines, format: ip1:port1,ip2:port2
std::string machines = ""; std::string machines = "";
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
};
#pragma endregion
#pragma region GPU Parameters
// desc=OpenCL platform ID. Usually each GPU vendor exposes one OpenCL platform
// desc=default value is -1,means the system-wide default platform
int gpu_platform_id = -1;
// desc=OpenCL device ID in the specified platform. Each GPU in the selected platform has a unique device ID
// desc=default value is -1,means the default device in the selected platform
int gpu_device_id = -1;
// desc=set to true to use double precision math on GPU (default using single precision)
bool gpu_use_dp = false;
#pragma endregion
#pragma endregion
/*! \brief Overall config, all configs will put on this class */
struct OverallConfig: public ConfigBase {
public:
TaskType task_type = TaskType::kTrain;
NetworkConfig network_config;
int seed = 0;
int num_threads = 0;
bool is_parallel = false; bool is_parallel = false;
bool is_parallel_find_bin = false; bool is_parallel_find_bin = false;
IOConfig io_config; LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params);
std::string boosting_type = kDefaultBoostingType; static std::unordered_map<std::string, std::string> alias_table;
BoostingConfig boosting_config; static std::unordered_set<std::string> parameter_set;
std::string objective_type = kDefaultObjectiveType;
ObjectiveConfig objective_config;
std::vector<std::string> metric_types;
MetricConfig metric_config;
std::string convert_model_language = "";
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
private: private:
void CheckParamConflict(); void CheckParamConflict();
void GetMembersFromString(const std::unordered_map<std::string, std::string>& params);
std::string SaveMembersToString() const;
}; };
inline bool Config::GetString(
inline bool ConfigBase::GetString(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, std::string* out) { const std::string& name, std::string* out) {
if (params.count(name) > 0) { if (params.count(name) > 0) {
...@@ -316,33 +701,33 @@ inline bool ConfigBase::GetString( ...@@ -316,33 +701,33 @@ inline bool ConfigBase::GetString(
return false; return false;
} }
inline bool ConfigBase::GetInt( inline bool Config::GetInt(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, int* out) { const std::string& name, int* out) {
if (params.count(name) > 0) { if (params.count(name) > 0) {
if (!Common::AtoiAndCheck(params.at(name).c_str(), out)) { if (!Common::AtoiAndCheck(params.at(name).c_str(), out)) {
Log::Fatal("Parameter %s should be of type int, got \"%s\"", Log::Fatal("Parameter %s should be of type int, got \"%s\"",
name.c_str(), params.at(name).c_str()); name.c_str(), params.at(name).c_str());
} }
return true; return true;
} }
return false; return false;
} }
inline bool ConfigBase::GetDouble( inline bool Config::GetDouble(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, double* out) { const std::string& name, double* out) {
if (params.count(name) > 0) { if (params.count(name) > 0) {
if (!Common::AtofAndCheck(params.at(name).c_str(), out)) { if (!Common::AtofAndCheck(params.at(name).c_str(), out)) {
Log::Fatal("Parameter %s should be of type double, got \"%s\"", Log::Fatal("Parameter %s should be of type double, got \"%s\"",
name.c_str(), params.at(name).c_str()); name.c_str(), params.at(name).c_str());
} }
return true; return true;
} }
return false; return false;
} }
inline bool ConfigBase::GetBool( inline bool Config::GetBool(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, bool* out) { const std::string& name, bool* out) {
if (params.count(name) > 0) { if (params.count(name) > 0) {
...@@ -354,7 +739,7 @@ inline bool ConfigBase::GetBool( ...@@ -354,7 +739,7 @@ inline bool ConfigBase::GetBool(
*out = true; *out = true;
} else { } else {
Log::Fatal("Parameter %s should be \"true\"/\"+\" or \"false\"/\"-\", got \"%s\"", Log::Fatal("Parameter %s should be \"true\"/\"+\" or \"false\"/\"-\", got \"%s\"",
name.c_str(), params.at(name).c_str()); name.c_str(), params.at(name).c_str());
} }
return true; return true;
} }
...@@ -363,154 +748,28 @@ inline bool ConfigBase::GetBool( ...@@ -363,154 +748,28 @@ inline bool ConfigBase::GetBool(
struct ParameterAlias { struct ParameterAlias {
static void KeyAliasTransform(std::unordered_map<std::string, std::string>* params) { static void KeyAliasTransform(std::unordered_map<std::string, std::string>* params) {
const std::unordered_map<std::string, std::string> alias_table(
{
{ "config", "config_file" },
{ "nthread", "num_threads" },
{ "num_thread", "num_threads" },
{ "random_seed", "seed" },
{ "boosting", "boosting_type" },
{ "boost", "boosting_type" },
{ "application", "objective" },
{ "app", "objective" },
{ "train_data", "data" },
{ "train", "data" },
{ "model_output", "output_model" },
{ "model_out", "output_model" },
{ "model_input", "input_model" },
{ "model_in", "input_model" },
{ "predict_result", "output_result" },
{ "prediction_result", "output_result" },
{ "valid", "valid_data" },
{ "test_data", "valid_data" },
{ "test", "valid_data" },
{ "is_sparse", "is_enable_sparse" },
{ "enable_sparse", "is_enable_sparse" },
{ "pre_partition", "is_pre_partition" },
{ "training_metric", "is_training_metric" },
{ "train_metric", "is_training_metric" },
{ "ndcg_at", "ndcg_eval_at" },
{ "eval_at", "ndcg_eval_at" },
{ "min_data_per_leaf", "min_data_in_leaf" },
{ "min_data", "min_data_in_leaf" },
{ "min_child_samples", "min_data_in_leaf" },
{ "min_sum_hessian_per_leaf", "min_sum_hessian_in_leaf" },
{ "min_sum_hessian", "min_sum_hessian_in_leaf" },
{ "min_hessian", "min_sum_hessian_in_leaf" },
{ "min_child_weight", "min_sum_hessian_in_leaf" },
{ "num_leaf", "num_leaves" },
{ "sub_feature", "feature_fraction" },
{ "colsample_bytree", "feature_fraction" },
{ "num_iteration", "num_iterations" },
{ "num_tree", "num_iterations" },
{ "num_round", "num_iterations" },
{ "num_trees", "num_iterations" },
{ "num_rounds", "num_iterations" },
{ "num_boost_round", "num_iterations" },
{ "n_estimators", "num_iterations"},
{ "sub_row", "bagging_fraction" },
{ "subsample", "bagging_fraction" },
{ "subsample_freq", "bagging_freq" },
{ "shrinkage_rate", "learning_rate" },
{ "tree", "tree_learner" },
{ "num_machine", "num_machines" },
{ "local_port", "local_listen_port" },
{ "two_round_loading", "use_two_round_loading"},
{ "two_round", "use_two_round_loading" },
{ "mlist", "machine_list_file" },
{ "is_save_binary", "is_save_binary_file" },
{ "save_binary", "is_save_binary_file" },
{ "early_stopping_rounds", "early_stopping_round"},
{ "early_stopping", "early_stopping_round"},
{ "verbosity", "verbose" },
{ "header", "has_header" },
{ "label", "label_column" },
{ "weight", "weight_column" },
{ "group", "group_column" },
{ "query", "group_column" },
{ "query_column", "group_column" },
{ "ignore_feature", "ignore_column" },
{ "blacklist", "ignore_column" },
{ "categorical_feature", "categorical_column" },
{ "cat_column", "categorical_column" },
{ "cat_feature", "categorical_column" },
{ "predict_raw_score", "is_predict_raw_score" },
{ "raw_score", "is_predict_raw_score" },
{ "leaf_index", "is_predict_leaf_index" },
{ "predict_leaf_index", "is_predict_leaf_index" },
{ "contrib", "is_predict_contrib" },
{ "predict_contrib", "is_predict_contrib" },
{ "min_split_gain", "min_gain_to_split" },
{ "topk", "top_k" },
{ "reg_alpha", "lambda_l1" },
{ "reg_lambda", "lambda_l2" },
{ "num_classes", "num_class" },
{ "unbalanced_sets", "is_unbalance" },
{ "bagging_fraction_seed", "bagging_seed" },
{ "workers", "machines" },
{ "nodes", "machines" },
{ "subsample_for_bin", "bin_construct_sample_cnt" },
{ "metric_freq", "output_freq" },
{ "mc", "monotone_constraints" },
{ "max_tree_output", "max_delta_step" },
{ "max_leaf_output", "max_delta_step" }
});
const std::unordered_set<std::string> parameter_set({
"config", "config_file", "task", "device",
"num_threads", "seed", "boosting_type", "objective", "data",
"output_model", "input_model", "output_result", "valid_data",
"is_enable_sparse", "is_pre_partition", "is_training_metric",
"ndcg_eval_at", "min_data_in_leaf", "min_sum_hessian_in_leaf",
"num_leaves", "feature_fraction", "num_iterations",
"bagging_fraction", "bagging_freq", "learning_rate", "tree_learner",
"num_machines", "local_listen_port", "use_two_round_loading",
"machine_list_file", "is_save_binary_file", "early_stopping_round",
"verbose", "has_header", "label_column", "weight_column", "group_column",
"ignore_column", "categorical_column", "is_predict_raw_score",
"is_predict_leaf_index", "min_gain_to_split", "top_k",
"lambda_l1", "lambda_l2", "num_class", "is_unbalance",
"max_depth", "max_bin", "bagging_seed",
"drop_rate", "skip_drop", "max_drop", "uniform_drop",
"xgboost_dart_mode", "drop_seed", "top_rate", "other_rate",
"min_data_in_bin", "data_random_seed", "bin_construct_sample_cnt",
"num_iteration_predict", "pred_early_stop", "pred_early_stop_freq",
"pred_early_stop_margin", "use_missing", "sigmoid",
"fair_c", "poission_max_delta_step", "scale_pos_weight",
"boost_from_average", "max_position", "label_gain",
"metric", "output_freq", "time_out",
"gpu_platform_id", "gpu_device_id", "gpu_use_dp",
"convert_model", "convert_model_language",
"feature_fraction_seed", "enable_bundle", "data_filename", "valid_data_filenames",
"snapshot_freq", "verbosity", "sparse_threshold", "enable_load_from_binary_file",
"max_conflict_rate", "poisson_max_delta_step",
"histogram_pool_size", "is_provide_training_metric", "machine_list_filename", "machines",
"zero_as_missing", "init_score_file", "valid_init_score_file", "is_predict_contrib",
"max_cat_threshold", "cat_smooth", "min_data_per_group", "cat_l2", "max_cat_to_onehot",
"alpha", "reg_sqrt", "tweedie_variance_power", "monotone_constraints", "max_delta_step",
"forced_splits"
});
std::unordered_map<std::string, std::string> tmp_map; std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) { for (const auto& pair : *params) {
auto alias = alias_table.find(pair.first); auto alias = Config::alias_table.find(pair.first);
if (alias != alias_table.end()) { // found alias if (alias != Config::alias_table.end()) { // found alias
auto alias_set = tmp_map.find(alias->second); auto alias_set = tmp_map.find(alias->second);
if (alias_set != tmp_map.end()) { // alias already set if (alias_set != tmp_map.end()) { // alias already set
// set priority by length & alphabetically to ensure reproducible behavior // set priority by length & alphabetically to ensure reproducible behavior
if (alias_set->second.size() < pair.first.size() || if (alias_set->second.size() < pair.first.size() ||
(alias_set->second.size() == pair.first.size() && alias_set->second < pair.first)) { (alias_set->second.size() == pair.first.size() && alias_set->second < pair.first)) {
Log::Warning("%s is set with %s=%s, %s=%s will be ignored. Current value: %s=%s", Log::Warning("%s is set with %s=%s, %s=%s will be ignored. Current value: %s=%s",
alias->second.c_str(), alias_set->second.c_str(), params->at(alias_set->second).c_str(), alias->second.c_str(), alias_set->second.c_str(), params->at(alias_set->second).c_str(),
pair.first.c_str(), pair.second.c_str(), alias->second.c_str(), params->at(alias_set->second).c_str()); pair.first.c_str(), pair.second.c_str(), alias->second.c_str(), params->at(alias_set->second).c_str());
} else { } else {
Log::Warning("%s is set with %s=%s, will be overridden by %s=%s. Current value: %s=%s", Log::Warning("%s is set with %s=%s, will be overridden by %s=%s. Current value: %s=%s",
alias->second.c_str(), alias_set->second.c_str(), params->at(alias_set->second).c_str(), alias->second.c_str(), alias_set->second.c_str(), params->at(alias_set->second).c_str(),
pair.first.c_str(), pair.second.c_str(), alias->second.c_str(), pair.second.c_str()); pair.first.c_str(), pair.second.c_str(), alias->second.c_str(), pair.second.c_str());
tmp_map[alias->second] = pair.first; tmp_map[alias->second] = pair.first;
} }
} else { // alias not set } else { // alias not set
tmp_map.emplace(alias->second, pair.first); tmp_map.emplace(alias->second, pair.first);
} }
} else if (parameter_set.find(pair.first) == parameter_set.end()) { } else if (Config::parameter_set.find(pair.first) == Config::parameter_set.end()) {
Log::Warning("Unknown parameter: %s", pair.first.c_str()); Log::Warning("Unknown parameter: %s", pair.first.c_str());
} }
} }
...@@ -520,9 +779,9 @@ struct ParameterAlias { ...@@ -520,9 +779,9 @@ struct ParameterAlias {
params->emplace(pair.first, params->at(pair.second)); params->emplace(pair.first, params->at(pair.second));
params->erase(pair.second); params->erase(pair.second);
} else { } else {
Log::Warning("%s is set=%s, %s=%s will be ignored. Current value: %s=%s", Log::Warning("%s is set=%s, %s=%s will be ignored. Current value: %s=%s",
pair.first.c_str(), alias->second.c_str(), pair.second.c_str(), params->at(pair.second).c_str(), pair.first.c_str(), alias->second.c_str(), pair.second.c_str(), params->at(pair.second).c_str(),
pair.first.c_str(), alias->second.c_str()); pair.first.c_str(), alias->second.c_str());
} }
} }
} }
......
...@@ -273,7 +273,7 @@ public: ...@@ -273,7 +273,7 @@ public:
* \param label_idx index of label column * \param label_idx index of label column
* \return Object of parser * \return Object of parser
*/ */
static Parser* CreateParser(const char* filename, bool has_header, int num_features, int label_idx); static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx);
}; };
/*! \brief The main class of data set, /*! \brief The main class of data set,
...@@ -292,7 +292,7 @@ public: ...@@ -292,7 +292,7 @@ public:
int** sample_non_zero_indices, int** sample_non_zero_indices,
const int* num_per_col, const int* num_per_col,
size_t total_sample_cnt, size_t total_sample_cnt,
const IOConfig& io_config); const Config& io_config);
/*! \brief Destructor */ /*! \brief Destructor */
LIGHTGBM_EXPORT ~Dataset(); LIGHTGBM_EXPORT ~Dataset();
......
...@@ -8,7 +8,7 @@ namespace LightGBM { ...@@ -8,7 +8,7 @@ namespace LightGBM {
class DatasetLoader { class DatasetLoader {
public: public:
LIGHTGBM_EXPORT DatasetLoader(const IOConfig& io_config, const PredictFunction& predict_fun, int num_class, const char* filename); LIGHTGBM_EXPORT DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename);
LIGHTGBM_EXPORT ~DatasetLoader(); LIGHTGBM_EXPORT ~DatasetLoader();
...@@ -54,7 +54,7 @@ private: ...@@ -54,7 +54,7 @@ private:
/*! \brief Check can load from binary file */ /*! \brief Check can load from binary file */
std::string CheckCanLoadFromBin(const char* filename); std::string CheckCanLoadFromBin(const char* filename);
const IOConfig& io_config_; const Config& config_;
/*! \brief Random generator*/ /*! \brief Random generator*/
Random random_; Random random_;
/*! \brief prediction function for initial model */ /*! \brief prediction function for initial model */
......
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
* \param type Specific type of metric * \param type Specific type of metric
* \param config Config for metric * \param config Config for metric
*/ */
LIGHTGBM_EXPORT static Metric* CreateMetric(const std::string& type, const MetricConfig& config); LIGHTGBM_EXPORT static Metric* CreateMetric(const std::string& type, const Config& config);
}; };
...@@ -56,11 +56,14 @@ public: ...@@ -56,11 +56,14 @@ public:
*/ */
class DCGCalculator { class DCGCalculator {
public: public:
static void DefaultEvalAt(std::vector<int>* eval_at);
static void DefaultLabelGain(std::vector<double>* label_gain);
/*! /*!
* \brief Initial logic * \brief Initial logic
* \param label_gain Gain for labels, default is 2^i - 1 * \param label_gain Gain for labels, default is 2^i - 1
*/ */
static void Init(std::vector<double> label_gain); static void Init(const std::vector<double>& label_gain);
/*! /*!
* \brief Calculate the DCG score at position k * \brief Calculate the DCG score at position k
......
...@@ -89,7 +89,7 @@ public: ...@@ -89,7 +89,7 @@ public:
* \brief Initialize * \brief Initialize
* \param config Config of network setting * \param config Config of network setting
*/ */
static void Init(NetworkConfig config); static void Init(Config config);
/*! /*!
* \brief Initialize * \brief Initialize
*/ */
......
...@@ -71,7 +71,7 @@ public: ...@@ -71,7 +71,7 @@ public:
* \param config Config for objective function * \param config Config for objective function
*/ */
LIGHTGBM_EXPORT static ObjectiveFunction* CreateObjectiveFunction(const std::string& type, LIGHTGBM_EXPORT static ObjectiveFunction* CreateObjectiveFunction(const std::string& type,
const ObjectiveConfig& config); const Config& config);
/*! /*!
* \brief Load objective function from string object * \brief Load objective function from string object
......
...@@ -170,7 +170,7 @@ public: ...@@ -170,7 +170,7 @@ public:
std::string ToJSON() const; std::string ToJSON() const;
/*! \brief Serialize this object to if-else statement*/ /*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool is_predict_leaf_index) const; std::string ToIfElse(int index, bool predict_leaf_index) const;
inline static bool IsZero(double fval) { inline static bool IsZero(double fval) {
if (fval > -kZeroThreshold && fval <= kZeroThreshold) { if (fval > -kZeroThreshold && fval <= kZeroThreshold) {
...@@ -307,9 +307,9 @@ private: ...@@ -307,9 +307,9 @@ private:
std::string NodeToJSON(int index) const; std::string NodeToJSON(int index) const;
/*! \brief Serialize one node to if-else statement*/ /*! \brief Serialize one node to if-else statement*/
std::string NodeToIfElse(int index, bool is_predict_leaf_index) const; std::string NodeToIfElse(int index, bool predict_leaf_index) const;
std::string NodeToIfElseByMap(int index, bool is_predict_leaf_index) const; std::string NodeToIfElseByMap(int index, bool predict_leaf_index) const;
double ExpectedValue() const; double ExpectedValue() const;
......
...@@ -36,9 +36,9 @@ public: ...@@ -36,9 +36,9 @@ public:
/*! /*!
* \brief Reset tree configs * \brief Reset tree configs
* \param tree_config config of tree * \param config config of tree
*/ */
virtual void ResetConfig(const TreeConfig* tree_config) = 0; virtual void ResetConfig(const Config* config) = 0;
/*! /*!
* \brief training tree model on dataset * \brief training tree model on dataset
...@@ -85,11 +85,11 @@ public: ...@@ -85,11 +85,11 @@ public:
* \brief Create object of tree learner * \brief Create object of tree learner
* \param learner_type Type of tree learner * \param learner_type Type of tree learner
* \param device_type Type of tree learner * \param device_type Type of tree learner
* \param tree_config config of tree * \param config config of tree
*/ */
static TreeLearner* CreateTreeLearner(const std::string& learner_type, static TreeLearner* CreateTreeLearner(const std::string& learner_type,
const std::string& device_type, const std::string& device_type,
const TreeConfig* tree_config); const Config* config);
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -33,7 +33,7 @@ Application::Application(int argc, char** argv) { ...@@ -33,7 +33,7 @@ Application::Application(int argc, char** argv) {
if (config_.num_threads > 0) { if (config_.num_threads > 0) {
omp_set_num_threads(config_.num_threads); omp_set_num_threads(config_.num_threads);
} }
if (config_.io_config.data_filename.size() == 0 && config_.task_type != TaskType::kConvertModel) { if (config_.data.size() == 0 && config_.task != TaskType::kConvertModel) {
Log::Fatal("No training/prediction data, application quit"); Log::Fatal("No training/prediction data, application quit");
} }
omp_set_nested(0); omp_set_nested(0);
...@@ -48,13 +48,13 @@ Application::~Application() { ...@@ -48,13 +48,13 @@ Application::~Application() {
void Application::LoadParameters(int argc, char** argv) { void Application::LoadParameters(int argc, char** argv) {
std::unordered_map<std::string, std::string> params; std::unordered_map<std::string, std::string> params;
for (int i = 1; i < argc; ++i) { for (int i = 1; i < argc; ++i) {
ConfigBase::KV2Map(params, argv[i]); Config::KV2Map(params, argv[i]);
} }
// check for alias // check for alias
ParameterAlias::KeyAliasTransform(&params); ParameterAlias::KeyAliasTransform(&params);
// read parameters from config file // read parameters from config file
if (params.count("config_file") > 0) { if (params.count("config") > 0) {
TextReader<size_t> config_reader(params["config_file"].c_str(), false); TextReader<size_t> config_reader(params["config"].c_str(), false);
config_reader.ReadAllLines(); config_reader.ReadAllLines();
if (!config_reader.Lines().empty()) { if (!config_reader.Lines().empty()) {
for (auto& line : config_reader.Lines()) { for (auto& line : config_reader.Lines()) {
...@@ -66,11 +66,11 @@ void Application::LoadParameters(int argc, char** argv) { ...@@ -66,11 +66,11 @@ void Application::LoadParameters(int argc, char** argv) {
if (line.size() == 0) { if (line.size() == 0) {
continue; continue;
} }
ConfigBase::KV2Map(params, line.c_str()); Config::KV2Map(params, line.c_str());
} }
} else { } else {
Log::Warning("Config file %s doesn't exist, will ignore", Log::Warning("Config file %s doesn't exist, will ignore",
params["config_file"].c_str()); params["config"].c_str());
} }
} }
// check for alias again // check for alias again
...@@ -87,37 +87,37 @@ void Application::LoadData() { ...@@ -87,37 +87,37 @@ void Application::LoadData() {
PredictFunction predict_fun = nullptr; PredictFunction predict_fun = nullptr;
PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig()); PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
// need to continue training // need to continue training
if (boosting_->NumberOfTotalModel() > 0 && config_.task_type != TaskType::KRefitTree) { if (boosting_->NumberOfTotalModel() > 0 && config_.task != TaskType::KRefitTree) {
predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, false, -1, -1)); predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, false, -1, -1));
predict_fun = predictor->GetPredictFunction(); predict_fun = predictor->GetPredictFunction();
} }
// sync up random seed for data partition // sync up random seed for data partition
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
config_.io_config.data_random_seed = Network::GlobalSyncUpByMin(config_.io_config.data_random_seed); config_.data_random_seed = Network::GlobalSyncUpByMin(config_.data_random_seed);
} }
DatasetLoader dataset_loader(config_.io_config, predict_fun, DatasetLoader dataset_loader(config_, predict_fun,
config_.boosting_config.num_class, config_.io_config.data_filename.c_str()); config_.num_class, config_.data.c_str());
// load Training data // load Training data
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
// load data for parallel training // load data for parallel training
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(),
config_.io_config.initscore_filename.c_str(), config_.initscore_filename.c_str(),
Network::rank(), Network::num_machines())); Network::rank(), Network::num_machines()));
} else { } else {
// load data for single machine // load data for single machine
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), config_.io_config.initscore_filename.c_str(), train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(), config_.initscore_filename.c_str(),
0, 1)); 0, 1));
} }
// need save binary file // need save binary file
if (config_.io_config.is_save_binary_file) { if (config_.save_binary) {
train_data_->SaveBinaryFile(nullptr); train_data_->SaveBinaryFile(nullptr);
} }
// create training metric // create training metric
if (config_.boosting_config.is_provide_training_metric) { if (config_.is_provide_training_metric) {
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config)); auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init(train_data_->metadata(), train_data_->num_data()); metric->Init(train_data_->metadata(), train_data_->num_data());
train_metric_.push_back(std::move(metric)); train_metric_.push_back(std::move(metric));
...@@ -126,28 +126,28 @@ void Application::LoadData() { ...@@ -126,28 +126,28 @@ void Application::LoadData() {
train_metric_.shrink_to_fit(); train_metric_.shrink_to_fit();
if (!config_.metric_types.empty()) { if (!config_.metric.empty()) {
// only when have metrics then need to construct validation data // only when have metrics then need to construct validation data
// Add validation data, if it exists // Add validation data, if it exists
for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) { for (size_t i = 0; i < config_.valid.size(); ++i) {
// add // add
auto new_dataset = std::unique_ptr<Dataset>( auto new_dataset = std::unique_ptr<Dataset>(
dataset_loader.LoadFromFileAlignWithOtherDataset( dataset_loader.LoadFromFileAlignWithOtherDataset(
config_.io_config.valid_data_filenames[i].c_str(), config_.valid[i].c_str(),
config_.io_config.valid_data_initscores[i].c_str(), config_.valid_data_initscores[i].c_str(),
train_data_.get()) train_data_.get())
); );
valid_datas_.push_back(std::move(new_dataset)); valid_datas_.push_back(std::move(new_dataset));
// need save binary file // need save binary file
if (config_.io_config.is_save_binary_file) { if (config_.save_binary) {
valid_datas_.back()->SaveBinaryFile(nullptr); valid_datas_.back()->SaveBinaryFile(nullptr);
} }
// add metric for validation data // add metric for validation data
valid_metrics_.emplace_back(); valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config)); auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init(valid_datas_.back()->metadata(), metric->Init(valid_datas_.back()->metadata(),
valid_datas_.back()->num_data()); valid_datas_.back()->num_data());
...@@ -167,30 +167,30 @@ void Application::LoadData() { ...@@ -167,30 +167,30 @@ void Application::LoadData() {
void Application::InitTrain() { void Application::InitTrain() {
if (config_.is_parallel) { if (config_.is_parallel) {
// need init network // need init network
Network::Init(config_.network_config); Network::Init(config_);
Log::Info("Finished initializing network"); Log::Info("Finished initializing network");
config_.boosting_config.tree_config.feature_fraction_seed = config_.feature_fraction_seed =
Network::GlobalSyncUpByMin(config_.boosting_config.tree_config.feature_fraction_seed); Network::GlobalSyncUpByMin(config_.feature_fraction_seed);
config_.boosting_config.tree_config.feature_fraction = config_.feature_fraction =
Network::GlobalSyncUpByMin(config_.boosting_config.tree_config.feature_fraction); Network::GlobalSyncUpByMin(config_.feature_fraction);
config_.boosting_config.drop_seed = config_.drop_seed =
Network::GlobalSyncUpByMin(config_.boosting_config.drop_seed); Network::GlobalSyncUpByMin(config_.drop_seed);
} }
// create boosting // create boosting
boosting_.reset( boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type, Boosting::CreateBoosting(config_.boosting,
config_.io_config.input_model.c_str())); config_.input_model.c_str()));
// create objective function // create objective function
objective_fun_.reset( objective_fun_.reset(
ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, ObjectiveFunction::CreateObjectiveFunction(config_.objective,
config_.objective_config)); config_));
// load training data // load training data
LoadData(); LoadData();
// initialize the objective function // initialize the objective function
objective_fun_->Init(train_data_->metadata(), train_data_->num_data()); objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
// initialize the boosting // initialize the boosting
boosting_->Init(&config_.boosting_config, train_data_.get(), objective_fun_.get(), boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
// add validation data into boosting // add validation data into boosting
for (size_t i = 0; i < valid_datas_.size(); ++i) { for (size_t i = 0; i < valid_datas_.size(); ++i) {
...@@ -202,22 +202,22 @@ void Application::InitTrain() { ...@@ -202,22 +202,22 @@ void Application::InitTrain() {
void Application::Train() { void Application::Train() {
Log::Info("Started training..."); Log::Info("Started training...");
boosting_->Train(config_.io_config.snapshot_freq, config_.io_config.output_model); boosting_->Train(config_.snapshot_freq, config_.output_model);
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str()); boosting_->SaveModelToFile(-1, config_.output_model.c_str());
// convert model to if-else statement code // convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) { if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str()); boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
} }
Log::Info("Finished training"); Log::Info("Finished training");
} }
void Application::Predict() { void Application::Predict() {
if (config_.task_type == TaskType::KRefitTree) { if (config_.task == TaskType::KRefitTree) {
// create predictor // create predictor
Predictor predictor(boosting_.get(), -1, false, true, false, false, 1, 1); Predictor predictor(boosting_.get(), -1, false, true, false, false, 1, 1);
predictor.Predict(config_.io_config.data_filename.c_str(), config_.io_config.output_result.c_str(), config_.io_config.has_header); predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header);
TextReader<int> result_reader(config_.io_config.output_result.c_str(), false); TextReader<int> result_reader(config_.output_result.c_str(), false);
result_reader.ReadAllLines(); result_reader.ReadAllLines();
std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size()); std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size());
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
...@@ -226,41 +226,41 @@ void Application::Predict() { ...@@ -226,41 +226,41 @@ void Application::Predict() {
// Free memory // Free memory
result_reader.Lines()[i].clear(); result_reader.Lines()[i].clear();
} }
DatasetLoader dataset_loader(config_.io_config, nullptr, DatasetLoader dataset_loader(config_, nullptr,
config_.boosting_config.num_class, config_.io_config.data_filename.c_str()); config_.num_class, config_.data.c_str());
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), config_.io_config.initscore_filename.c_str(), train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(), config_.initscore_filename.c_str(),
0, 1)); 0, 1));
train_metric_.clear(); train_metric_.clear();
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
config_.objective_config)); config_));
objective_fun_->Init(train_data_->metadata(), train_data_->num_data()); objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
boosting_->Init(&config_.boosting_config, train_data_.get(), objective_fun_.get(), boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
boosting_->RefitTree(pred_leaf); boosting_->RefitTree(pred_leaf);
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str()); boosting_->SaveModelToFile(-1, config_.output_model.c_str());
Log::Info("Finished RefitTree"); Log::Info("Finished RefitTree");
} else { } else {
// create predictor // create predictor
Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score, Predictor predictor(boosting_.get(), config_.num_iteration_predict, config_.predict_raw_score,
config_.io_config.is_predict_leaf_index, config_.io_config.is_predict_contrib, config_.predict_leaf_index, config_.predict_contrib,
config_.io_config.pred_early_stop, config_.io_config.pred_early_stop_freq, config_.pred_early_stop, config_.pred_early_stop_freq,
config_.io_config.pred_early_stop_margin); config_.pred_early_stop_margin);
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.data.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header); config_.output_result.c_str(), config_.header);
Log::Info("Finished prediction"); Log::Info("Finished prediction");
} }
} }
void Application::InitPredict() { void Application::InitPredict() {
boosting_.reset( boosting_.reset(
Boosting::CreateBoosting("gbdt", config_.io_config.input_model.c_str())); Boosting::CreateBoosting("gbdt", config_.input_model.c_str()));
Log::Info("Finished initializing prediction, total used %d iterations", boosting_->GetCurrentIteration()); Log::Info("Finished initializing prediction, total used %d iterations", boosting_->GetCurrentIteration());
} }
void Application::ConvertModel() { void Application::ConvertModel() {
boosting_.reset( boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type, config_.io_config.input_model.c_str())); Boosting::CreateBoosting(config_.boosting, config_.input_model.c_str()));
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str()); boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
} }
......
...@@ -29,11 +29,11 @@ public: ...@@ -29,11 +29,11 @@ public:
* \param boosting Input boosting model * \param boosting Input boosting model
* \param num_iteration Number of boosting round * \param num_iteration Number of boosting round
* \param is_raw_score True if need to predict result with raw score * \param is_raw_score True if need to predict result with raw score
* \param is_predict_leaf_index True to output leaf index instead of prediction score * \param predict_leaf_index True to output leaf index instead of prediction score
* \param is_predict_contrib True to output feature contributions instead of prediction score * \param predict_contrib True to output feature contributions instead of prediction score
*/ */
Predictor(Boosting* boosting, int num_iteration, Predictor(Boosting* boosting, int num_iteration,
bool is_raw_score, bool is_predict_leaf_index, bool is_predict_contrib, bool is_raw_score, bool predict_leaf_index, bool predict_contrib,
bool early_stop, int early_stop_freq, double early_stop_margin) { bool early_stop, int early_stop_freq, double early_stop_margin) {
early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig()); early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
...@@ -55,14 +55,14 @@ public: ...@@ -55,14 +55,14 @@ public:
{ {
num_threads_ = omp_get_num_threads(); num_threads_ = omp_get_num_threads();
} }
boosting->InitPredict(num_iteration, is_predict_contrib); boosting->InitPredict(num_iteration, predict_contrib);
boosting_ = boosting; boosting_ = boosting;
num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index, is_predict_contrib); num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, predict_leaf_index, predict_contrib);
num_feature_ = boosting_->MaxFeatureIdx() + 1; num_feature_ = boosting_->MaxFeatureIdx() + 1;
predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f)); predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f));
const int kFeatureThreshold = 100000; const int kFeatureThreshold = 100000;
const size_t KSparseThreshold = static_cast<size_t>(0.01 * num_feature_); const size_t KSparseThreshold = static_cast<size_t>(0.01 * num_feature_);
if (is_predict_leaf_index) { if (predict_leaf_index) {
predict_fun_ = [this, kFeatureThreshold, KSparseThreshold](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this, kFeatureThreshold, KSparseThreshold](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) { if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) {
...@@ -75,7 +75,7 @@ public: ...@@ -75,7 +75,7 @@ public:
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features); ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
} }
}; };
} else if (is_predict_contrib) { } else if (predict_contrib) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features); CopyToPredictBuffer(predict_buf_[tid].data(), features);
...@@ -127,27 +127,27 @@ public: ...@@ -127,27 +127,27 @@ public:
* \param data_filename Filename of data * \param data_filename Filename of data
* \param result_filename Filename of output result * \param result_filename Filename of output result
*/ */
void Predict(const char* data_filename, const char* result_filename, bool has_header) { void Predict(const char* data_filename, const char* result_filename, bool header) {
auto writer = VirtualFileWriter::Make(result_filename); auto writer = VirtualFileWriter::Make(result_filename);
if (!writer->Init()) { if (!writer->Init()) {
Log::Fatal("Prediction results file %s cannot be found", result_filename); Log::Fatal("Prediction results file %s cannot be found", result_filename);
} }
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx())); auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx()));
if (parser == nullptr) { if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename); Log::Fatal("Could not recognize the data format of data file %s", data_filename);
} }
TextReader<data_size_t> predict_data_reader(data_filename, has_header); TextReader<data_size_t> predict_data_reader(data_filename, header);
std::unordered_map<int, int> feature_names_map_; std::unordered_map<int, int> feature_names_map_;
bool need_adjust = false; bool need_adjust = false;
if (has_header) { if (header) {
std::string first_line = predict_data_reader.first_line(); std::string first_line = predict_data_reader.first_line();
std::vector<std::string> header = Common::Split(first_line.c_str(), "\t,"); std::vector<std::string> header_words = Common::Split(first_line.c_str(), "\t,");
header.erase(header.begin() + boosting_->LabelIdx()); header_words.erase(header_words.begin() + boosting_->LabelIdx());
for (int i = 0; i < static_cast<int>(header.size()); ++i) { for (int i = 0; i < static_cast<int>(header_words.size()); ++i) {
for (int j = 0; j < static_cast<int>(boosting_->FeatureNames().size()); ++j) { for (int j = 0; j < static_cast<int>(boosting_->FeatureNames().size()); ++j) {
if (header[i] == boosting_->FeatureNames()[j]) { if (header_words[i] == boosting_->FeatureNames()[j]) {
feature_names_map_[i] = j; feature_names_map_[i] = j;
break; break;
} }
......
...@@ -32,17 +32,17 @@ public: ...@@ -32,17 +32,17 @@ public:
* \param training_metrics Training metrics * \param training_metrics Training metrics
* \param output_model_filename Filename of output model * \param output_model_filename Filename of output model
*/ */
void Init(const BoostingConfig* config, const Dataset* train_data, void Init(const Config* config, const Dataset* train_data,
const ObjectiveFunction* objective_function, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, objective_function, training_metrics); GBDT::Init(config, train_data, objective_function, training_metrics);
random_for_drop_ = Random(gbdt_config_->drop_seed); random_for_drop_ = Random(config_->drop_seed);
sum_weight_ = 0.0f; sum_weight_ = 0.0f;
} }
void ResetConfig(const BoostingConfig* config) override { void ResetConfig(const Config* config) override {
GBDT::ResetConfig(config); GBDT::ResetConfig(config);
random_for_drop_ = Random(gbdt_config_->drop_seed); random_for_drop_ = Random(config_->drop_seed);
sum_weight_ = 0.0f; sum_weight_ = 0.0f;
} }
...@@ -57,7 +57,7 @@ public: ...@@ -57,7 +57,7 @@ public:
} }
// normalize // normalize
Normalize(); Normalize();
if (!gbdt_config_->uniform_drop) { if (!config_->uniform_drop) {
tree_weight_.push_back(shrinkage_rate_); tree_weight_.push_back(shrinkage_rate_);
sum_weight_ += shrinkage_rate_; sum_weight_ += shrinkage_rate_;
} }
...@@ -85,31 +85,31 @@ private: ...@@ -85,31 +85,31 @@ private:
*/ */
void DroppingTrees() { void DroppingTrees() {
drop_index_.clear(); drop_index_.clear();
bool is_skip = random_for_drop_.NextFloat() < gbdt_config_->skip_drop; bool is_skip = random_for_drop_.NextFloat() < config_->skip_drop;
// select dropping tree indices based on drop_rate and tree weights // select dropping tree indices based on drop_rate and tree weights
if (!is_skip) { if (!is_skip) {
double drop_rate = gbdt_config_->drop_rate; double drop_rate = config_->drop_rate;
if (!gbdt_config_->uniform_drop) { if (!config_->uniform_drop) {
double inv_average_weight = static_cast<double>(tree_weight_.size()) / sum_weight_; double inv_average_weight = static_cast<double>(tree_weight_.size()) / sum_weight_;
if (gbdt_config_->max_drop > 0) { if (config_->max_drop > 0) {
drop_rate = std::min(drop_rate, gbdt_config_->max_drop * inv_average_weight / sum_weight_); drop_rate = std::min(drop_rate, config_->max_drop * inv_average_weight / sum_weight_);
} }
for (int i = 0; i < iter_; ++i) { for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextFloat() < drop_rate * tree_weight_[i] * inv_average_weight) { if (random_for_drop_.NextFloat() < drop_rate * tree_weight_[i] * inv_average_weight) {
drop_index_.push_back(num_init_iteration_ + i); drop_index_.push_back(num_init_iteration_ + i);
if (drop_index_.size() >= static_cast<size_t>(gbdt_config_->max_drop)) { if (drop_index_.size() >= static_cast<size_t>(config_->max_drop)) {
break; break;
} }
} }
} }
} else { } else {
if (gbdt_config_->max_drop > 0) { if (config_->max_drop > 0) {
drop_rate = std::min(drop_rate, gbdt_config_->max_drop / static_cast<double>(iter_)); drop_rate = std::min(drop_rate, config_->max_drop / static_cast<double>(iter_));
} }
for (int i = 0; i < iter_; ++i) { for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextFloat() < drop_rate) { if (random_for_drop_.NextFloat() < drop_rate) {
drop_index_.push_back(num_init_iteration_ + i); drop_index_.push_back(num_init_iteration_ + i);
if (drop_index_.size() >= static_cast<size_t>(gbdt_config_->max_drop)) { if (drop_index_.size() >= static_cast<size_t>(config_->max_drop)) {
break; break;
} }
} }
...@@ -124,13 +124,13 @@ private: ...@@ -124,13 +124,13 @@ private:
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id); train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
} }
if (!gbdt_config_->xgboost_dart_mode) { if (!config_->xgboost_dart_mode) {
shrinkage_rate_ = gbdt_config_->learning_rate / (1.0f + static_cast<double>(drop_index_.size())); shrinkage_rate_ = config_->learning_rate / (1.0f + static_cast<double>(drop_index_.size()));
} else { } else {
if (drop_index_.empty()) { if (drop_index_.empty()) {
shrinkage_rate_ = gbdt_config_->learning_rate; shrinkage_rate_ = config_->learning_rate;
} else { } else {
shrinkage_rate_ = gbdt_config_->learning_rate / (gbdt_config_->learning_rate + static_cast<double>(drop_index_.size())); shrinkage_rate_ = config_->learning_rate / (config_->learning_rate + static_cast<double>(drop_index_.size()));
} }
} }
} }
...@@ -146,7 +146,7 @@ private: ...@@ -146,7 +146,7 @@ private:
*/ */
void Normalize() { void Normalize() {
double k = static_cast<double>(drop_index_.size()); double k = static_cast<double>(drop_index_.size());
if (!gbdt_config_->xgboost_dart_mode) { if (!config_->xgboost_dart_mode) {
for (auto i : drop_index_) { for (auto i : drop_index_) {
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id; auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id;
...@@ -159,7 +159,7 @@ private: ...@@ -159,7 +159,7 @@ private:
models_[curr_tree]->Shrinkage(-k); models_[curr_tree]->Shrinkage(-k);
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id); train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
if (!gbdt_config_->uniform_drop) { if (!config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f)); sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f));
tree_weight_[i] *= (k / (k + 1.0f)); tree_weight_[i] *= (k / (k + 1.0f));
} }
...@@ -174,12 +174,12 @@ private: ...@@ -174,12 +174,12 @@ private:
score_updater->AddScore(models_[curr_tree].get(), cur_tree_id); score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
// update training score // update training score
models_[curr_tree]->Shrinkage(-k / gbdt_config_->learning_rate); models_[curr_tree]->Shrinkage(-k / config_->learning_rate);
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id); train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
if (!gbdt_config_->uniform_drop) { if (!config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + gbdt_config_->learning_rate));; sum_weight_ -= tree_weight_[i] * (1.0f / (k + config_->learning_rate));;
tree_weight_[i] *= (k / (k + gbdt_config_->learning_rate)); tree_weight_[i] *= (k / (k + config_->learning_rate));
} }
} }
} }
......
...@@ -61,7 +61,7 @@ GBDT::~GBDT() { ...@@ -61,7 +61,7 @@ GBDT::~GBDT() {
#endif #endif
} }
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, void GBDT::Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) { const std::vector<const Metric*>& training_metrics) {
CHECK(train_data != nullptr); CHECK(train_data != nullptr);
CHECK(train_data->num_features() > 0); CHECK(train_data->num_features() > 0);
...@@ -70,9 +70,9 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -70,9 +70,9 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
num_iteration_for_pred_ = 0; num_iteration_for_pred_ = 0;
max_feature_idx_ = 0; max_feature_idx_ = 0;
num_class_ = config->num_class; num_class_ = config->num_class;
gbdt_config_ = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config)); config_ = std::unique_ptr<Config>(new Config(*config));
early_stopping_round_ = gbdt_config_->early_stopping_round; early_stopping_round_ = config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate; shrinkage_rate_ = config_->learning_rate;
std::string forced_splits_path = config->forcedsplits_filename; std::string forced_splits_path = config->forcedsplits_filename;
//load forced_splits file //load forced_splits file
...@@ -93,7 +93,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -93,7 +93,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
is_constant_hessian_ = false; is_constant_hessian_ = false;
} }
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->device_type, &gbdt_config_->tree_config)); tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type, config_.get()));
// init tree learner // init tree learner
tree_learner_->Init(train_data_, is_constant_hessian_); tree_learner_->Init(train_data_, is_constant_hessian_);
...@@ -123,7 +123,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -123,7 +123,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
feature_infos_ = train_data_->feature_infos(); feature_infos_ = train_data_->feature_infos();
// if need bagging, create buffer // if need bagging, create buffer
ResetBaggingConfig(gbdt_config_.get(), true); ResetBaggingConfig(config_.get(), true);
// reset config for tree learner // reset config for tree learner
class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true); class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
...@@ -214,7 +214,7 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t ...@@ -214,7 +214,7 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t
if (cnt <= 0) { if (cnt <= 0) {
return 0; return 0;
} }
data_size_t bag_data_cnt = static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt); data_size_t bag_data_cnt = static_cast<data_size_t>(config_->bagging_fraction * cnt);
data_size_t cur_left_cnt = 0; data_size_t cur_left_cnt = 0;
data_size_t cur_right_cnt = 0; data_size_t cur_right_cnt = 0;
auto right_buffer = buffer + bag_data_cnt; auto right_buffer = buffer + bag_data_cnt;
...@@ -233,7 +233,7 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t ...@@ -233,7 +233,7 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t
void GBDT::Bagging(int iter) { void GBDT::Bagging(int iter) {
// if need bagging // if need bagging
if ((bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) if ((bag_data_cnt_ < num_data_ && iter % config_->bagging_freq == 0)
|| need_re_bagging_) { || need_re_bagging_) {
need_re_bagging_ = false; need_re_bagging_ = false;
const data_size_t min_inner_size = 1000; const data_size_t min_inner_size = 1000;
...@@ -249,7 +249,7 @@ void GBDT::Bagging(int iter) { ...@@ -249,7 +249,7 @@ void GBDT::Bagging(int iter) {
if (cur_start > num_data_) { continue; } if (cur_start > num_data_) { continue; }
data_size_t cur_cnt = inner_size; data_size_t cur_cnt = inner_size;
if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; } if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i); Random cur_rand(config_->bagging_seed + iter * num_threads_ + i);
data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start); data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
offsets_buf_[i] = cur_start; offsets_buf_[i] = cur_start;
left_cnts_buf_[i] = cur_left_count; left_cnts_buf_[i] = cur_left_count;
...@@ -318,7 +318,7 @@ double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj) { ...@@ -318,7 +318,7 @@ double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj) {
void GBDT::Train(int snapshot_freq, const std::string& model_output_path) { void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
bool is_finished = false; bool is_finished = false;
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
for (int iter = 0; iter < gbdt_config_->num_iterations && !is_finished; ++iter) { for (int iter = 0; iter < config_->num_iterations && !is_finished; ++iter) {
is_finished = TrainOneIter(nullptr, nullptr); is_finished = TrainOneIter(nullptr, nullptr);
if (!is_finished) { if (!is_finished) {
is_finished = EvalAndCheckEarlyStopping(); is_finished = EvalAndCheckEarlyStopping();
...@@ -364,7 +364,7 @@ double GBDT::BoostFromAverage() { ...@@ -364,7 +364,7 @@ double GBDT::BoostFromAverage() {
if (models_.empty() && !train_score_updater_->has_init_score() if (models_.empty() && !train_score_updater_->has_init_score()
&& num_class_ <= 1 && num_class_ <= 1
&& objective_function_ != nullptr) { && objective_function_ != nullptr) {
if (gbdt_config_->boost_from_average) { if (config_->boost_from_average) {
double init_score = ObtainAutomaticInitialScore(objective_function_); double init_score = ObtainAutomaticInitialScore(objective_function_);
if (std::fabs(init_score) > kEpsilon) { if (std::fabs(init_score) > kEpsilon) {
train_score_updater_->AddScore(init_score, 0); train_score_updater_->AddScore(init_score, 0);
...@@ -580,7 +580,7 @@ std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* scor ...@@ -580,7 +580,7 @@ std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* scor
} }
std::string GBDT::OutputMetric(int iter) { std::string GBDT::OutputMetric(int iter) {
bool need_output = (iter % gbdt_config_->output_freq) == 0; bool need_output = (iter % config_->metric_freq) == 0;
std::string ret = ""; std::string ret = "";
std::stringstream msg_buf; std::stringstream msg_buf;
std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs; std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
...@@ -777,24 +777,24 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* ...@@ -777,24 +777,24 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
feature_infos_ = train_data_->feature_infos(); feature_infos_ = train_data_->feature_infos();
tree_learner_->ResetTrainingData(train_data); tree_learner_->ResetTrainingData(train_data);
ResetBaggingConfig(gbdt_config_.get(), true); ResetBaggingConfig(config_.get(), true);
} }
} }
void GBDT::ResetConfig(const BoostingConfig* config) { void GBDT::ResetConfig(const Config* config) {
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config)); auto new_config = std::unique_ptr<Config>(new Config(*config));
early_stopping_round_ = new_config->early_stopping_round; early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate; shrinkage_rate_ = new_config->learning_rate;
if (tree_learner_ != nullptr) { if (tree_learner_ != nullptr) {
tree_learner_->ResetConfig(&new_config->tree_config); tree_learner_->ResetConfig(new_config.get());
} }
if (train_data_ != nullptr) { if (train_data_ != nullptr) {
ResetBaggingConfig(new_config.get(), false); ResetBaggingConfig(new_config.get(), false);
} }
gbdt_config_.reset(new_config.release()); config_.reset(new_config.release());
} }
void GBDT::ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset) { void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
// if need bagging, create buffer // if need bagging, create buffer
if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) { if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
bag_data_cnt_ = bag_data_cnt_ =
......
...@@ -43,7 +43,7 @@ public: ...@@ -43,7 +43,7 @@ public:
* \param objective_function Training objective function * \param objective_function Training objective function
* \param training_metrics Training metrics * \param training_metrics Training metrics
*/ */
void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, void Init(const Config* gbdt_config, const Dataset* train_data,
const ObjectiveFunction* objective_function, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override; const std::vector<const Metric*>& training_metrics) override;
...@@ -83,7 +83,7 @@ public: ...@@ -83,7 +83,7 @@ public:
* \brief Reset Boosting Config * \brief Reset Boosting Config
* \param gbdt_config Config for boosting * \param gbdt_config Config for boosting
*/ */
void ResetConfig(const BoostingConfig* gbdt_config) override; void ResetConfig(const Config* gbdt_config) override;
/*! /*!
* \brief Adding a validation dataset * \brief Adding a validation dataset
...@@ -335,7 +335,7 @@ protected: ...@@ -335,7 +335,7 @@ protected:
/*! /*!
* \brief reset config for bagging * \brief reset config for bagging
*/ */
void ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset); void ResetBaggingConfig(const Config* config, bool is_change_dataset);
/*! /*!
* \brief Implement bagging logic * \brief Implement bagging logic
...@@ -384,7 +384,7 @@ protected: ...@@ -384,7 +384,7 @@ protected:
/*! \brief Pointer to training data */ /*! \brief Pointer to training data */
const Dataset* train_data_; const Dataset* train_data_;
/*! \brief Config of gbdt */ /*! \brief Config of gbdt */
std::unique_ptr<BoostingConfig> gbdt_config_; std::unique_ptr<Config> config_;
/*! \brief Tree learner, will use this class to learn trees */ /*! \brief Tree learner, will use this class to learn trees */
std::unique_ptr<TreeLearner> tree_learner_; std::unique_ptr<TreeLearner> tree_learner_;
/*! \brief Objective function */ /*! \brief Objective function */
......
...@@ -300,6 +300,10 @@ std::string GBDT::SaveModelToString(int num_iteration) const { ...@@ -300,6 +300,10 @@ std::string GBDT::SaveModelToString(int num_iteration) const {
for (size_t i = 0; i < pairs.size(); ++i) { for (size_t i = 0; i < pairs.size(); ++i) {
ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n'; ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n';
} }
if (config_ != nullptr) {
ss << "parameters:" << '\n';
ss << config_->ToString() << "\n";
}
return ss.str(); return ss.str();
} }
......
...@@ -39,7 +39,7 @@ public: ...@@ -39,7 +39,7 @@ public:
#endif #endif
} }
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, objective_function, training_metrics); GBDT::Init(config, train_data, objective_function, training_metrics);
ResetGoss(); ResetGoss();
...@@ -51,15 +51,15 @@ public: ...@@ -51,15 +51,15 @@ public:
ResetGoss(); ResetGoss();
} }
void ResetConfig(const BoostingConfig* config) override { void ResetConfig(const Config* config) override {
GBDT::ResetConfig(config); GBDT::ResetConfig(config);
ResetGoss(); ResetGoss();
} }
void ResetGoss() { void ResetGoss() {
CHECK(gbdt_config_->top_rate + gbdt_config_->other_rate <= 1.0f); CHECK(config_->top_rate + config_->other_rate <= 1.0f);
CHECK(gbdt_config_->top_rate > 0.0f && gbdt_config_->other_rate > 0.0f); CHECK(config_->top_rate > 0.0f && config_->other_rate > 0.0f);
if (gbdt_config_->bagging_freq > 0 && gbdt_config_->bagging_fraction != 1.0f) { if (config_->bagging_freq > 0 && config_->bagging_fraction != 1.0f) {
Log::Fatal("Cannot use bagging in GOSS"); Log::Fatal("Cannot use bagging in GOSS");
} }
Log::Info("Using GOSS"); Log::Info("Using GOSS");
...@@ -74,8 +74,8 @@ public: ...@@ -74,8 +74,8 @@ public:
right_write_pos_buf_.resize(num_threads_); right_write_pos_buf_.resize(num_threads_);
is_use_subset_ = false; is_use_subset_ = false;
if (gbdt_config_->top_rate + gbdt_config_->other_rate <= 0.5) { if (config_->top_rate + config_->other_rate <= 0.5) {
auto bag_data_cnt = static_cast<data_size_t>((gbdt_config_->top_rate + gbdt_config_->other_rate) * num_data_); auto bag_data_cnt = static_cast<data_size_t>((config_->top_rate + config_->other_rate) * num_data_);
bag_data_cnt = std::max(1, bag_data_cnt); bag_data_cnt = std::max(1, bag_data_cnt);
tmp_subset_.reset(new Dataset(bag_data_cnt)); tmp_subset_.reset(new Dataset(bag_data_cnt));
tmp_subset_->CopyFeatureMapperFrom(train_data_); tmp_subset_->CopyFeatureMapperFrom(train_data_);
...@@ -93,8 +93,8 @@ public: ...@@ -93,8 +93,8 @@ public:
tmp_gradients[i] += std::fabs(gradients_[idx] * hessians_[idx]); tmp_gradients[i] += std::fabs(gradients_[idx] * hessians_[idx]);
} }
} }
data_size_t top_k = static_cast<data_size_t>(cnt * gbdt_config_->top_rate); data_size_t top_k = static_cast<data_size_t>(cnt * config_->top_rate);
data_size_t other_k = static_cast<data_size_t>(cnt * gbdt_config_->other_rate); data_size_t other_k = static_cast<data_size_t>(cnt * config_->other_rate);
top_k = std::max(1, top_k); top_k = std::max(1, top_k);
ArrayArgs<score_t>::ArgMaxAtK(&tmp_gradients, 0, static_cast<int>(tmp_gradients.size()), top_k - 1); ArrayArgs<score_t>::ArgMaxAtK(&tmp_gradients, 0, static_cast<int>(tmp_gradients.size()), top_k - 1);
score_t threshold = tmp_gradients[top_k - 1]; score_t threshold = tmp_gradients[top_k - 1];
...@@ -135,7 +135,7 @@ public: ...@@ -135,7 +135,7 @@ public:
void Bagging(int iter) override { void Bagging(int iter) override {
bag_data_cnt_ = num_data_; bag_data_cnt_ = num_data_;
// not subsample for first iterations // not subsample for first iterations
if (iter < static_cast<int>(1.0f / gbdt_config_->learning_rate)) { return; } if (iter < static_cast<int>(1.0f / config_->learning_rate)) { return; }
const data_size_t min_inner_size = 100; const data_size_t min_inner_size = 100;
data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_; data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
...@@ -150,7 +150,7 @@ public: ...@@ -150,7 +150,7 @@ public:
if (cur_start > num_data_) { continue; } if (cur_start > num_data_) { continue; }
data_size_t cur_cnt = inner_size; data_size_t cur_cnt = inner_size;
if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; } if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i); Random cur_rand(config_->bagging_seed + iter * num_threads_ + i);
data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt,
tmp_indices_.data() + cur_start, tmp_indice_right_.data() + cur_start); tmp_indices_.data() + cur_start, tmp_indice_right_.data() + cur_start);
offsets_buf_[i] = cur_start; offsets_buf_[i] = cur_start;
......
...@@ -24,10 +24,10 @@ public: ...@@ -24,10 +24,10 @@ public:
~RF() {} ~RF() {}
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f); CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
CHECK(config->tree_config.feature_fraction < 1.0f && config->tree_config.feature_fraction > 0.0f); CHECK(config->feature_fraction < 1.0f && config->feature_fraction > 0.0f);
GBDT::Init(config, train_data, objective_function, training_metrics); GBDT::Init(config, train_data, objective_function, training_metrics);
if (num_init_iteration_ > 0) { if (num_init_iteration_ > 0) {
...@@ -50,9 +50,9 @@ public: ...@@ -50,9 +50,9 @@ public:
} }
} }
void ResetConfig(const BoostingConfig* config) override { void ResetConfig(const Config* config) override {
CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f); CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
CHECK(config->tree_config.feature_fraction < 1.0f && config->tree_config.feature_fraction > 0.0f); CHECK(config->feature_fraction < 1.0f && config->feature_fraction > 0.0f);
GBDT::ResetConfig(config); GBDT::ResetConfig(config);
// not shrinkage rate for the RF // not shrinkage rate for the RF
shrinkage_rate_ = 1.0f; shrinkage_rate_ = 1.0f;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment