Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
99c6b539
Commit
99c6b539
authored
Mar 24, 2017
by
Guolin Ke
Browse files
robust tree model loading.
parent
6757a6aa
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
23 deletions
+70
-23
src/io/tree.cpp
src/io/tree.cpp
+70
-23
No files found.
src/io/tree.cpp
View file @
99c6b539
...
...
@@ -376,37 +376,84 @@ Tree::Tree(const std::string& str) {
}
}
if
(
key_vals
.
count
(
"num_leaves"
)
<=
0
)
{
Log
::
Fatal
(
"Tree model s
tring format error
"
);
Log
::
Fatal
(
"Tree model s
hould contain num_leaves field.
"
);
}
Common
::
Atoi
(
key_vals
[
"num_leaves"
].
c_str
(),
&
num_leaves_
);
if
(
num_leaves_
<=
1
)
{
return
;
}
if
(
key_vals
.
count
(
"split_feature"
)
<=
0
||
key_vals
.
count
(
"split_gain"
)
<=
0
||
key_vals
.
count
(
"threshold"
)
<=
0
||
key_vals
.
count
(
"left_child"
)
<=
0
||
key_vals
.
count
(
"right_child"
)
<=
0
||
key_vals
.
count
(
"leaf_parent"
)
<=
0
||
key_vals
.
count
(
"leaf_value"
)
<=
0
||
key_vals
.
count
(
"internal_value"
)
<=
0
||
key_vals
.
count
(
"internal_count"
)
<=
0
||
key_vals
.
count
(
"leaf_count"
)
<=
0
||
key_vals
.
count
(
"shrinkage"
)
<=
0
||
key_vals
.
count
(
"decision_type"
)
<=
0
)
{
Log
::
Fatal
(
"Tree model string format error"
);
if
(
key_vals
.
count
(
"left_child"
))
{
left_child_
=
Common
::
StringToArray
<
int
>
(
key_vals
[
"left_child"
],
' '
,
num_leaves_
-
1
);
}
else
{
Log
::
Fatal
(
"Tree model string format error, should contain left_child field"
);
}
left_child_
=
Common
::
StringToArray
<
int
>
(
key_vals
[
"left_child"
],
' '
,
num_leaves_
-
1
);
if
(
key_vals
.
count
(
"right_child"
))
{
right_child_
=
Common
::
StringToArray
<
int
>
(
key_vals
[
"right_child"
],
' '
,
num_leaves_
-
1
);
}
else
{
Log
::
Fatal
(
"Tree model string format error, should contain right_child field"
);
}
if
(
key_vals
.
count
(
"split_feature"
))
{
split_feature_
=
Common
::
StringToArray
<
int
>
(
key_vals
[
"split_feature"
],
' '
,
num_leaves_
-
1
);
}
else
{
Log
::
Fatal
(
"Tree model string format error, should contain split_feature field"
);
}
if
(
key_vals
.
count
(
"threshold"
))
{
threshold_
=
Common
::
StringToArray
<
double
>
(
key_vals
[
"threshold"
],
' '
,
num_leaves_
-
1
);
decision_type_
=
Common
::
StringToArray
<
int8_t
>
(
key_vals
[
"decision_type"
],
' '
,
num_leaves_
-
1
);
}
else
{
Log
::
Fatal
(
"Tree model string format error, should contain threshold field"
);
}
if
(
key_vals
.
count
(
"leaf_value"
))
{
leaf_value_
=
Common
::
StringToArray
<
double
>
(
key_vals
[
"leaf_value"
],
' '
,
num_leaves_
);
}
else
{
Log
::
Fatal
(
"Tree model string format error, should contain leaf_value field"
);
}
if
(
key_vals
.
count
(
"split_gain"
))
{
split_gain_
=
Common
::
StringToArray
<
double
>
(
key_vals
[
"split_gain"
],
' '
,
num_leaves_
-
1
);
}
else
{
split_gain_
.
resize
(
num_leaves_
-
1
);
}
if
(
key_vals
.
count
(
"internal_count"
))
{
internal_count_
=
Common
::
StringToArray
<
data_size_t
>
(
key_vals
[
"internal_count"
],
' '
,
num_leaves_
-
1
);
}
else
{
internal_count_
.
resize
(
num_leaves_
-
1
);
}
if
(
key_vals
.
count
(
"internal_value"
))
{
internal_value_
=
Common
::
StringToArray
<
double
>
(
key_vals
[
"internal_value"
],
' '
,
num_leaves_
-
1
);
}
else
{
internal_value_
.
resize
(
num_leaves_
-
1
);
}
if
(
key_vals
.
count
(
"leaf_count"
))
{
leaf_count_
=
Common
::
StringToArray
<
data_size_t
>
(
key_vals
[
"leaf_count"
],
' '
,
num_leaves_
);
}
else
{
leaf_count_
.
resize
(
num_leaves_
);
}
if
(
key_vals
.
count
(
"leaf_parent"
))
{
leaf_parent_
=
Common
::
StringToArray
<
int
>
(
key_vals
[
"leaf_parent"
],
' '
,
num_leaves_
);
leaf_value_
=
Common
::
StringToArray
<
double
>
(
key_vals
[
"leaf_value"
],
' '
,
num_leaves_
);
}
else
{
leaf_parent_
.
resize
(
num_leaves_
);
}
if
(
key_vals
.
count
(
"decision_type"
))
{
decision_type_
=
Common
::
StringToArray
<
int8_t
>
(
key_vals
[
"decision_type"
],
' '
,
num_leaves_
-
1
);
}
else
{
decision_type_
=
std
::
vector
<
int8_t
>
(
num_leaves_
-
1
,
0
);
}
if
(
key_vals
.
count
(
"shrinkage"
))
{
Common
::
Atof
(
key_vals
[
"shrinkage"
].
c_str
(),
&
shrinkage_
);
}
else
{
shrinkage_
=
1.0
f
;
}
}
}
// namespace LightGBM
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment