Commit dd316895 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix name of boosting type

parent 76c44d78
......@@ -182,9 +182,9 @@ public:
virtual void SetNumIterationForPred(int num_iteration) = 0;
/*!
* \brief Get Type name of this boosting object
* \brief Name of submodel
*/
virtual const char* Name() const = 0;
virtual const char* SubModelName() const = 0;
Boosting() = default;
/*! \brief Disable copy */
......@@ -201,7 +201,7 @@ public:
* \param filename name of model file, if existing will continue to train from this model
* \return The boosting object
*/
static Boosting* CreateBoosting(BoostingType type, const char* filename);
static Boosting* CreateBoosting(const std::string& type, const char* filename);
/*!
* \brief Create boosting object from model file
......
......@@ -76,12 +76,6 @@ public:
static std::unordered_map<std::string, std::string> Str2Map(const char* parameters);
};
/*! \brief Types of boosting */
enum BoostingType {
kGBDT, kDART, kUnknow
};
/*! \brief Types of tasks */
enum TaskType {
kTrain, kPredict
......@@ -240,7 +234,7 @@ public:
bool is_parallel = false;
bool is_parallel_find_bin = false;
IOConfig io_config;
BoostingType boosting_type = BoostingType::kGBDT;
std::string boosting_type = "gbdt";
BoostingConfig boosting_config;
std::string objective_type = "regression";
ObjectiveConfig objective_config;
......
......@@ -190,13 +190,14 @@ void Application::InitTrain() {
Network::Init(config_.network_config);
Log::Info("Finished initializing network");
// sync global random seed for feature patition
if (config_.boosting_type == BoostingType::kGBDT || config_.boosting_type == BoostingType::kDART) {
config_.boosting_config.tree_config.feature_fraction_seed =
GlobalSyncUpByMin<int>(config_.boosting_config.tree_config.feature_fraction_seed);
config_.boosting_config.tree_config.feature_fraction =
GlobalSyncUpByMin<double>(config_.boosting_config.tree_config.feature_fraction);
}
config_.boosting_config.tree_config.feature_fraction_seed =
GlobalSyncUpByMin<int>(config_.boosting_config.tree_config.feature_fraction_seed);
config_.boosting_config.tree_config.feature_fraction =
GlobalSyncUpByMin<double>(config_.boosting_config.tree_config.feature_fraction);
config_.boosting_config.drop_seed =
GlobalSyncUpByMin<int>(config_.boosting_config.drop_seed);
}
// create boosting
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type,
......
......@@ -4,15 +4,10 @@
namespace LightGBM {
BoostingType GetBoostingTypeFromModelFile(const char* filename) {
std::string GetBoostingTypeFromModelFile(const char* filename) {
TextReader<size_t> model_reader(filename, true);
std::string type = model_reader.first_line();
if (type == std::string("gbdt")) {
return BoostingType::kGBDT;
} else if (type == std::string("dart")) {
return BoostingType::kDART;
}
return BoostingType::kUnknow;
return type;
}
void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
......@@ -27,11 +22,11 @@ void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
}
}
Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
if (filename == nullptr || filename[0] == '\0') {
if (type == BoostingType::kGBDT) {
if (type == std::string("gbdt")) {
return new GBDT();
} else if (type == BoostingType::kDART) {
} else if (type == std::string("dart")) {
return new DART();
} else {
return nullptr;
......@@ -39,15 +34,15 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
} else {
std::unique_ptr<Boosting> ret;
auto type_in_file = GetBoostingTypeFromModelFile(filename);
if (type_in_file == type) {
if (type == BoostingType::kGBDT) {
if (type_in_file == std::string("tree")) {
if (type == std::string("gbdt")) {
ret.reset(new GBDT());
} else if (type == BoostingType::kDART) {
} else if (type == std::string("dart")) {
ret.reset(new DART());
}
LoadFileToBoosting(ret.get(), filename);
} else {
Log::Fatal("Boosting type in parameter is not the same as the type in the model file");
Log::Fatal("unknow submodel type in model file %s", filename);
}
return ret.release();
}
......@@ -56,10 +51,10 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
Boosting* Boosting::CreateBoosting(const char* filename) {
auto type = GetBoostingTypeFromModelFile(filename);
std::unique_ptr<Boosting> ret;
if (type == BoostingType::kGBDT) {
if (type == std::string("tree")) {
ret.reset(new GBDT());
} else if (type == BoostingType::kDART) {
ret.reset(new DART());
} else {
Log::Fatal("unknow submodel type in model file %s", filename);
}
LoadFileToBoosting(ret.get(), filename);
return ret.release();
......
......@@ -72,11 +72,6 @@ public:
return train_score_updater_->score();
}
/*!
* \brief Get Type name of this boosting object
*/
const char* Name() const override { return "dart"; }
private:
/*!
* \brief drop trees based on drop_rate
......
......@@ -439,7 +439,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
std::stringstream str_buf;
str_buf << "{";
str_buf << "\"name\":\"" << Name() << "\"," << std::endl;
str_buf << "\"name\":\"" << SubModelName() << "\"," << std::endl;
str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
......@@ -481,7 +481,7 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
std::ofstream output_file;
output_file.open(filename);
// output model type
output_file << Name() << std::endl;
output_file << SubModelName() << std::endl;
// output number of class
output_file << "num_class=" << num_class_ << std::endl;
// output label index
......
......@@ -212,7 +212,7 @@ public:
/*!
* \brief Get Type name of this boosting object
*/
virtual const char* Name() const override { return "gbdt"; }
virtual const char* SubModelName() const override { return "tree"; }
protected:
/*!
......
......@@ -76,9 +76,9 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s
if (GetString(params, "boosting_type", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("gbdt") || value == std::string("gbrt")) {
boosting_type = BoostingType::kGBDT;
boosting_type = "gbdt";
} else if (value == std::string("dart")) {
boosting_type = BoostingType::kDART;
boosting_type = "dart";
} else {
Log::Fatal("Unknown boosting type %s", value.c_str());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment