"include/LightGBM/vscode:/vscode.git/clone" did not exist on "3a665f62e1c59d5ef80276fe533a0656ee9f8bf3"
Commit dd316895 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix name of boosting type

parent 76c44d78
...@@ -182,9 +182,9 @@ public: ...@@ -182,9 +182,9 @@ public:
virtual void SetNumIterationForPred(int num_iteration) = 0; virtual void SetNumIterationForPred(int num_iteration) = 0;
/*! /*!
* \brief Get Type name of this boosting object * \brief Name of submodel
*/ */
virtual const char* Name() const = 0; virtual const char* SubModelName() const = 0;
Boosting() = default; Boosting() = default;
/*! \brief Disable copy */ /*! \brief Disable copy */
...@@ -201,7 +201,7 @@ public: ...@@ -201,7 +201,7 @@ public:
* \param filename name of model file, if existing will continue to train from this model * \param filename name of model file, if existing will continue to train from this model
* \return The boosting object * \return The boosting object
*/ */
static Boosting* CreateBoosting(BoostingType type, const char* filename); static Boosting* CreateBoosting(const std::string& type, const char* filename);
/*! /*!
* \brief Create boosting object from model file * \brief Create boosting object from model file
......
...@@ -76,12 +76,6 @@ public: ...@@ -76,12 +76,6 @@ public:
static std::unordered_map<std::string, std::string> Str2Map(const char* parameters); static std::unordered_map<std::string, std::string> Str2Map(const char* parameters);
}; };
/*! \brief Types of boosting */
enum BoostingType {
kGBDT, kDART, kUnknow
};
/*! \brief Types of tasks */ /*! \brief Types of tasks */
enum TaskType { enum TaskType {
kTrain, kPredict kTrain, kPredict
...@@ -240,7 +234,7 @@ public: ...@@ -240,7 +234,7 @@ public:
bool is_parallel = false; bool is_parallel = false;
bool is_parallel_find_bin = false; bool is_parallel_find_bin = false;
IOConfig io_config; IOConfig io_config;
BoostingType boosting_type = BoostingType::kGBDT; std::string boosting_type = "gbdt";
BoostingConfig boosting_config; BoostingConfig boosting_config;
std::string objective_type = "regression"; std::string objective_type = "regression";
ObjectiveConfig objective_config; ObjectiveConfig objective_config;
......
...@@ -190,13 +190,14 @@ void Application::InitTrain() { ...@@ -190,13 +190,14 @@ void Application::InitTrain() {
Network::Init(config_.network_config); Network::Init(config_.network_config);
Log::Info("Finished initializing network"); Log::Info("Finished initializing network");
// sync global random seed for feature patition // sync global random seed for feature patition
if (config_.boosting_type == BoostingType::kGBDT || config_.boosting_type == BoostingType::kDART) {
config_.boosting_config.tree_config.feature_fraction_seed = config_.boosting_config.tree_config.feature_fraction_seed =
GlobalSyncUpByMin<int>(config_.boosting_config.tree_config.feature_fraction_seed); GlobalSyncUpByMin<int>(config_.boosting_config.tree_config.feature_fraction_seed);
config_.boosting_config.tree_config.feature_fraction = config_.boosting_config.tree_config.feature_fraction =
GlobalSyncUpByMin<double>(config_.boosting_config.tree_config.feature_fraction); GlobalSyncUpByMin<double>(config_.boosting_config.tree_config.feature_fraction);
config_.boosting_config.drop_seed =
GlobalSyncUpByMin<int>(config_.boosting_config.drop_seed);
} }
}
// create boosting // create boosting
boosting_.reset( boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type, Boosting::CreateBoosting(config_.boosting_type,
......
...@@ -4,15 +4,10 @@ ...@@ -4,15 +4,10 @@
namespace LightGBM { namespace LightGBM {
BoostingType GetBoostingTypeFromModelFile(const char* filename) { std::string GetBoostingTypeFromModelFile(const char* filename) {
TextReader<size_t> model_reader(filename, true); TextReader<size_t> model_reader(filename, true);
std::string type = model_reader.first_line(); std::string type = model_reader.first_line();
if (type == std::string("gbdt")) { return type;
return BoostingType::kGBDT;
} else if (type == std::string("dart")) {
return BoostingType::kDART;
}
return BoostingType::kUnknow;
} }
void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) { void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
...@@ -27,11 +22,11 @@ void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) { ...@@ -27,11 +22,11 @@ void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
} }
} }
Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) { Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
if (filename == nullptr || filename[0] == '\0') { if (filename == nullptr || filename[0] == '\0') {
if (type == BoostingType::kGBDT) { if (type == std::string("gbdt")) {
return new GBDT(); return new GBDT();
} else if (type == BoostingType::kDART) { } else if (type == std::string("dart")) {
return new DART(); return new DART();
} else { } else {
return nullptr; return nullptr;
...@@ -39,15 +34,15 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) { ...@@ -39,15 +34,15 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
} else { } else {
std::unique_ptr<Boosting> ret; std::unique_ptr<Boosting> ret;
auto type_in_file = GetBoostingTypeFromModelFile(filename); auto type_in_file = GetBoostingTypeFromModelFile(filename);
if (type_in_file == type) { if (type_in_file == std::string("tree")) {
if (type == BoostingType::kGBDT) { if (type == std::string("gbdt")) {
ret.reset(new GBDT()); ret.reset(new GBDT());
} else if (type == BoostingType::kDART) { } else if (type == std::string("dart")) {
ret.reset(new DART()); ret.reset(new DART());
} }
LoadFileToBoosting(ret.get(), filename); LoadFileToBoosting(ret.get(), filename);
} else { } else {
Log::Fatal("Boosting type in parameter is not the same as the type in the model file"); Log::Fatal("unknow submodel type in model file %s", filename);
} }
return ret.release(); return ret.release();
} }
...@@ -56,10 +51,10 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) { ...@@ -56,10 +51,10 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
Boosting* Boosting::CreateBoosting(const char* filename) { Boosting* Boosting::CreateBoosting(const char* filename) {
auto type = GetBoostingTypeFromModelFile(filename); auto type = GetBoostingTypeFromModelFile(filename);
std::unique_ptr<Boosting> ret; std::unique_ptr<Boosting> ret;
if (type == BoostingType::kGBDT) { if (type == std::string("tree")) {
ret.reset(new GBDT()); ret.reset(new GBDT());
} else if (type == BoostingType::kDART) { } else {
ret.reset(new DART()); Log::Fatal("unknow submodel type in model file %s", filename);
} }
LoadFileToBoosting(ret.get(), filename); LoadFileToBoosting(ret.get(), filename);
return ret.release(); return ret.release();
......
...@@ -72,11 +72,6 @@ public: ...@@ -72,11 +72,6 @@ public:
return train_score_updater_->score(); return train_score_updater_->score();
} }
/*!
* \brief Get Type name of this boosting object
*/
const char* Name() const override { return "dart"; }
private: private:
/*! /*!
* \brief drop trees based on drop_rate * \brief drop trees based on drop_rate
......
...@@ -439,7 +439,7 @@ std::string GBDT::DumpModel(int num_iteration) const { ...@@ -439,7 +439,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << "{"; str_buf << "{";
str_buf << "\"name\":\"" << Name() << "\"," << std::endl; str_buf << "\"name\":\"" << SubModelName() << "\"," << std::endl;
str_buf << "\"num_class\":" << num_class_ << "," << std::endl; str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
str_buf << "\"label_index\":" << label_idx_ << "," << std::endl; str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl; str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
...@@ -481,7 +481,7 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const { ...@@ -481,7 +481,7 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
std::ofstream output_file; std::ofstream output_file;
output_file.open(filename); output_file.open(filename);
// output model type // output model type
output_file << Name() << std::endl; output_file << SubModelName() << std::endl;
// output number of class // output number of class
output_file << "num_class=" << num_class_ << std::endl; output_file << "num_class=" << num_class_ << std::endl;
// output label index // output label index
......
...@@ -212,7 +212,7 @@ public: ...@@ -212,7 +212,7 @@ public:
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
*/ */
virtual const char* Name() const override { return "gbdt"; } virtual const char* SubModelName() const override { return "tree"; }
protected: protected:
/*! /*!
......
...@@ -76,9 +76,9 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s ...@@ -76,9 +76,9 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s
if (GetString(params, "boosting_type", &value)) { if (GetString(params, "boosting_type", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("gbdt") || value == std::string("gbrt")) { if (value == std::string("gbdt") || value == std::string("gbrt")) {
boosting_type = BoostingType::kGBDT; boosting_type = "gbdt";
} else if (value == std::string("dart")) { } else if (value == std::string("dart")) {
boosting_type = BoostingType::kDART; boosting_type = "dart";
} else { } else {
Log::Fatal("Unknown boosting type %s", value.c_str()); Log::Fatal("Unknown boosting type %s", value.c_str());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment