Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
2e962c77
Commit
2e962c77
authored
Mar 23, 2017
by
Guolin Ke
Browse files
fix tests.
parent
e179c7c6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
5 deletions
+7
-5
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+5
-3
src/boosting/gbdt.h
src/boosting/gbdt.h
+1
-1
tests/python_package_test/test_engine.py
tests/python_package_test/test_engine.py
+1
-1
No files found.
src/boosting/gbdt.cpp
View file @
2e962c77
...
...
@@ -630,6 +630,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
str_buf
<<
"
\"
tree_info
\"
:["
;
int
num_used_model
=
static_cast
<
int
>
(
models_
.
size
());
if
(
num_iteration
>
0
)
{
num_iteration
+=
boost_from_average_
?
1
:
0
;
num_used_model
=
std
::
min
(
num_iteration
*
num_class_
,
num_used_model
);
}
for
(
int
i
=
0
;
i
<
num_used_model
;
++
i
)
{
...
...
@@ -648,7 +649,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
return
str_buf
.
str
();
}
std
::
string
GBDT
::
SaveModelToString
(
int
num_iteration
s
)
const
{
std
::
string
GBDT
::
SaveModelToString
(
int
num_iteration
)
const
{
std
::
stringstream
ss
;
// output model type
...
...
@@ -676,8 +677,9 @@ std::string GBDT::SaveModelToString(int num_iterations) const {
ss
<<
std
::
endl
;
int
num_used_model
=
static_cast
<
int
>
(
models_
.
size
());
if
(
num_iterations
>
0
)
{
num_used_model
=
std
::
min
(
num_iterations
*
num_class_
,
num_used_model
);
if
(
num_iteration
>
0
)
{
num_iteration
+=
boost_from_average_
?
1
:
0
;
num_used_model
=
std
::
min
(
num_iteration
*
num_class_
,
num_used_model
);
}
// output tree models
for
(
int
i
=
0
;
i
<
num_used_model
;
++
i
)
{
...
...
src/boosting/gbdt.h
View file @
2e962c77
...
...
@@ -89,7 +89,7 @@ public:
*/
void
RollbackOneIter
()
override
;
int
GetCurrentIteration
()
const
override
{
return
iter_
+
num_init_iteration
_
;
}
int
GetCurrentIteration
()
const
override
{
return
static_cast
<
int
>
(
models_
.
size
())
/
num_class
_
;
}
bool
EvalAndCheckEarlyStopping
()
override
;
...
...
tests/python_package_test/test_engine.py
View file @
2e962c77
...
...
@@ -32,7 +32,7 @@ class template(object):
@
staticmethod
def
test_template
(
params
=
{
'objective'
:
'regression'
,
'metric'
:
'l2'
},
X_y
=
load_boston
(
True
),
feval
=
mean_squared_error
,
num_round
=
15
0
,
init_model
=
None
,
custom_eval
=
None
,
num_round
=
20
0
,
init_model
=
None
,
custom_eval
=
None
,
early_stopping_rounds
=
10
,
return_data
=
False
,
return_model
=
False
):
params
[
'verbose'
],
params
[
'seed'
]
=
-
1
,
42
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment