Unverified Commit dc699574 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Refine config object (#1381)

* [WIP] refine config

* [wip] ready for the auto code generate

* auto generate config codes

* use with to open file

* fix bug

* fix pylint

* fix bug

* fix pylint

* fix bugs.

* tmp for failed test.

* fix tests.

* added nthreads alias

* added new aliases from new config.h

* fixed duplicated alias

* refactored parameter_generator.py

* added new aliases from config.h and removed remaining old names

* fix bugs & some miss alias

* added aliases

* add more descriptions.

* add comment.
parent 497e60ed
......@@ -32,7 +32,7 @@ Core Parameters
- **Note**: Only can be used in CLI version
- ``task``, default=\ ``train``, type=enum, options=\ ``train``, ``predict``, ``convert_model``, ``refit``
- ``task``, default=\ ``train``, type=enum, options=\ ``train``, ``predict``, ``convert_model``, ``refit``, alias=\ ``task_type``
- ``train``, alias=\ ``training``, for training
......@@ -47,7 +47,7 @@ Core Parameters
- ``application``, default=\ ``regression``, type=enum,
options=\ ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gammma``, ``tweedie``,
``binary``, ``multiclass``, ``multiclassova``, ``xentropy``, ``xentlambda``, ``lambdarank``,
alias=\ ``objective``, ``app``
alias=\ ``app``, ``objective``, ``objective_type``
- regression application
......@@ -107,11 +107,11 @@ Core Parameters
- ``goss``, Gradient-based One-Side Sampling
- ``data``, default=\ ``""``, type=string, alias=\ ``train``, ``train_data``
- ``data``, default=\ ``""``, type=string, alias=\ ``train``, ``train_data``, ``data_filename``
- training data, LightGBM will train from this data
- ``valid``, default=\ ``""``, type=multi-string, alias=\ ``test``, ``valid_data``, ``test_data``
- ``valid``, default=\ ``""``, type=multi-string, alias=\ ``test``, ``valid_data``, ``test_data``, ``valid_filenames``
- validation/test data, LightGBM will output metrics for these data
......@@ -137,7 +137,7 @@ Core Parameters
- number of leaves in one tree
- ``tree_learner``, default=\ ``serial``, type=enum, options=\ ``serial``, ``feature``, ``data``, ``voting``, alias=\ ``tree``
- ``tree_learner``, default=\ ``serial``, type=enum, options=\ ``serial``, ``feature``, ``data``, ``voting``, alias=\ ``tree``, ``tree_learner_type``
- ``serial``, single machine tree learner
......@@ -149,7 +149,7 @@ Core Parameters
- refer to `Parallel Learning Guide <./Parallel-Learning-Guide.rst>`__ to get more details
- ``num_threads``, default=\ ``OpenMP_default``, type=int, alias=\ ``num_thread``, ``nthread``
- ``num_threads``, default=\ ``OpenMP_default``, type=int, alias=\ ``num_thread``, ``nthread``, ``nthreads``
- number of threads for LightGBM
......@@ -204,7 +204,7 @@ Learning Control Parameters
- random seed for ``feature_fraction``
- ``bagging_fraction``, default=\ ``1.0``, type=double, ``0.0 < bagging_fraction <= 1.0``, alias=\ ``sub_row``, ``subsample``
- ``bagging_fraction``, default=\ ``1.0``, type=double, ``0.0 < bagging_fraction <= 1.0``, alias=\ ``sub_row``, ``subsample``, ``bagging``
- like ``feature_fraction``, but this will randomly select part of data without resampling
......@@ -312,7 +312,7 @@ Learning Control Parameters
- set this to larger value for more accurate result, but it will slow down the training speed
- ``monotone_constraint``, default=\ ``None``, type=multi-int, alias=\ ``mc``
- ``monotone_constraint``, default=\ ``None``, type=multi-int, alias=\ ``mc``, ``monotone_constraints``
- used for constraints of monotonic features
......@@ -443,7 +443,7 @@ IO Parameters
- **Note**: the negative values will be treated as **missing values**
- ``predict_raw_score``, default=\ ``false``, type=bool, alias=\ ``raw_score``, ``is_predict_raw_score``
- ``predict_raw_score``, default=\ ``false``, type=bool, alias=\ ``raw_score``, ``is_predict_raw_score``, ``predict_rawscore``
- only used in ``prediction`` task
......@@ -501,17 +501,17 @@ IO Parameters
- set to ``false`` to use ``na`` to represent missing values
- ``init_score_file``, default=\ ``""``, type=string
- ``init_score_file``, default=\ ``""``, type=string, alias=\ ``init_score_filename``, ``initscore_filename``, ``init_score``
- path to training initial score file, ``""`` will use ``train_data_file`` + ``.init`` (if exists)
- ``valid_init_score_file``, default=\ ``""``, type=multi-string
- ``valid_init_score_file``, default=\ ``""``, type=multi-string, alias=\ ``valid_data_initscores``, ``valid_data_init_scores``, ``valid_init_score``
- path to validation initial score file, ``""`` will use ``valid_data_file`` + ``.init`` (if exists)
- separate by ``,`` for multi-validation data
- ``forced_splits``, default=\ ``""``, type=string
- ``forced_splits``, default=\ ``""``, type=string, alias=\ ``forced_splits_file``, ``forcedsplits_filename``, ``forced_splits_filename``
- path to a ``.json`` file that specifies splits to force at the top of every decision tree before best-first learning commences
......@@ -593,7 +593,7 @@ Objective Parameters
Metric Parameters
-----------------
- ``metric``, default=\ ``''``, type=multi-enum
- ``metric``, default=\ ``''``, type=multi-enum, alias=\ ``metric_types``
- metric to be evaluated on the evaluation sets **in addition** to what is provided in the training arguments
......@@ -650,7 +650,7 @@ Metric Parameters
- frequency for metric output
- ``train_metric``, default=\ ``false``, type=bool, alias=\ ``training_metric``, ``is_training_metric``
- ``train_metric``, default=\ ``false``, type=bool, alias=\ ``training_metric``, ``is_training_metric``, ``is_provide_training_metric``
- set this to ``true`` if you need to output metric result of training
......
import os
def GetParameterInfos(config_hpp):
is_inparameter = False
parameter_group = None
cur_key = None
cur_info = {}
keys = []
member_infos = []
with open(config_hpp) as config_hpp_file:
for line in config_hpp_file:
if "#pragma region Parameters" in line:
is_inparameter = True
elif "#pragma region" in line and "Parameters" in line:
cur_key = line.split("region")[1].strip()
keys.append(cur_key)
member_infos.append([])
elif '#pragma endregion' in line:
if cur_key is not None:
cur_key = None
elif is_inparameter:
is_inparameter = False
elif cur_key is not None:
line = line.strip()
if line.startswith("//"):
tokens = line.split("//")[1].split("=")
key = tokens[0].strip()
val = '='.join(tokens[1:]).strip()
if key not in cur_info:
if key == "descl2":
cur_info["desc"] = []
else:
cur_info[key] = []
if key == "desc":
cur_info["desc"].append(["l1", val])
elif key == "descl2":
cur_info["desc"].append(["l2", val])
else:
cur_info[key].append(val)
elif line:
has_eqsgn = False
tokens = line.split("=")
if len(tokens) == 2:
if "default" not in cur_info:
cur_info["default"] = [tokens[1][:-1].strip()]
has_eqsgn = True
tokens = line.split()
cur_info["inner_type"] = [tokens[0].strip()]
if "name" not in cur_info:
if has_eqsgn:
cur_info["name"] = [tokens[1].strip()]
else:
cur_info["name"] = [tokens[1][:-1].strip()]
member_infos[-1].append(cur_info)
cur_info = {}
return (keys, member_infos)
def GetNames(infos):
names = []
for x in infos:
for y in x:
names.append(y["name"][0])
return names
def GetAlias(infos):
pairs = []
for x in infos:
for y in x:
if "alias" in y:
name = y["name"][0]
alias = y["alias"][0].split(',')
for name2 in alias:
pairs.append([name2.strip(), name])
return pairs
def SetOneVarFromString(name, type, checks):
ret = ""
univar_mapper = {"int": "GetInt", "double": "GetDouble", "bool": "GetBool", "std::string": "GetString"}
if "vector" not in type:
ret += " %s(params, \"%s\", &%s);\n" % (univar_mapper[type], name, name)
if len(checks) > 0:
for check in checks:
ret += " CHECK(%s %s);\n" % (name, check)
ret += "\n"
else:
ret += " if (GetString(params, \"%s\", &tmp_str)) {\n" % (name)
type2 = type.split("<")[1][:-1]
if type2 == "std::string":
ret += " %s = Common::Split(tmp_str.c_str(), ',');\n" % (name)
else:
ret += " %s = Common::StringToArray<%s>(tmp_str, ',');\n" % (name, type2)
ret += " }\n\n"
return ret
def GenParameterCode(config_hpp, config_out_cpp):
keys, infos = GetParameterInfos(config_hpp)
names = GetNames(infos)
alias = GetAlias(infos)
str_to_write = "/// This file is auto generated by LightGBM\\helper\\parameter_generator.py\n"
str_to_write += "#include<LightGBM/config.h>\nnamespace LightGBM {\n"
# alias table
str_to_write += "std::unordered_map<std::string, std::string> Config::alias_table({\n"
for pair in alias:
str_to_write += " {\"%s\", \"%s\"}, \n" % (pair[0], pair[1])
str_to_write += "});\n\n"
# names
str_to_write += "std::unordered_set<std::string> Config::parameter_set({\n"
for name in names:
str_to_write += " \"%s\", \n" % (name)
str_to_write += "});\n\n"
# from strings
str_to_write += "void Config::GetMembersFromString(const std::unordered_map<std::string, std::string>& params) {\n"
str_to_write += " std::string tmp_str = \"\";\n"
for x in infos:
for y in x:
if "[doc-only]" in y:
continue
type = y["inner_type"][0]
name = y["name"][0]
checks = []
if "check" in y:
checks = y["check"]
tmp = SetOneVarFromString(name, type, checks)
str_to_write += tmp
# tails
str_to_write += "}\n\n"
str_to_write += "std::string Config::SaveMembersToString() const {\n"
str_to_write += " std::stringstream str_buf;\n"
for x in infos:
for y in x:
if "[doc-only]" in y:
continue
type = y["inner_type"][0]
name = y["name"][0]
if "vector" in type:
if "int8" in type:
str_to_write += " str_buf << \"[%s: \" << Common::Join(Common::ArrayCast<int8_t, int>(%s),\",\") << \"]\\n\";\n" % (name, name)
else:
str_to_write += " str_buf << \"[%s: \" << Common::Join(%s,\",\") << \"]\\n\";\n" % (name, name)
else:
str_to_write += " str_buf << \"[%s: \" << %s << \"]\\n\";\n" % (name, name)
# tails
str_to_write += " return str_buf.str();\n"
str_to_write += "}\n\n"
str_to_write += "}\n"
with open(config_out_cpp, "w") as config_out_cpp_file:
config_out_cpp_file.write(str_to_write)
if __name__ == "__main__":
config_hpp = os.path.join(os.path.pardir, 'include', 'LightGBM', 'config.h')
config_out_cpp = os.path.join(os.path.pardir, 'src', 'io', 'config_auto.cpp')
GenParameterCode(config_hpp, config_out_cpp)
......@@ -56,7 +56,7 @@ private:
void ConvertModel();
/*! \brief All configs */
OverallConfig config_;
Config config_;
/*! \brief Training data */
std::unique_ptr<Dataset> train_data_;
/*! \brief Validation data */
......@@ -73,10 +73,10 @@ private:
inline void Application::Run() {
if (config_.task_type == TaskType::kPredict || config_.task_type == TaskType::KRefitTree) {
if (config_.task == TaskType::kPredict || config_.task == TaskType::KRefitTree) {
InitPredict();
Predict();
} else if (config_.task_type == TaskType::kConvertModel) {
} else if (config_.task == TaskType::kConvertModel) {
ConvertModel();
} else {
InitTrain();
......
......@@ -32,7 +32,7 @@ public:
* \param training_metrics Training metric
*/
virtual void Init(
const BoostingConfig* config,
const Config* config,
const Dataset* train_data,
const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0;
......@@ -47,7 +47,7 @@ public:
virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0;
virtual void ResetConfig(const BoostingConfig* config) = 0;
virtual void ResetConfig(const Config* config) = 0;
......
This diff is collapsed.
......@@ -273,7 +273,7 @@ public:
* \param label_idx index of label column
* \return Object of parser
*/
static Parser* CreateParser(const char* filename, bool has_header, int num_features, int label_idx);
static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx);
};
/*! \brief The main class of data set,
......@@ -292,7 +292,7 @@ public:
int** sample_non_zero_indices,
const int* num_per_col,
size_t total_sample_cnt,
const IOConfig& io_config);
const Config& io_config);
/*! \brief Destructor */
LIGHTGBM_EXPORT ~Dataset();
......
......@@ -8,7 +8,7 @@ namespace LightGBM {
class DatasetLoader {
public:
LIGHTGBM_EXPORT DatasetLoader(const IOConfig& io_config, const PredictFunction& predict_fun, int num_class, const char* filename);
LIGHTGBM_EXPORT DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename);
LIGHTGBM_EXPORT ~DatasetLoader();
......@@ -54,7 +54,7 @@ private:
/*! \brief Check can load from binary file */
std::string CheckCanLoadFromBin(const char* filename);
const IOConfig& io_config_;
const Config& config_;
/*! \brief Random generator*/
Random random_;
/*! \brief prediction function for initial model */
......
......@@ -47,7 +47,7 @@ public:
* \param type Specific type of metric
* \param config Config for metric
*/
LIGHTGBM_EXPORT static Metric* CreateMetric(const std::string& type, const MetricConfig& config);
LIGHTGBM_EXPORT static Metric* CreateMetric(const std::string& type, const Config& config);
};
......@@ -56,11 +56,14 @@ public:
*/
class DCGCalculator {
public:
static void DefaultEvalAt(std::vector<int>* eval_at);
static void DefaultLabelGain(std::vector<double>* label_gain);
/*!
* \brief Initial logic
* \param label_gain Gain for labels, default is 2^i - 1
*/
static void Init(std::vector<double> label_gain);
static void Init(const std::vector<double>& label_gain);
/*!
* \brief Calculate the DCG score at position k
......
......@@ -89,7 +89,7 @@ public:
* \brief Initialize
* \param config Config of network setting
*/
static void Init(NetworkConfig config);
static void Init(Config config);
/*!
* \brief Initialize
*/
......
......@@ -71,7 +71,7 @@ public:
* \param config Config for objective function
*/
LIGHTGBM_EXPORT static ObjectiveFunction* CreateObjectiveFunction(const std::string& type,
const ObjectiveConfig& config);
const Config& config);
/*!
* \brief Load objective function from string object
......
......@@ -170,7 +170,7 @@ public:
std::string ToJSON() const;
/*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool is_predict_leaf_index) const;
std::string ToIfElse(int index, bool predict_leaf_index) const;
inline static bool IsZero(double fval) {
if (fval > -kZeroThreshold && fval <= kZeroThreshold) {
......@@ -307,9 +307,9 @@ private:
std::string NodeToJSON(int index) const;
/*! \brief Serialize one node to if-else statement*/
std::string NodeToIfElse(int index, bool is_predict_leaf_index) const;
std::string NodeToIfElse(int index, bool predict_leaf_index) const;
std::string NodeToIfElseByMap(int index, bool is_predict_leaf_index) const;
std::string NodeToIfElseByMap(int index, bool predict_leaf_index) const;
double ExpectedValue() const;
......
......@@ -36,9 +36,9 @@ public:
/*!
* \brief Reset tree configs
* \param tree_config config of tree
* \param config config of tree
*/
virtual void ResetConfig(const TreeConfig* tree_config) = 0;
virtual void ResetConfig(const Config* config) = 0;
/*!
* \brief training tree model on dataset
......@@ -85,11 +85,11 @@ public:
* \brief Create object of tree learner
* \param learner_type Type of tree learner
* \param device_type Type of tree learner
* \param tree_config config of tree
* \param config config of tree
*/
static TreeLearner* CreateTreeLearner(const std::string& learner_type,
const std::string& device_type,
const TreeConfig* tree_config);
const Config* config);
};
} // namespace LightGBM
......
......@@ -33,7 +33,7 @@ Application::Application(int argc, char** argv) {
if (config_.num_threads > 0) {
omp_set_num_threads(config_.num_threads);
}
if (config_.io_config.data_filename.size() == 0 && config_.task_type != TaskType::kConvertModel) {
if (config_.data.size() == 0 && config_.task != TaskType::kConvertModel) {
Log::Fatal("No training/prediction data, application quit");
}
omp_set_nested(0);
......@@ -48,13 +48,13 @@ Application::~Application() {
void Application::LoadParameters(int argc, char** argv) {
std::unordered_map<std::string, std::string> params;
for (int i = 1; i < argc; ++i) {
ConfigBase::KV2Map(params, argv[i]);
Config::KV2Map(params, argv[i]);
}
// check for alias
ParameterAlias::KeyAliasTransform(&params);
// read parameters from config file
if (params.count("config_file") > 0) {
TextReader<size_t> config_reader(params["config_file"].c_str(), false);
if (params.count("config") > 0) {
TextReader<size_t> config_reader(params["config"].c_str(), false);
config_reader.ReadAllLines();
if (!config_reader.Lines().empty()) {
for (auto& line : config_reader.Lines()) {
......@@ -66,11 +66,11 @@ void Application::LoadParameters(int argc, char** argv) {
if (line.size() == 0) {
continue;
}
ConfigBase::KV2Map(params, line.c_str());
Config::KV2Map(params, line.c_str());
}
} else {
Log::Warning("Config file %s doesn't exist, will ignore",
params["config_file"].c_str());
params["config"].c_str());
}
}
// check for alias again
......@@ -87,37 +87,37 @@ void Application::LoadData() {
PredictFunction predict_fun = nullptr;
PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
// need to continue training
if (boosting_->NumberOfTotalModel() > 0 && config_.task_type != TaskType::KRefitTree) {
if (boosting_->NumberOfTotalModel() > 0 && config_.task != TaskType::KRefitTree) {
predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, false, -1, -1));
predict_fun = predictor->GetPredictFunction();
}
// sync up random seed for data partition
if (config_.is_parallel_find_bin) {
config_.io_config.data_random_seed = Network::GlobalSyncUpByMin(config_.io_config.data_random_seed);
config_.data_random_seed = Network::GlobalSyncUpByMin(config_.data_random_seed);
}
DatasetLoader dataset_loader(config_.io_config, predict_fun,
config_.boosting_config.num_class, config_.io_config.data_filename.c_str());
DatasetLoader dataset_loader(config_, predict_fun,
config_.num_class, config_.data.c_str());
// load Training data
if (config_.is_parallel_find_bin) {
// load data for parallel training
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(),
config_.io_config.initscore_filename.c_str(),
train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(),
config_.initscore_filename.c_str(),
Network::rank(), Network::num_machines()));
} else {
// load data for single machine
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), config_.io_config.initscore_filename.c_str(),
train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(), config_.initscore_filename.c_str(),
0, 1));
}
// need save binary file
if (config_.io_config.is_save_binary_file) {
if (config_.save_binary) {
train_data_->SaveBinaryFile(nullptr);
}
// create training metric
if (config_.boosting_config.is_provide_training_metric) {
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (config_.is_provide_training_metric) {
for (auto metric_type : config_.metric) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
if (metric == nullptr) { continue; }
metric->Init(train_data_->metadata(), train_data_->num_data());
train_metric_.push_back(std::move(metric));
......@@ -126,28 +126,28 @@ void Application::LoadData() {
train_metric_.shrink_to_fit();
if (!config_.metric_types.empty()) {
if (!config_.metric.empty()) {
// only when have metrics then need to construct validation data
// Add validation data, if it exists
for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) {
for (size_t i = 0; i < config_.valid.size(); ++i) {
// add
auto new_dataset = std::unique_ptr<Dataset>(
dataset_loader.LoadFromFileAlignWithOtherDataset(
config_.io_config.valid_data_filenames[i].c_str(),
config_.io_config.valid_data_initscores[i].c_str(),
config_.valid[i].c_str(),
config_.valid_data_initscores[i].c_str(),
train_data_.get())
);
valid_datas_.push_back(std::move(new_dataset));
// need save binary file
if (config_.io_config.is_save_binary_file) {
if (config_.save_binary) {
valid_datas_.back()->SaveBinaryFile(nullptr);
}
// add metric for validation data
valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
for (auto metric_type : config_.metric) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
if (metric == nullptr) { continue; }
metric->Init(valid_datas_.back()->metadata(),
valid_datas_.back()->num_data());
......@@ -167,30 +167,30 @@ void Application::LoadData() {
void Application::InitTrain() {
if (config_.is_parallel) {
// need init network
Network::Init(config_.network_config);
Network::Init(config_);
Log::Info("Finished initializing network");
config_.boosting_config.tree_config.feature_fraction_seed =
Network::GlobalSyncUpByMin(config_.boosting_config.tree_config.feature_fraction_seed);
config_.boosting_config.tree_config.feature_fraction =
Network::GlobalSyncUpByMin(config_.boosting_config.tree_config.feature_fraction);
config_.boosting_config.drop_seed =
Network::GlobalSyncUpByMin(config_.boosting_config.drop_seed);
config_.feature_fraction_seed =
Network::GlobalSyncUpByMin(config_.feature_fraction_seed);
config_.feature_fraction =
Network::GlobalSyncUpByMin(config_.feature_fraction);
config_.drop_seed =
Network::GlobalSyncUpByMin(config_.drop_seed);
}
// create boosting
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.input_model.c_str()));
Boosting::CreateBoosting(config_.boosting,
config_.input_model.c_str()));
// create objective function
objective_fun_.reset(
ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
ObjectiveFunction::CreateObjectiveFunction(config_.objective,
config_));
// load training data
LoadData();
// initialize the objective function
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
// initialize the boosting
boosting_->Init(&config_.boosting_config, train_data_.get(), objective_fun_.get(),
boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
// add validation data into boosting
for (size_t i = 0; i < valid_datas_.size(); ++i) {
......@@ -202,22 +202,22 @@ void Application::InitTrain() {
void Application::Train() {
Log::Info("Started training...");
boosting_->Train(config_.io_config.snapshot_freq, config_.io_config.output_model);
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
boosting_->Train(config_.snapshot_freq, config_.output_model);
boosting_->SaveModelToFile(-1, config_.output_model.c_str());
// convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
}
Log::Info("Finished training");
}
void Application::Predict() {
if (config_.task_type == TaskType::KRefitTree) {
if (config_.task == TaskType::KRefitTree) {
// create predictor
Predictor predictor(boosting_.get(), -1, false, true, false, false, 1, 1);
predictor.Predict(config_.io_config.data_filename.c_str(), config_.io_config.output_result.c_str(), config_.io_config.has_header);
TextReader<int> result_reader(config_.io_config.output_result.c_str(), false);
predictor.Predict(config_.data.c_str(), config_.output_result.c_str(), config_.header);
TextReader<int> result_reader(config_.output_result.c_str(), false);
result_reader.ReadAllLines();
std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size());
#pragma omp parallel for schedule(static)
......@@ -226,41 +226,41 @@ void Application::Predict() {
// Free memory
result_reader.Lines()[i].clear();
}
DatasetLoader dataset_loader(config_.io_config, nullptr,
config_.boosting_config.num_class, config_.io_config.data_filename.c_str());
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), config_.io_config.initscore_filename.c_str(),
DatasetLoader dataset_loader(config_, nullptr,
config_.num_class, config_.data.c_str());
train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(), config_.initscore_filename.c_str(),
0, 1));
train_metric_.clear();
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
config_));
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
boosting_->Init(&config_.boosting_config, train_data_.get(), objective_fun_.get(),
boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
boosting_->RefitTree(pred_leaf);
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
boosting_->SaveModelToFile(-1, config_.output_model.c_str());
Log::Info("Finished RefitTree");
} else {
// create predictor
Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
config_.io_config.is_predict_leaf_index, config_.io_config.is_predict_contrib,
config_.io_config.pred_early_stop, config_.io_config.pred_early_stop_freq,
config_.io_config.pred_early_stop_margin);
predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header);
Predictor predictor(boosting_.get(), config_.num_iteration_predict, config_.predict_raw_score,
config_.predict_leaf_index, config_.predict_contrib,
config_.pred_early_stop, config_.pred_early_stop_freq,
config_.pred_early_stop_margin);
predictor.Predict(config_.data.c_str(),
config_.output_result.c_str(), config_.header);
Log::Info("Finished prediction");
}
}
void Application::InitPredict() {
boosting_.reset(
Boosting::CreateBoosting("gbdt", config_.io_config.input_model.c_str()));
Boosting::CreateBoosting("gbdt", config_.input_model.c_str()));
Log::Info("Finished initializing prediction, total used %d iterations", boosting_->GetCurrentIteration());
}
void Application::ConvertModel() {
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type, config_.io_config.input_model.c_str()));
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
Boosting::CreateBoosting(config_.boosting, config_.input_model.c_str()));
boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
}
......
......@@ -29,11 +29,11 @@ public:
* \param boosting Input boosting model
* \param num_iteration Number of boosting round
* \param is_raw_score True if need to predict result with raw score
* \param is_predict_leaf_index True to output leaf index instead of prediction score
* \param is_predict_contrib True to output feature contributions instead of prediction score
* \param predict_leaf_index True to output leaf index instead of prediction score
* \param predict_contrib True to output feature contributions instead of prediction score
*/
Predictor(Boosting* boosting, int num_iteration,
bool is_raw_score, bool is_predict_leaf_index, bool is_predict_contrib,
bool is_raw_score, bool predict_leaf_index, bool predict_contrib,
bool early_stop, int early_stop_freq, double early_stop_margin) {
early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
......@@ -55,14 +55,14 @@ public:
{
num_threads_ = omp_get_num_threads();
}
boosting->InitPredict(num_iteration, is_predict_contrib);
boosting->InitPredict(num_iteration, predict_contrib);
boosting_ = boosting;
num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index, is_predict_contrib);
num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, predict_leaf_index, predict_contrib);
num_feature_ = boosting_->MaxFeatureIdx() + 1;
predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f));
const int kFeatureThreshold = 100000;
const size_t KSparseThreshold = static_cast<size_t>(0.01 * num_feature_);
if (is_predict_leaf_index) {
if (predict_leaf_index) {
predict_fun_ = [this, kFeatureThreshold, KSparseThreshold](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num();
if (num_feature_ > kFeatureThreshold && features.size() < KSparseThreshold) {
......@@ -75,7 +75,7 @@ public:
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
}
};
} else if (is_predict_contrib) {
} else if (predict_contrib) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features);
......@@ -127,27 +127,27 @@ public:
* \param data_filename Filename of data
* \param result_filename Filename of output result
*/
void Predict(const char* data_filename, const char* result_filename, bool has_header) {
void Predict(const char* data_filename, const char* result_filename, bool header) {
auto writer = VirtualFileWriter::Make(result_filename);
if (!writer->Init()) {
Log::Fatal("Prediction results file %s cannot be found", result_filename);
}
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx()));
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx()));
if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename);
}
TextReader<data_size_t> predict_data_reader(data_filename, has_header);
TextReader<data_size_t> predict_data_reader(data_filename, header);
std::unordered_map<int, int> feature_names_map_;
bool need_adjust = false;
if (has_header) {
if (header) {
std::string first_line = predict_data_reader.first_line();
std::vector<std::string> header = Common::Split(first_line.c_str(), "\t,");
header.erase(header.begin() + boosting_->LabelIdx());
for (int i = 0; i < static_cast<int>(header.size()); ++i) {
std::vector<std::string> header_words = Common::Split(first_line.c_str(), "\t,");
header_words.erase(header_words.begin() + boosting_->LabelIdx());
for (int i = 0; i < static_cast<int>(header_words.size()); ++i) {
for (int j = 0; j < static_cast<int>(boosting_->FeatureNames().size()); ++j) {
if (header[i] == boosting_->FeatureNames()[j]) {
if (header_words[i] == boosting_->FeatureNames()[j]) {
feature_names_map_[i] = j;
break;
}
......
......@@ -32,17 +32,17 @@ public:
* \param training_metrics Training metrics
* \param output_model_filename Filename of output model
*/
void Init(const BoostingConfig* config, const Dataset* train_data,
void Init(const Config* config, const Dataset* train_data,
const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, objective_function, training_metrics);
random_for_drop_ = Random(gbdt_config_->drop_seed);
random_for_drop_ = Random(config_->drop_seed);
sum_weight_ = 0.0f;
}
void ResetConfig(const BoostingConfig* config) override {
void ResetConfig(const Config* config) override {
GBDT::ResetConfig(config);
random_for_drop_ = Random(gbdt_config_->drop_seed);
random_for_drop_ = Random(config_->drop_seed);
sum_weight_ = 0.0f;
}
......@@ -57,7 +57,7 @@ public:
}
// normalize
Normalize();
if (!gbdt_config_->uniform_drop) {
if (!config_->uniform_drop) {
tree_weight_.push_back(shrinkage_rate_);
sum_weight_ += shrinkage_rate_;
}
......@@ -85,31 +85,31 @@ private:
*/
void DroppingTrees() {
drop_index_.clear();
bool is_skip = random_for_drop_.NextFloat() < gbdt_config_->skip_drop;
bool is_skip = random_for_drop_.NextFloat() < config_->skip_drop;
// select dropping tree indices based on drop_rate and tree weights
if (!is_skip) {
double drop_rate = gbdt_config_->drop_rate;
if (!gbdt_config_->uniform_drop) {
double drop_rate = config_->drop_rate;
if (!config_->uniform_drop) {
double inv_average_weight = static_cast<double>(tree_weight_.size()) / sum_weight_;
if (gbdt_config_->max_drop > 0) {
drop_rate = std::min(drop_rate, gbdt_config_->max_drop * inv_average_weight / sum_weight_);
if (config_->max_drop > 0) {
drop_rate = std::min(drop_rate, config_->max_drop * inv_average_weight / sum_weight_);
}
for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextFloat() < drop_rate * tree_weight_[i] * inv_average_weight) {
drop_index_.push_back(num_init_iteration_ + i);
if (drop_index_.size() >= static_cast<size_t>(gbdt_config_->max_drop)) {
if (drop_index_.size() >= static_cast<size_t>(config_->max_drop)) {
break;
}
}
}
} else {
if (gbdt_config_->max_drop > 0) {
drop_rate = std::min(drop_rate, gbdt_config_->max_drop / static_cast<double>(iter_));
if (config_->max_drop > 0) {
drop_rate = std::min(drop_rate, config_->max_drop / static_cast<double>(iter_));
}
for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextFloat() < drop_rate) {
drop_index_.push_back(num_init_iteration_ + i);
if (drop_index_.size() >= static_cast<size_t>(gbdt_config_->max_drop)) {
if (drop_index_.size() >= static_cast<size_t>(config_->max_drop)) {
break;
}
}
......@@ -124,13 +124,13 @@ private:
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
}
}
if (!gbdt_config_->xgboost_dart_mode) {
shrinkage_rate_ = gbdt_config_->learning_rate / (1.0f + static_cast<double>(drop_index_.size()));
if (!config_->xgboost_dart_mode) {
shrinkage_rate_ = config_->learning_rate / (1.0f + static_cast<double>(drop_index_.size()));
} else {
if (drop_index_.empty()) {
shrinkage_rate_ = gbdt_config_->learning_rate;
shrinkage_rate_ = config_->learning_rate;
} else {
shrinkage_rate_ = gbdt_config_->learning_rate / (gbdt_config_->learning_rate + static_cast<double>(drop_index_.size()));
shrinkage_rate_ = config_->learning_rate / (config_->learning_rate + static_cast<double>(drop_index_.size()));
}
}
}
......@@ -146,7 +146,7 @@ private:
*/
void Normalize() {
double k = static_cast<double>(drop_index_.size());
if (!gbdt_config_->xgboost_dart_mode) {
if (!config_->xgboost_dart_mode) {
for (auto i : drop_index_) {
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id;
......@@ -159,7 +159,7 @@ private:
models_[curr_tree]->Shrinkage(-k);
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
}
if (!gbdt_config_->uniform_drop) {
if (!config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f));
tree_weight_[i] *= (k / (k + 1.0f));
}
......@@ -174,12 +174,12 @@ private:
score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
}
// update training score
models_[curr_tree]->Shrinkage(-k / gbdt_config_->learning_rate);
models_[curr_tree]->Shrinkage(-k / config_->learning_rate);
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
}
if (!gbdt_config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + gbdt_config_->learning_rate));;
tree_weight_[i] *= (k / (k + gbdt_config_->learning_rate));
if (!config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + config_->learning_rate));;
tree_weight_[i] *= (k / (k + config_->learning_rate));
}
}
}
......
......@@ -61,7 +61,7 @@ GBDT::~GBDT() {
#endif
}
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
void GBDT::Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) {
CHECK(train_data != nullptr);
CHECK(train_data->num_features() > 0);
......@@ -70,9 +70,9 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
num_iteration_for_pred_ = 0;
max_feature_idx_ = 0;
num_class_ = config->num_class;
gbdt_config_ = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
config_ = std::unique_ptr<Config>(new Config(*config));
early_stopping_round_ = config_->early_stopping_round;
shrinkage_rate_ = config_->learning_rate;
std::string forced_splits_path = config->forcedsplits_filename;
//load forced_splits file
......@@ -93,7 +93,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
is_constant_hessian_ = false;
}
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->device_type, &gbdt_config_->tree_config));
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(config_->tree_learner, config_->device_type, config_.get()));
// init tree learner
tree_learner_->Init(train_data_, is_constant_hessian_);
......@@ -123,7 +123,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
feature_infos_ = train_data_->feature_infos();
// if need bagging, create buffer
ResetBaggingConfig(gbdt_config_.get(), true);
ResetBaggingConfig(config_.get(), true);
// reset config for tree learner
class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
......@@ -214,7 +214,7 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t
if (cnt <= 0) {
return 0;
}
data_size_t bag_data_cnt = static_cast<data_size_t>(gbdt_config_->bagging_fraction * cnt);
data_size_t bag_data_cnt = static_cast<data_size_t>(config_->bagging_fraction * cnt);
data_size_t cur_left_cnt = 0;
data_size_t cur_right_cnt = 0;
auto right_buffer = buffer + bag_data_cnt;
......@@ -233,7 +233,7 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t
void GBDT::Bagging(int iter) {
// if need bagging
if ((bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0)
if ((bag_data_cnt_ < num_data_ && iter % config_->bagging_freq == 0)
|| need_re_bagging_) {
need_re_bagging_ = false;
const data_size_t min_inner_size = 1000;
......@@ -249,7 +249,7 @@ void GBDT::Bagging(int iter) {
if (cur_start > num_data_) { continue; }
data_size_t cur_cnt = inner_size;
if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i);
Random cur_rand(config_->bagging_seed + iter * num_threads_ + i);
data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt, tmp_indices_.data() + cur_start);
offsets_buf_[i] = cur_start;
left_cnts_buf_[i] = cur_left_count;
......@@ -318,7 +318,7 @@ double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj) {
void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
bool is_finished = false;
auto start_time = std::chrono::steady_clock::now();
for (int iter = 0; iter < gbdt_config_->num_iterations && !is_finished; ++iter) {
for (int iter = 0; iter < config_->num_iterations && !is_finished; ++iter) {
is_finished = TrainOneIter(nullptr, nullptr);
if (!is_finished) {
is_finished = EvalAndCheckEarlyStopping();
......@@ -364,7 +364,7 @@ double GBDT::BoostFromAverage() {
if (models_.empty() && !train_score_updater_->has_init_score()
&& num_class_ <= 1
&& objective_function_ != nullptr) {
if (gbdt_config_->boost_from_average) {
if (config_->boost_from_average) {
double init_score = ObtainAutomaticInitialScore(objective_function_);
if (std::fabs(init_score) > kEpsilon) {
train_score_updater_->AddScore(init_score, 0);
......@@ -580,7 +580,7 @@ std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* scor
}
std::string GBDT::OutputMetric(int iter) {
bool need_output = (iter % gbdt_config_->output_freq) == 0;
bool need_output = (iter % config_->metric_freq) == 0;
std::string ret = "";
std::stringstream msg_buf;
std::vector<std::pair<size_t, size_t>> meet_early_stopping_pairs;
......@@ -777,24 +777,24 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
feature_infos_ = train_data_->feature_infos();
tree_learner_->ResetTrainingData(train_data);
ResetBaggingConfig(gbdt_config_.get(), true);
ResetBaggingConfig(config_.get(), true);
}
}
void GBDT::ResetConfig(const BoostingConfig* config) {
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
void GBDT::ResetConfig(const Config* config) {
auto new_config = std::unique_ptr<Config>(new Config(*config));
early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate;
if (tree_learner_ != nullptr) {
tree_learner_->ResetConfig(&new_config->tree_config);
tree_learner_->ResetConfig(new_config.get());
}
if (train_data_ != nullptr) {
ResetBaggingConfig(new_config.get(), false);
}
gbdt_config_.reset(new_config.release());
config_.reset(new_config.release());
}
void GBDT::ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset) {
void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
// if need bagging, create buffer
if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
bag_data_cnt_ =
......
......@@ -43,7 +43,7 @@ public:
* \param objective_function Training objective function
* \param training_metrics Training metrics
*/
void Init(const BoostingConfig* gbdt_config, const Dataset* train_data,
void Init(const Config* gbdt_config, const Dataset* train_data,
const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override;
......@@ -83,7 +83,7 @@ public:
* \brief Reset Boosting Config
* \param gbdt_config Config for boosting
*/
void ResetConfig(const BoostingConfig* gbdt_config) override;
void ResetConfig(const Config* gbdt_config) override;
/*!
* \brief Adding a validation dataset
......@@ -335,7 +335,7 @@ protected:
/*!
* \brief reset config for bagging
*/
void ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset);
void ResetBaggingConfig(const Config* config, bool is_change_dataset);
/*!
* \brief Implement bagging logic
......@@ -384,7 +384,7 @@ protected:
/*! \brief Pointer to training data */
const Dataset* train_data_;
/*! \brief Config of gbdt */
std::unique_ptr<BoostingConfig> gbdt_config_;
std::unique_ptr<Config> config_;
/*! \brief Tree learner, will use this class to learn trees */
std::unique_ptr<TreeLearner> tree_learner_;
/*! \brief Objective function */
......
......@@ -300,6 +300,10 @@ std::string GBDT::SaveModelToString(int num_iteration) const {
for (size_t i = 0; i < pairs.size(); ++i) {
ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n';
}
if (config_ != nullptr) {
ss << "parameters:" << '\n';
ss << config_->ToString() << "\n";
}
return ss.str();
}
......
......@@ -39,7 +39,7 @@ public:
#endif
}
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, objective_function, training_metrics);
ResetGoss();
......@@ -51,15 +51,15 @@ public:
ResetGoss();
}
void ResetConfig(const BoostingConfig* config) override {
void ResetConfig(const Config* config) override {
GBDT::ResetConfig(config);
ResetGoss();
}
void ResetGoss() {
CHECK(gbdt_config_->top_rate + gbdt_config_->other_rate <= 1.0f);
CHECK(gbdt_config_->top_rate > 0.0f && gbdt_config_->other_rate > 0.0f);
if (gbdt_config_->bagging_freq > 0 && gbdt_config_->bagging_fraction != 1.0f) {
CHECK(config_->top_rate + config_->other_rate <= 1.0f);
CHECK(config_->top_rate > 0.0f && config_->other_rate > 0.0f);
if (config_->bagging_freq > 0 && config_->bagging_fraction != 1.0f) {
Log::Fatal("Cannot use bagging in GOSS");
}
Log::Info("Using GOSS");
......@@ -74,8 +74,8 @@ public:
right_write_pos_buf_.resize(num_threads_);
is_use_subset_ = false;
if (gbdt_config_->top_rate + gbdt_config_->other_rate <= 0.5) {
auto bag_data_cnt = static_cast<data_size_t>((gbdt_config_->top_rate + gbdt_config_->other_rate) * num_data_);
if (config_->top_rate + config_->other_rate <= 0.5) {
auto bag_data_cnt = static_cast<data_size_t>((config_->top_rate + config_->other_rate) * num_data_);
bag_data_cnt = std::max(1, bag_data_cnt);
tmp_subset_.reset(new Dataset(bag_data_cnt));
tmp_subset_->CopyFeatureMapperFrom(train_data_);
......@@ -93,8 +93,8 @@ public:
tmp_gradients[i] += std::fabs(gradients_[idx] * hessians_[idx]);
}
}
data_size_t top_k = static_cast<data_size_t>(cnt * gbdt_config_->top_rate);
data_size_t other_k = static_cast<data_size_t>(cnt * gbdt_config_->other_rate);
data_size_t top_k = static_cast<data_size_t>(cnt * config_->top_rate);
data_size_t other_k = static_cast<data_size_t>(cnt * config_->other_rate);
top_k = std::max(1, top_k);
ArrayArgs<score_t>::ArgMaxAtK(&tmp_gradients, 0, static_cast<int>(tmp_gradients.size()), top_k - 1);
score_t threshold = tmp_gradients[top_k - 1];
......@@ -135,7 +135,7 @@ public:
void Bagging(int iter) override {
bag_data_cnt_ = num_data_;
// not subsample for first iterations
if (iter < static_cast<int>(1.0f / gbdt_config_->learning_rate)) { return; }
if (iter < static_cast<int>(1.0f / config_->learning_rate)) { return; }
const data_size_t min_inner_size = 100;
data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
......@@ -150,7 +150,7 @@ public:
if (cur_start > num_data_) { continue; }
data_size_t cur_cnt = inner_size;
if (cur_start + cur_cnt > num_data_) { cur_cnt = num_data_ - cur_start; }
Random cur_rand(gbdt_config_->bagging_seed + iter * num_threads_ + i);
Random cur_rand(config_->bagging_seed + iter * num_threads_ + i);
data_size_t cur_left_count = BaggingHelper(cur_rand, cur_start, cur_cnt,
tmp_indices_.data() + cur_start, tmp_indice_right_.data() + cur_start);
offsets_buf_[i] = cur_start;
......
......@@ -24,10 +24,10 @@ public:
~RF() {}
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
CHECK(config->tree_config.feature_fraction < 1.0f && config->tree_config.feature_fraction > 0.0f);
CHECK(config->feature_fraction < 1.0f && config->feature_fraction > 0.0f);
GBDT::Init(config, train_data, objective_function, training_metrics);
if (num_init_iteration_ > 0) {
......@@ -50,9 +50,9 @@ public:
}
}
void ResetConfig(const BoostingConfig* config) override {
void ResetConfig(const Config* config) override {
CHECK(config->bagging_freq > 0 && config->bagging_fraction < 1.0f && config->bagging_fraction > 0.0f);
CHECK(config->tree_config.feature_fraction < 1.0f && config->tree_config.feature_fraction > 0.0f);
CHECK(config->feature_fraction < 1.0f && config->feature_fraction > 0.0f);
GBDT::ResetConfig(config);
// not shrinkage rate for the RF
shrinkage_rate_ = 1.0f;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment