"python-package/vscode:/vscode.git/clone" did not exist on "4f77bd2860913a9a210c72a367fc88b4a9c13b3d"
Commit 99c6b539 authored by Guolin Ke's avatar Guolin Ke
Browse files

robust tree model loading.

parent 6757a6aa
......@@ -376,37 +376,84 @@ Tree::Tree(const std::string& str) {
}
}
if (key_vals.count("num_leaves") <= 0) {
Log::Fatal("Tree model string format error");
Log::Fatal("Tree model should contain num_leaves field.");
}
Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);
if (num_leaves_ <= 1) { return; }
if (key_vals.count("split_feature") <= 0
|| key_vals.count("split_gain") <= 0 || key_vals.count("threshold") <= 0
|| key_vals.count("left_child") <= 0 || key_vals.count("right_child") <= 0
|| key_vals.count("leaf_parent") <= 0 || key_vals.count("leaf_value") <= 0
|| key_vals.count("internal_value") <= 0 || key_vals.count("internal_count") <= 0
|| key_vals.count("leaf_count") <= 0 || key_vals.count("shrinkage") <= 0
|| key_vals.count("decision_type") <= 0
) {
Log::Fatal("Tree model string format error");
if (key_vals.count("left_child")) {
left_child_ = Common::StringToArray<int>(key_vals["left_child"], ' ', num_leaves_ - 1);
} else {
Log::Fatal("Tree model string format error, should contain left_child field");
}
left_child_ = Common::StringToArray<int>(key_vals["left_child"], ' ', num_leaves_ - 1);
right_child_ = Common::StringToArray<int>(key_vals["right_child"], ' ', num_leaves_ - 1);
split_feature_ = Common::StringToArray<int>(key_vals["split_feature"], ' ', num_leaves_ - 1);
threshold_ = Common::StringToArray<double>(key_vals["threshold"], ' ', num_leaves_ - 1);
decision_type_ = Common::StringToArray<int8_t>(key_vals["decision_type"], ' ', num_leaves_ - 1);
split_gain_ = Common::StringToArray<double>(key_vals["split_gain"], ' ', num_leaves_ - 1);
internal_count_ = Common::StringToArray<data_size_t>(key_vals["internal_count"], ' ', num_leaves_ - 1);
internal_value_ = Common::StringToArray<double>(key_vals["internal_value"], ' ', num_leaves_ - 1);
leaf_count_ = Common::StringToArray<data_size_t>(key_vals["leaf_count"], ' ', num_leaves_);
leaf_parent_ = Common::StringToArray<int>(key_vals["leaf_parent"], ' ', num_leaves_);
leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], ' ', num_leaves_);
Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
if (key_vals.count("right_child")) {
right_child_ = Common::StringToArray<int>(key_vals["right_child"], ' ', num_leaves_ - 1);
} else {
Log::Fatal("Tree model string format error, should contain right_child field");
}
if (key_vals.count("split_feature")) {
split_feature_ = Common::StringToArray<int>(key_vals["split_feature"], ' ', num_leaves_ - 1);
} else {
Log::Fatal("Tree model string format error, should contain split_feature field");
}
if (key_vals.count("threshold")) {
threshold_ = Common::StringToArray<double>(key_vals["threshold"], ' ', num_leaves_ - 1);
} else {
Log::Fatal("Tree model string format error, should contain threshold field");
}
if (key_vals.count("leaf_value")) {
leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], ' ', num_leaves_);
} else {
Log::Fatal("Tree model string format error, should contain leaf_value field");
}
if (key_vals.count("split_gain")) {
split_gain_ = Common::StringToArray<double>(key_vals["split_gain"], ' ', num_leaves_ - 1);
} else {
split_gain_.resize(num_leaves_ - 1);
}
if (key_vals.count("internal_count")) {
internal_count_ = Common::StringToArray<data_size_t>(key_vals["internal_count"], ' ', num_leaves_ - 1);
} else {
internal_count_.resize(num_leaves_ - 1);
}
if (key_vals.count("internal_value")) {
internal_value_ = Common::StringToArray<double>(key_vals["internal_value"], ' ', num_leaves_ - 1);
} else {
internal_value_.resize(num_leaves_ - 1);
}
if (key_vals.count("leaf_count")) {
leaf_count_ = Common::StringToArray<data_size_t>(key_vals["leaf_count"], ' ', num_leaves_);
} else {
leaf_count_.resize(num_leaves_);
}
if (key_vals.count("leaf_parent")) {
leaf_parent_ = Common::StringToArray<int>(key_vals["leaf_parent"], ' ', num_leaves_);
} else {
leaf_parent_.resize(num_leaves_);
}
if (key_vals.count("decision_type")) {
decision_type_ = Common::StringToArray<int8_t>(key_vals["decision_type"], ' ', num_leaves_ - 1);
} else {
decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
}
if (key_vals.count("shrinkage")) {
Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
} else {
shrinkage_ = 1.0f;
}
}
} // namespace LightGBM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment