Commit 2e962c77 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix tests.

parent e179c7c6
...@@ -630,6 +630,7 @@ std::string GBDT::DumpModel(int num_iteration) const { ...@@ -630,6 +630,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
str_buf << "\"tree_info\":["; str_buf << "\"tree_info\":[";
int num_used_model = static_cast<int>(models_.size()); int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) { if (num_iteration > 0) {
num_iteration += boost_from_average_ ? 1 : 0;
num_used_model = std::min(num_iteration * num_class_, num_used_model); num_used_model = std::min(num_iteration * num_class_, num_used_model);
} }
for (int i = 0; i < num_used_model; ++i) { for (int i = 0; i < num_used_model; ++i) {
...@@ -648,7 +649,7 @@ std::string GBDT::DumpModel(int num_iteration) const { ...@@ -648,7 +649,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
return str_buf.str(); return str_buf.str();
} }
std::string GBDT::SaveModelToString(int num_iterations) const { std::string GBDT::SaveModelToString(int num_iteration) const {
std::stringstream ss; std::stringstream ss;
// output model type // output model type
...@@ -676,8 +677,9 @@ std::string GBDT::SaveModelToString(int num_iterations) const { ...@@ -676,8 +677,9 @@ std::string GBDT::SaveModelToString(int num_iterations) const {
ss << std::endl; ss << std::endl;
int num_used_model = static_cast<int>(models_.size()); int num_used_model = static_cast<int>(models_.size());
if (num_iterations > 0) { if (num_iteration > 0) {
num_used_model = std::min(num_iterations * num_class_, num_used_model); num_iteration += boost_from_average_ ? 1 : 0;
num_used_model = std::min(num_iteration * num_class_, num_used_model);
} }
// output tree models // output tree models
for (int i = 0; i < num_used_model; ++i) { for (int i = 0; i < num_used_model; ++i) {
......
...@@ -89,7 +89,7 @@ public: ...@@ -89,7 +89,7 @@ public:
*/ */
void RollbackOneIter() override; void RollbackOneIter() override;
int GetCurrentIteration() const override { return iter_ + num_init_iteration_; } int GetCurrentIteration() const override { return static_cast<int>(models_.size()) / num_class_; }
bool EvalAndCheckEarlyStopping() override; bool EvalAndCheckEarlyStopping() override;
......
...@@ -32,7 +32,7 @@ class template(object): ...@@ -32,7 +32,7 @@ class template(object):
@staticmethod @staticmethod
def test_template(params={'objective': 'regression', 'metric': 'l2'}, def test_template(params={'objective': 'regression', 'metric': 'l2'},
X_y=load_boston(True), feval=mean_squared_error, X_y=load_boston(True), feval=mean_squared_error,
num_round=150, init_model=None, custom_eval=None, num_round=200, init_model=None, custom_eval=None,
early_stopping_rounds=10, early_stopping_rounds=10,
return_data=False, return_model=False): return_data=False, return_model=False):
params['verbose'], params['seed'] = -1, 42 params['verbose'], params['seed'] = -1, 42
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment