Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
cc83cd67
Commit
cc83cd67
authored
Aug 20, 2017
by
Guolin Ke
Committed by
GitHub
Aug 20, 2017
Browse files
support constant tree (one-leaf tree) (#851)
parent
ecc8b8cd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
25 deletions
+20
-25
include/LightGBM/tree.h
include/LightGBM/tree.h
+7
-1
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+2
-3
src/boosting/rf.hpp
src/boosting/rf.hpp
+1
-2
src/io/tree.cpp
src/io/tree.cpp
+10
-19
No files found.
include/LightGBM/tree.h
View file @
cc83cd67
...
@@ -167,6 +167,12 @@ public:
...
@@ -167,6 +167,12 @@ public:
shrinkage_
*=
rate
;
shrinkage_
*=
rate
;
}
}
inline
void
AsConstantTree
(
double
val
)
{
num_leaves_
=
1
;
shrinkage_
=
1.0
f
;
leaf_value_
[
0
]
=
val
;
}
/*! \brief Serialize this object to string*/
/*! \brief Serialize this object to string*/
std
::
string
ToString
();
std
::
string
ToString
();
...
@@ -425,7 +431,7 @@ inline double Tree::Predict(const double* feature_values) const {
...
@@ -425,7 +431,7 @@ inline double Tree::Predict(const double* feature_values) const {
int
leaf
=
GetLeaf
(
feature_values
);
int
leaf
=
GetLeaf
(
feature_values
);
return
LeafOutput
(
leaf
);
return
LeafOutput
(
leaf
);
}
else
{
}
else
{
return
0.0
f
;
return
leaf_value_
[
0
]
;
}
}
}
}
...
...
src/boosting/gbdt.cpp
View file @
cc83cd67
...
@@ -473,7 +473,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
...
@@ -473,7 +473,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
auto
label
=
train_data_
->
metadata
().
label
();
auto
label
=
train_data_
->
metadata
().
label
();
double
init_score
=
ObtainAutomaticInitialScore
(
objective_function_
,
label
,
num_data_
);
double
init_score
=
ObtainAutomaticInitialScore
(
objective_function_
,
label
,
num_data_
);
std
::
unique_ptr
<
Tree
>
new_tree
(
new
Tree
(
2
));
std
::
unique_ptr
<
Tree
>
new_tree
(
new
Tree
(
2
));
new_tree
->
Split
(
0
,
0
,
0
,
0
,
0
,
init_score
,
init_score
,
0
,
0
,
-
1
,
MissingType
::
None
,
tru
e
);
new_tree
->
AsConstantTree
(
init_scor
e
);
train_score_updater_
->
AddScore
(
init_score
,
0
);
train_score_updater_
->
AddScore
(
init_score
,
0
);
for
(
auto
&
score_updater
:
valid_score_updater_
)
{
for
(
auto
&
score_updater
:
valid_score_updater_
)
{
score_updater
->
AddScore
(
init_score
,
0
);
score_updater
->
AddScore
(
init_score
,
0
);
...
@@ -553,8 +553,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
...
@@ -553,8 +553,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// only add default score one-time
// only add default score one-time
if
(
!
class_need_train_
[
cur_tree_id
]
&&
models_
.
size
()
<
static_cast
<
size_t
>
(
num_tree_per_iteration_
))
{
if
(
!
class_need_train_
[
cur_tree_id
]
&&
models_
.
size
()
<
static_cast
<
size_t
>
(
num_tree_per_iteration_
))
{
auto
output
=
class_default_output_
[
cur_tree_id
];
auto
output
=
class_default_output_
[
cur_tree_id
];
new_tree
->
Split
(
0
,
0
,
0
,
0
,
0
,
new_tree
->
AsConstantTree
(
output
);
output
,
output
,
0
,
0
,
-
1
,
MissingType
::
None
,
true
);
train_score_updater_
->
AddScore
(
output
,
cur_tree_id
);
train_score_updater_
->
AddScore
(
output
,
cur_tree_id
);
for
(
auto
&
score_updater
:
valid_score_updater_
)
{
for
(
auto
&
score_updater
:
valid_score_updater_
)
{
score_updater
->
AddScore
(
output
,
cur_tree_id
);
score_updater
->
AddScore
(
output
,
cur_tree_id
);
...
...
src/boosting/rf.hpp
View file @
cc83cd67
...
@@ -127,8 +127,7 @@ public:
...
@@ -127,8 +127,7 @@ public:
if
(
!
class_need_train_
[
cur_tree_id
]
&&
models_
.
size
()
<
static_cast
<
size_t
>
(
num_tree_per_iteration_
))
{
if
(
!
class_need_train_
[
cur_tree_id
]
&&
models_
.
size
()
<
static_cast
<
size_t
>
(
num_tree_per_iteration_
))
{
double
output
=
class_default_output_
[
cur_tree_id
];
double
output
=
class_default_output_
[
cur_tree_id
];
objective_function_
->
ConvertOutput
(
&
output
,
&
output
);
objective_function_
->
ConvertOutput
(
&
output
,
&
output
);
new_tree
->
Split
(
0
,
0
,
0
,
0
,
0
,
new_tree
->
AsConstantTree
(
output
);
output
,
output
,
0
,
0
,
-
1
,
MissingType
::
None
,
true
);
train_score_updater_
->
AddScore
(
output
,
cur_tree_id
);
train_score_updater_
->
AddScore
(
output
,
cur_tree_id
);
for
(
auto
&
score_updater
:
valid_score_updater_
)
{
for
(
auto
&
score_updater
:
valid_score_updater_
)
{
score_updater
->
AddScore
(
output
,
cur_tree_id
);
score_updater
->
AddScore
(
output
,
cur_tree_id
);
...
...
src/io/tree.cpp
View file @
cc83cd67
...
@@ -18,7 +18,6 @@ namespace LightGBM {
...
@@ -18,7 +18,6 @@ namespace LightGBM {
Tree
::
Tree
(
int
max_leaves
)
Tree
::
Tree
(
int
max_leaves
)
:
max_leaves_
(
max_leaves
)
{
:
max_leaves_
(
max_leaves
)
{
num_leaves_
=
0
;
left_child_
.
resize
(
max_leaves_
-
1
);
left_child_
.
resize
(
max_leaves_
-
1
);
right_child_
.
resize
(
max_leaves_
-
1
);
right_child_
.
resize
(
max_leaves_
-
1
);
split_feature_inner_
.
resize
(
max_leaves_
-
1
);
split_feature_inner_
.
resize
(
max_leaves_
-
1
);
...
@@ -36,6 +35,7 @@ Tree::Tree(int max_leaves)
...
@@ -36,6 +35,7 @@ Tree::Tree(int max_leaves)
// root is in the depth 0
// root is in the depth 0
leaf_depth_
[
0
]
=
0
;
leaf_depth_
[
0
]
=
0
;
num_leaves_
=
1
;
num_leaves_
=
1
;
leaf_value_
[
0
]
=
0.0
f
;
leaf_parent_
[
0
]
=
-
1
;
leaf_parent_
[
0
]
=
-
1
;
shrinkage_
=
1.0
f
;
shrinkage_
=
1.0
f
;
num_cat_
=
0
;
num_cat_
=
0
;
...
@@ -195,8 +195,6 @@ std::string Tree::ToString() {
...
@@ -195,8 +195,6 @@ std::string Tree::ToString() {
<<
Common
::
ArrayToString
<
int
>
(
left_child_
,
num_leaves_
-
1
,
' '
)
<<
std
::
endl
;
<<
Common
::
ArrayToString
<
int
>
(
left_child_
,
num_leaves_
-
1
,
' '
)
<<
std
::
endl
;
str_buf
<<
"right_child="
str_buf
<<
"right_child="
<<
Common
::
ArrayToString
<
int
>
(
right_child_
,
num_leaves_
-
1
,
' '
)
<<
std
::
endl
;
<<
Common
::
ArrayToString
<
int
>
(
right_child_
,
num_leaves_
-
1
,
' '
)
<<
std
::
endl
;
str_buf
<<
"leaf_parent="
<<
Common
::
ArrayToString
<
int
>
(
leaf_parent_
,
num_leaves_
,
' '
)
<<
std
::
endl
;
str_buf
<<
"leaf_value="
str_buf
<<
"leaf_value="
<<
Common
::
ArrayToString
<
double
>
(
leaf_value_
,
num_leaves_
,
' '
)
<<
std
::
endl
;
<<
Common
::
ArrayToString
<
double
>
(
leaf_value_
,
num_leaves_
,
' '
)
<<
std
::
endl
;
str_buf
<<
"leaf_count="
str_buf
<<
"leaf_count="
...
@@ -217,7 +215,7 @@ std::string Tree::ToJSON() {
...
@@ -217,7 +215,7 @@ std::string Tree::ToJSON() {
str_buf
<<
"
\"
num_cat
\"
:"
<<
num_cat_
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
num_cat
\"
:"
<<
num_cat_
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
shrinkage
\"
:"
<<
shrinkage_
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
shrinkage
\"
:"
<<
shrinkage_
<<
","
<<
std
::
endl
;
if
(
num_leaves_
==
1
)
{
if
(
num_leaves_
==
1
)
{
str_buf
<<
"
\"
tree_structure
\"
:"
<<
NodeToJSON
(
-
1
)
<<
std
::
endl
;
str_buf
<<
"
\"
tree_structure
\"
:
{
"
<<
"
\"
leaf_value
\"
:"
<<
leaf_value_
[
0
]
<<
"}"
<<
std
::
endl
;
}
else
{
}
else
{
str_buf
<<
"
\"
tree_structure
\"
:"
<<
NodeToJSON
(
0
)
<<
std
::
endl
;
str_buf
<<
"
\"
tree_structure
\"
:"
<<
NodeToJSON
(
0
)
<<
std
::
endl
;
}
}
...
@@ -264,7 +262,6 @@ std::string Tree::NodeToJSON(int index) {
...
@@ -264,7 +262,6 @@ std::string Tree::NodeToJSON(int index) {
index
=
~
index
;
index
=
~
index
;
str_buf
<<
"{"
<<
std
::
endl
;
str_buf
<<
"{"
<<
std
::
endl
;
str_buf
<<
"
\"
leaf_index
\"
:"
<<
index
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
leaf_index
\"
:"
<<
index
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
leaf_parent
\"
:"
<<
leaf_parent_
[
index
]
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
leaf_value
\"
:"
<<
leaf_value_
[
index
]
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
leaf_value
\"
:"
<<
leaf_value_
[
index
]
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
leaf_count
\"
:"
<<
leaf_count_
[
index
]
<<
std
::
endl
;
str_buf
<<
"
\"
leaf_count
\"
:"
<<
leaf_count_
[
index
]
<<
std
::
endl
;
str_buf
<<
"}"
;
str_buf
<<
"}"
;
...
@@ -280,8 +277,8 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) {
...
@@ -280,8 +277,8 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) {
str_buf
<<
"Leaf"
;
str_buf
<<
"Leaf"
;
}
}
str_buf
<<
"(const double* arr) { "
;
str_buf
<<
"(const double* arr) { "
;
if
(
num_leaves_
=
=
1
)
{
if
(
num_leaves_
<
=
1
)
{
str_buf
<<
"return
0
"
;
str_buf
<<
"return
"
<<
leaf_value_
[
0
]
<<
";
"
;
}
else
{
}
else
{
// use this for the missing value conversion
// use this for the missing value conversion
str_buf
<<
"double fval = 0.0f; "
;
str_buf
<<
"double fval = 0.0f; "
;
...
@@ -350,6 +347,12 @@ Tree::Tree(const std::string& str) {
...
@@ -350,6 +347,12 @@ Tree::Tree(const std::string& str) {
Common
::
Atoi
(
key_vals
[
"num_cat"
].
c_str
(),
&
num_cat_
);
Common
::
Atoi
(
key_vals
[
"num_cat"
].
c_str
(),
&
num_cat_
);
if
(
key_vals
.
count
(
"leaf_value"
))
{
leaf_value_
=
Common
::
StringToArray
<
double
>
(
key_vals
[
"leaf_value"
],
' '
,
num_leaves_
);
}
else
{
Log
::
Fatal
(
"Tree model string format error, should contain leaf_value field"
);
}
if
(
num_leaves_
<=
1
)
{
return
;
}
if
(
num_leaves_
<=
1
)
{
return
;
}
if
(
key_vals
.
count
(
"left_child"
))
{
if
(
key_vals
.
count
(
"left_child"
))
{
...
@@ -376,12 +379,6 @@ Tree::Tree(const std::string& str) {
...
@@ -376,12 +379,6 @@ Tree::Tree(const std::string& str) {
Log
::
Fatal
(
"Tree model string format error, should contain threshold field"
);
Log
::
Fatal
(
"Tree model string format error, should contain threshold field"
);
}
}
if
(
key_vals
.
count
(
"leaf_value"
))
{
leaf_value_
=
Common
::
StringToArray
<
double
>
(
key_vals
[
"leaf_value"
],
' '
,
num_leaves_
);
}
else
{
Log
::
Fatal
(
"Tree model string format error, should contain leaf_value field"
);
}
if
(
key_vals
.
count
(
"split_gain"
))
{
if
(
key_vals
.
count
(
"split_gain"
))
{
split_gain_
=
Common
::
StringToArray
<
double
>
(
key_vals
[
"split_gain"
],
' '
,
num_leaves_
-
1
);
split_gain_
=
Common
::
StringToArray
<
double
>
(
key_vals
[
"split_gain"
],
' '
,
num_leaves_
-
1
);
}
else
{
}
else
{
...
@@ -406,12 +403,6 @@ Tree::Tree(const std::string& str) {
...
@@ -406,12 +403,6 @@ Tree::Tree(const std::string& str) {
leaf_count_
.
resize
(
num_leaves_
);
leaf_count_
.
resize
(
num_leaves_
);
}
}
if
(
key_vals
.
count
(
"leaf_parent"
))
{
leaf_parent_
=
Common
::
StringToArray
<
int
>
(
key_vals
[
"leaf_parent"
],
' '
,
num_leaves_
);
}
else
{
leaf_parent_
.
resize
(
num_leaves_
);
}
if
(
key_vals
.
count
(
"decision_type"
))
{
if
(
key_vals
.
count
(
"decision_type"
))
{
decision_type_
=
Common
::
StringToArray
<
int8_t
>
(
key_vals
[
"decision_type"
],
' '
,
num_leaves_
-
1
);
decision_type_
=
Common
::
StringToArray
<
int8_t
>
(
key_vals
[
"decision_type"
],
' '
,
num_leaves_
-
1
);
}
else
{
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment