Unverified Commit 8a5ec366 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Speed up saving and loading model (#1083)

* remove protobuf

* add version number

* remove pmml script

* use float for split gain

* fix warnings

* refine the read model logic of gbdt

* fix compile error

* improve decode speed

* fix some bugs

* fix double accuracy problem

* fix bug

* multi-thread save model

* speed up save model to string

* parallel save/load model

* fix some warnings.

* fix warnings.

* fix a bug

* remove debug output

* fix doc

* fix max_bin warning in tests.

* fix max_bin warning

* fix pylint

* clean code for stringToArray

* clean code for TToString

* remove max_bin

* replace "class" with typename
parent 8d016c12
......@@ -29,7 +29,7 @@ namespace LightGBM {
class Booster {
public:
explicit Booster(const char* filename) {
boosting_.reset(Boosting::CreateBoosting("gbdt", "text", filename));
boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
}
Booster(const Dataset* train_data,
......@@ -46,7 +46,7 @@ public:
please use continued train with input score");
}
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, "text", nullptr));
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
train_data_ = train_data;
CreateObjectiveAndMetrics();
......@@ -240,7 +240,8 @@ public:
}
void LoadModelFromString(const char* model_str) {
boosting_->LoadModelFromString(model_str);
size_t len = std::strlen(model_str);
boosting_->LoadModelFromString(model_str, len);
}
std::string SaveModelToString(int num_iteration) {
......
......@@ -269,7 +269,6 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetString(params, "input_model", &input_model);
GetString(params, "convert_model", &convert_model);
GetString(params, "output_result", &output_result);
GetString(params, "model_format", &model_format);
std::string tmp_str = "";
if (GetString(params, "valid_data", &tmp_str)) {
valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
......
......@@ -49,7 +49,7 @@ Tree::~Tree() {
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold_double, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type, bool default_left) {
int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
int new_node_idx = num_leaves_ - 1;
decision_type_[new_node_idx] = 0;
......@@ -70,7 +70,7 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type) {
data_size_t left_cnt, data_size_t right_cnt, float gain, MissingType missing_type) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
int new_node_idx = num_leaves_ - 1;
decision_type_[new_node_idx] = 0;
......@@ -207,49 +207,49 @@ void Tree::AddPredictionToScore(const Dataset* data,
std::string Tree::ToString() const {
std::stringstream str_buf;
str_buf << "num_leaves=" << num_leaves_ << std::endl;
str_buf << "num_cat=" << num_cat_ << std::endl;
str_buf << "num_leaves=" << num_leaves_ << '\n';
str_buf << "num_cat=" << num_cat_ << '\n';
str_buf << "split_feature="
<< Common::ArrayToString<int>(split_feature_, num_leaves_ - 1, ' ') << std::endl;
<< Common::ArrayToStringFast(split_feature_, num_leaves_ - 1) << '\n';
str_buf << "split_gain="
<< Common::ArrayToString<double>(split_gain_, num_leaves_ - 1, ' ') << std::endl;
<< Common::ArrayToStringFast(split_gain_, num_leaves_ - 1) << '\n';
str_buf << "threshold="
<< Common::ArrayToString<double>(threshold_, num_leaves_ - 1, ' ') << std::endl;
<< Common::ArrayToString(threshold_, num_leaves_ - 1) << '\n';
str_buf << "decision_type="
<< Common::ArrayToString<int>(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1, ' ') << std::endl;
<< Common::ArrayToStringFast(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1) << '\n';
str_buf << "left_child="
<< Common::ArrayToString<int>(left_child_, num_leaves_ - 1, ' ') << std::endl;
<< Common::ArrayToStringFast(left_child_, num_leaves_ - 1) << '\n';
str_buf << "right_child="
<< Common::ArrayToString<int>(right_child_, num_leaves_ - 1, ' ') << std::endl;
<< Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
str_buf << "leaf_value="
<< Common::ArrayToString<double>(leaf_value_, num_leaves_, ' ') << std::endl;
<< Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
str_buf << "leaf_count="
<< Common::ArrayToString<data_size_t>(leaf_count_, num_leaves_, ' ') << std::endl;
<< Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
str_buf << "internal_value="
<< Common::ArrayToString<double>(internal_value_, num_leaves_ - 1, ' ') << std::endl;
<< Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
str_buf << "internal_count="
<< Common::ArrayToString<data_size_t>(internal_count_, num_leaves_ - 1, ' ') << std::endl;
<< Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
if (num_cat_ > 0) {
str_buf << "cat_boundaries="
<< Common::ArrayToString<int>(cat_boundaries_, num_cat_ + 1, ' ') << std::endl;
<< Common::ArrayToStringFast(cat_boundaries_, num_cat_ + 1) << '\n';
str_buf << "cat_threshold="
<< Common::ArrayToString<uint32_t>(cat_threshold_, cat_threshold_.size(), ' ') << std::endl;
<< Common::ArrayToStringFast(cat_threshold_, cat_threshold_.size()) << '\n';
}
str_buf << "shrinkage=" << shrinkage_ << std::endl;
str_buf << std::endl;
str_buf << "shrinkage=" << shrinkage_ << '\n';
str_buf << '\n';
return str_buf.str();
}
std::string Tree::ToJSON() const {
std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
str_buf << "\"num_leaves\":" << num_leaves_ << "," << std::endl;
str_buf << "\"num_cat\":" << num_cat_ << "," << std::endl;
str_buf << "\"shrinkage\":" << shrinkage_ << "," << std::endl;
str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
if (num_leaves_ == 1) {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << std::endl;
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
} else {
str_buf << "\"tree_structure\":" << NodeToJSON(0) << std::endl;
str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
}
return str_buf.str();
......@@ -260,10 +260,10 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
if (index >= 0) {
// non-leaf
str_buf << "{" << std::endl;
str_buf << "\"split_index\":" << index << "," << std::endl;
str_buf << "\"split_feature\":" << split_feature_[index] << "," << std::endl;
str_buf << "\"split_gain\":" << split_gain_[index] << "," << std::endl;
str_buf << "{" << '\n';
str_buf << "\"split_index\":" << index << "," << '\n';
str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
str_buf << "\"split_gain\":" << split_gain_[index] << "," << '\n';
if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
int cat_idx = static_cast<int>(threshold_[index]);
std::vector<int> cats;
......@@ -276,37 +276,37 @@ std::string Tree::NodeToJSON(int index) const {
}
}
}
str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << std::endl;
str_buf << "\"decision_type\":\"==\"," << std::endl;
str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << '\n';
str_buf << "\"decision_type\":\"==\"," << '\n';
} else {
str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << std::endl;
str_buf << "\"decision_type\":\"<=\"," << std::endl;
str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << '\n';
str_buf << "\"decision_type\":\"<=\"," << '\n';
}
if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
str_buf << "\"default_left\":true," << std::endl;
str_buf << "\"default_left\":true," << '\n';
} else {
str_buf << "\"default_left\":false," << std::endl;
str_buf << "\"default_left\":false," << '\n';
}
uint8_t missing_type = GetMissingType(decision_type_[index]);
if (missing_type == 0) {
str_buf << "\"missing_type\":\"None\"," << std::endl;
str_buf << "\"missing_type\":\"None\"," << '\n';
} else if (missing_type == 1) {
str_buf << "\"missing_type\":\"Zero\"," << std::endl;
str_buf << "\"missing_type\":\"Zero\"," << '\n';
} else {
str_buf << "\"missing_type\":\"NaN\"," << std::endl;
str_buf << "\"missing_type\":\"NaN\"," << '\n';
}
str_buf << "\"internal_value\":" << internal_value_[index] << "," << std::endl;
str_buf << "\"internal_count\":" << internal_count_[index] << "," << std::endl;
str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << std::endl;
str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << std::endl;
str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
str_buf << "}";
} else {
// leaf
index = ~index;
str_buf << "{" << std::endl;
str_buf << "\"leaf_index\":" << index << "," << std::endl;
str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << std::endl;
str_buf << "\"leaf_count\":" << leaf_count_[index] << std::endl;
str_buf << "{" << '\n';
str_buf << "\"leaf_index\":" << index << "," << '\n';
str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
str_buf << "}";
}
......@@ -376,7 +376,7 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) const {
}
str_buf << NodeToIfElse(0, is_predict_leaf_index);
}
str_buf << " }" << std::endl;
str_buf << " }" << '\n';
//Predict func by Map to ifelse
str_buf << "double PredictTree" << index;
......@@ -404,7 +404,7 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) const {
}
str_buf << NodeToIfElseByMap(0, is_predict_leaf_index);
}
str_buf << " }" << std::endl;
str_buf << " }" << '\n';
return str_buf.str();
}
......@@ -471,19 +471,26 @@ std::string Tree::NodeToIfElseByMap(int index, bool is_predict_leaf_index) const
return str_buf.str();
}
Tree::Tree(const std::string& str) {
std::vector<std::string> lines = Common::SplitLines(str.c_str());
Tree::Tree(const char* str, size_t* used_len) {
auto p = str;
std::unordered_map<std::string, std::string> key_vals;
for (const std::string& line : lines) {
std::vector<std::string> tmp_strs = Common::Split(line.c_str(), '=');
if (tmp_strs.size() == 2) {
std::string key = Common::Trim(tmp_strs[0]);
std::string val = Common::Trim(tmp_strs[1]);
if (key.size() > 0 && val.size() > 0) {
key_vals[key] = val;
}
}
}
const int max_num_line = 15;
int read_line = 0;
while (read_line < max_num_line) {
if (*p == '\r' || *p == '\n') break;
auto start = p;
while (*p != '=') ++p;
std::string key(start, p - start);
++p;
start = p;
while (*p != '\r' && *p != '\n') ++p;
key_vals[key] = std::string(start, p - start);
++read_line;
if (*p == '\r') ++p;
if (*p == '\n') ++p;
}
*used_len = p - str;
if (key_vals.count("num_leaves") <= 0) {
Log::Fatal("Tree model should contain num_leaves field.");
}
......@@ -497,7 +504,7 @@ Tree::Tree(const std::string& str) {
Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);
if (key_vals.count("leaf_value")) {
leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], ' ', num_leaves_);
leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], num_leaves_);
} else {
Log::Fatal("Tree model string format error, should contain leaf_value field");
}
......@@ -505,68 +512,68 @@ Tree::Tree(const std::string& str) {
if (num_leaves_ <= 1) { return; }
if (key_vals.count("left_child")) {
left_child_ = Common::StringToArray<int>(key_vals["left_child"], ' ', num_leaves_ - 1);
left_child_ = Common::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
} else {
Log::Fatal("Tree model string format error, should contain left_child field");
}
if (key_vals.count("right_child")) {
right_child_ = Common::StringToArray<int>(key_vals["right_child"], ' ', num_leaves_ - 1);
right_child_ = Common::StringToArrayFast<int>(key_vals["right_child"], num_leaves_ - 1);
} else {
Log::Fatal("Tree model string format error, should contain right_child field");
}
if (key_vals.count("split_feature")) {
split_feature_ = Common::StringToArray<int>(key_vals["split_feature"], ' ', num_leaves_ - 1);
split_feature_ = Common::StringToArrayFast<int>(key_vals["split_feature"], num_leaves_ - 1);
} else {
Log::Fatal("Tree model string format error, should contain split_feature field");
}
if (key_vals.count("threshold")) {
threshold_ = Common::StringToArray<double>(key_vals["threshold"], ' ', num_leaves_ - 1);
threshold_ = Common::StringToArray<double>(key_vals["threshold"], num_leaves_ - 1);
} else {
Log::Fatal("Tree model string format error, should contain threshold field");
}
if (key_vals.count("split_gain")) {
split_gain_ = Common::StringToArray<double>(key_vals["split_gain"], ' ', num_leaves_ - 1);
split_gain_ = Common::StringToArrayFast<float>(key_vals["split_gain"], num_leaves_ - 1);
} else {
split_gain_.resize(num_leaves_ - 1);
}
if (key_vals.count("internal_count")) {
internal_count_ = Common::StringToArray<data_size_t>(key_vals["internal_count"], ' ', num_leaves_ - 1);
internal_count_ = Common::StringToArrayFast<int>(key_vals["internal_count"], num_leaves_ - 1);
} else {
internal_count_.resize(num_leaves_ - 1);
}
if (key_vals.count("internal_value")) {
internal_value_ = Common::StringToArray<double>(key_vals["internal_value"], ' ', num_leaves_ - 1);
internal_value_ = Common::StringToArrayFast<double>(key_vals["internal_value"], num_leaves_ - 1);
} else {
internal_value_.resize(num_leaves_ - 1);
}
if (key_vals.count("leaf_count")) {
leaf_count_ = Common::StringToArray<data_size_t>(key_vals["leaf_count"], ' ', num_leaves_);
leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
} else {
leaf_count_.resize(num_leaves_);
}
if (key_vals.count("decision_type")) {
decision_type_ = Common::StringToArray<int8_t>(key_vals["decision_type"], ' ', num_leaves_ - 1);
decision_type_ = Common::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
} else {
decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
}
if (num_cat_ > 0) {
if (key_vals.count("cat_boundaries")) {
cat_boundaries_ = Common::StringToArray<int>(key_vals["cat_boundaries"], ' ', num_cat_ + 1);
cat_boundaries_ = Common::StringToArrayFast<int>(key_vals["cat_boundaries"], num_cat_ + 1);
} else {
Log::Fatal("Tree model should contain cat_boundaries field.");
}
if (key_vals.count("cat_threshold")) {
cat_threshold_ = Common::StringToArray<uint32_t>(key_vals["cat_threshold"], ' ', cat_boundaries_.back());
cat_threshold_ = Common::StringToArrayFast<uint32_t>(key_vals["cat_threshold"], cat_boundaries_.back());
} else {
Log::Fatal("Tree model should contain cat_threshold field.");
}
......
#include "../boosting/gbdt.h"
#include <LightGBM/tree.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/objective_function.h>
#include <iostream>
#include <fstream>
namespace LightGBM {
void GBDT::SaveModelToProto(int num_iteration, const char* filename) const {
LightGBM::Model model;
model.set_name(SubModelName());
model.set_num_class(num_class_);
model.set_num_tree_per_iteration(num_tree_per_iteration_);
model.set_label_index(label_idx_);
model.set_max_feature_idx(max_feature_idx_);
if (objective_function_ != nullptr) {
model.set_objective(objective_function_->ToString());
}
model.set_average_output(average_output_);
for(auto feature_name: feature_names_) {
model.add_feature_names(feature_name);
}
for(auto feature_info: feature_infos_) {
model.add_feature_infos(feature_info);
}
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
}
for (int i = 0; i < num_used_model; ++i) {
models_[i]->ToProto(*model.add_trees());
}
std::filebuf fb;
fb.open(filename, std::ios::out | std::ios::binary);
std::ostream os(&fb);
if (!model.SerializeToOstream(&os)) {
Log::Fatal("Cannot serialize model to binary file.");
}
fb.close();
}
bool GBDT::LoadModelFromProto(const char* filename) {
models_.clear();
LightGBM::Model model;
std::filebuf fb;
if (fb.open(filename, std::ios::in | std::ios::binary))
{
std::istream is(&fb);
if (!model.ParseFromIstream(&is)) {
Log::Fatal("Cannot parse model from binary file.");
}
fb.close();
} else {
Log::Fatal("Cannot open file: %s.", filename);
}
num_class_ = model.num_class();
num_tree_per_iteration_ = model.num_tree_per_iteration();
label_idx_ = model.label_index();
max_feature_idx_ = model.max_feature_idx();
average_output_ = model.average_output();
feature_names_.reserve(model.feature_names_size());
for (auto feature_name: model.feature_names()) {
feature_names_.push_back(feature_name);
}
feature_infos_.reserve(model.feature_infos_size());
for (auto feature_info: model.feature_infos()) {
feature_infos_.push_back(feature_info);
}
loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(model.objective()));
objective_function_ = loaded_objective_.get();
for (auto tree: model.trees()) {
models_.emplace_back(new Tree(tree));
}
Log::Info("Finished loading %d models", models_.size());
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
num_init_iteration_ = num_iteration_for_pred_;
iter_ = 0;
return true;
}
void Tree::ToProto(LightGBM::Model_Tree& model_tree) const {
model_tree.set_num_leaves(num_leaves_);
model_tree.set_num_cat(num_cat_);
for (int i = 0; i < num_leaves_ - 1; ++i) {
model_tree.add_split_feature(split_feature_[i]);
model_tree.add_split_gain(split_gain_[i]);
model_tree.add_threshold(threshold_[i]);
model_tree.add_decision_type(decision_type_[i]);
model_tree.add_left_child(left_child_[i]);
model_tree.add_right_child(right_child_[i]);
model_tree.add_internal_value(internal_value_[i]);
model_tree.add_internal_count(internal_count_[i]);
}
for (int i = 0; i < num_leaves_; ++i) {
model_tree.add_leaf_value(leaf_value_[i]);
model_tree.add_leaf_count(leaf_count_[i]);
}
if (num_cat_ > 0) {
for (int i = 0; i < num_cat_ + 1; ++i) {
model_tree.add_cat_boundaries(cat_boundaries_[i]);
}
for (size_t i = 0; i < cat_threshold_.size(); ++i) {
model_tree.add_cat_threshold(cat_threshold_[i]);
}
}
model_tree.set_shrinkage(shrinkage_);
}
Tree::Tree(const LightGBM::Model_Tree& model_tree) {
num_leaves_ = model_tree.num_leaves();
if (num_leaves_ <= 1) { return; }
num_cat_ = model_tree.num_cat();
leaf_value_.reserve(model_tree.leaf_value_size());
for(auto leaf_value: model_tree.leaf_value()) {
leaf_value_.push_back(leaf_value);
}
left_child_.reserve(model_tree.left_child_size());
for(auto left_child: model_tree.left_child()) {
left_child_.push_back(left_child);
}
right_child_.reserve(model_tree.right_child_size());
for(auto right_child: model_tree.right_child()) {
right_child_.push_back(right_child);
}
split_feature_.reserve(model_tree.split_feature_size());
for(auto split_feature: model_tree.split_feature()) {
split_feature_.push_back(split_feature);
}
threshold_.reserve(model_tree.threshold_size());
for(auto threshold: model_tree.threshold()) {
threshold_.push_back(threshold);
}
split_gain_.reserve(model_tree.split_gain_size());
for(auto split_gain: model_tree.split_gain()) {
split_gain_.push_back(split_gain);
}
internal_count_.reserve(model_tree.internal_count_size());
for(auto internal_count: model_tree.internal_count()) {
internal_count_.push_back(internal_count);
}
internal_value_.reserve(model_tree.internal_value_size());
for(auto internal_value: model_tree.internal_value()) {
internal_value_.push_back(internal_value);
}
leaf_count_.reserve(model_tree.leaf_count_size());
for(auto leaf_count: model_tree.leaf_count()) {
leaf_count_.push_back(leaf_count);
}
decision_type_.reserve(model_tree.decision_type_size());
for(auto decision_type: model_tree.decision_type()) {
decision_type_.push_back(decision_type);
}
if (num_cat_ > 0) {
cat_boundaries_.reserve(model_tree.cat_boundaries_size());
for(auto cat_boundaries: model_tree.cat_boundaries()) {
cat_boundaries_.push_back(cat_boundaries);
}
cat_threshold_.reserve(model_tree.cat_threshold_size());
for(auto cat_threshold: model_tree.cat_threshold()) {
cat_threshold_.push_back(cat_threshold);
}
}
shrinkage_ = model_tree.shrinkage();
}
} // namespace LightGBM
......@@ -533,7 +533,7 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.gain),
static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
best_split_info.default_left);
data_partition_->Split(best_leaf, train_data_, inner_feature_index,
......@@ -556,7 +556,7 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.gain),
static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
data_partition_->Split(best_leaf, train_data_, inner_feature_index,
cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()), best_split_info.default_left, *right_leaf);
......
......@@ -15,7 +15,7 @@ class TestBasic(unittest.TestCase):
def test(self):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=2)
train_data = lgb.Dataset(X_train, max_bin=255, label=y_train)
train_data = lgb.Dataset(X_train, label=y_train)
valid_data = train_data.create_valid(X_test, label=y_test)
params = {
......@@ -24,7 +24,8 @@ class TestBasic(unittest.TestCase):
"min_data": 10,
"num_leaves": 15,
"verbose": -1,
"num_threads": 1
"num_threads": 1,
"max_bin": 255
}
bst = lgb.Booster(params, train_data)
bst.add_valid(valid_data, "valid_1")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment