Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
4aa63045
Commit
4aa63045
authored
Oct 24, 2016
by
Guolin Ke
Browse files
support max_depth option.
parent
9895116d
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
37 additions
and
3 deletions
+37
-3
include/LightGBM/config.h
include/LightGBM/config.h
+6
-0
include/LightGBM/tree.h
include/LightGBM/tree.h
+5
-0
src/io/config.cpp
src/io/config.cpp
+2
-0
src/io/tree.cpp
src/io/tree.cpp
+8
-1
src/treelearner/serial_tree_learner.cpp
src/treelearner/serial_tree_learner.cpp
+14
-2
src/treelearner/serial_tree_learner.h
src/treelearner/serial_tree_learner.h
+2
-0
No files found.
include/LightGBM/config.h
View file @
4aa63045
...
@@ -137,11 +137,17 @@ struct TreeConfig: public ConfigBase {
...
@@ -137,11 +137,17 @@ struct TreeConfig: public ConfigBase {
public:
public:
int
min_data_in_leaf
=
100
;
int
min_data_in_leaf
=
100
;
double
min_sum_hessian_in_leaf
=
10.0
f
;
double
min_sum_hessian_in_leaf
=
10.0
f
;
// should > 1, only one leaf means not need to learning
int
num_leaves
=
127
;
int
num_leaves
=
127
;
int
feature_fraction_seed
=
2
;
int
feature_fraction_seed
=
2
;
double
feature_fraction
=
1.0
;
double
feature_fraction
=
1.0
;
// max cache size(unit:MB) for historical histogram. < 0 means not limit
// max cache size(unit:MB) for historical histogram. < 0 means not limit
double
histogram_pool_size
=
-
1
;
double
histogram_pool_size
=
-
1
;
// max depth of tree model.
// Still grow tree by leaf-wise, but limit the max depth to avoid over-fitting
// And the max leaves will be min(num_leaves, pow(2, max_depth - 1))
// max_depth < 0 means not limit
int
max_depth
=
-
1
;
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
override
;
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
override
;
};
};
...
...
include/LightGBM/tree.h
View file @
4aa63045
...
@@ -80,6 +80,9 @@ public:
...
@@ -80,6 +80,9 @@ public:
/*! \brief Get Number of leaves*/
/*! \brief Get Number of leaves*/
inline
int
num_leaves
()
const
{
return
num_leaves_
;
}
inline
int
num_leaves
()
const
{
return
num_leaves_
;
}
/*! \brief Get depth of specific leaf*/
inline
int
leaf_depth
(
int
leaf_idx
)
const
{
return
leaf_depth_
[
leaf_idx
];
}
/*!
/*!
* \brief Shrinkage for the tree's output
* \brief Shrinkage for the tree's output
* shrinkage rate (a.k.a learning rate) is used to tune the traning process
* shrinkage rate (a.k.a learning rate) is used to tune the traning process
...
@@ -139,6 +142,8 @@ private:
...
@@ -139,6 +142,8 @@ private:
int
*
leaf_parent_
;
int
*
leaf_parent_
;
/*! \brief Output of leaves */
/*! \brief Output of leaves */
score_t
*
leaf_value_
;
score_t
*
leaf_value_
;
/*! \brief Depth for leaves */
int
*
leaf_depth_
;
};
};
...
...
src/io/config.cpp
View file @
4aa63045
...
@@ -228,6 +228,8 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
...
@@ -228,6 +228,8 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetDouble
(
params
,
"feature_fraction"
,
&
feature_fraction
);
GetDouble
(
params
,
"feature_fraction"
,
&
feature_fraction
);
CHECK
(
feature_fraction
>
0.0
&&
feature_fraction
<=
1.0
);
CHECK
(
feature_fraction
>
0.0
&&
feature_fraction
<=
1.0
);
GetDouble
(
params
,
"histogram_pool_size"
,
&
histogram_pool_size
);
GetDouble
(
params
,
"histogram_pool_size"
,
&
histogram_pool_size
);
GetInt
(
params
,
"max_depth"
,
&
max_depth
);
CHECK
(
max_depth
>
1
||
max_depth
<
0
);
}
}
...
...
src/io/tree.cpp
View file @
4aa63045
...
@@ -28,6 +28,9 @@ Tree::Tree(int max_leaves)
...
@@ -28,6 +28,9 @@ Tree::Tree(int max_leaves)
leaf_parent_
=
new
int
[
max_leaves_
];
leaf_parent_
=
new
int
[
max_leaves_
];
leaf_value_
=
new
score_t
[
max_leaves_
];
leaf_value_
=
new
score_t
[
max_leaves_
];
leaf_depth_
=
new
int
[
max_leaves_
];
// root is in the depth 1
leaf_depth_
[
0
]
=
1
;
num_leaves_
=
1
;
num_leaves_
=
1
;
leaf_parent_
[
0
]
=
-
1
;
leaf_parent_
[
0
]
=
-
1
;
}
}
...
@@ -41,6 +44,7 @@ Tree::~Tree() {
...
@@ -41,6 +44,7 @@ Tree::~Tree() {
if
(
threshold_
!=
nullptr
)
{
delete
[]
threshold_
;
}
if
(
threshold_
!=
nullptr
)
{
delete
[]
threshold_
;
}
if
(
split_gain_
!=
nullptr
)
{
delete
[]
split_gain_
;
}
if
(
split_gain_
!=
nullptr
)
{
delete
[]
split_gain_
;
}
if
(
leaf_value_
!=
nullptr
)
{
delete
[]
leaf_value_
;
}
if
(
leaf_value_
!=
nullptr
)
{
delete
[]
leaf_value_
;
}
if
(
leaf_depth_
!=
nullptr
)
{
delete
[]
leaf_depth_
;
}
}
}
int
Tree
::
Split
(
int
leaf
,
int
feature
,
unsigned
int
threshold_bin
,
int
real_feature
,
int
Tree
::
Split
(
int
leaf
,
int
feature
,
unsigned
int
threshold_bin
,
int
real_feature
,
...
@@ -70,9 +74,11 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat
...
@@ -70,9 +74,11 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat
leaf_parent_
[
num_leaves_
]
=
new_node_idx
;
leaf_parent_
[
num_leaves_
]
=
new_node_idx
;
leaf_value_
[
leaf
]
=
left_value
;
leaf_value_
[
leaf
]
=
left_value
;
leaf_value_
[
num_leaves_
]
=
right_value
;
leaf_value_
[
num_leaves_
]
=
right_value
;
// update leaf depth
leaf_depth_
[
num_leaves_
]
=
leaf_depth_
[
leaf
]
+
1
;
leaf_depth_
[
leaf
]
++
;
++
num_leaves_
;
++
num_leaves_
;
return
num_leaves_
-
1
;
return
num_leaves_
-
1
;
}
}
...
@@ -155,6 +161,7 @@ Tree::Tree(const std::string& str) {
...
@@ -155,6 +161,7 @@ Tree::Tree(const std::string& str) {
split_feature_
=
nullptr
;
split_feature_
=
nullptr
;
threshold_in_bin_
=
nullptr
;
threshold_in_bin_
=
nullptr
;
leaf_depth_
=
nullptr
;
Common
::
StringToIntArray
(
key_vals
[
"split_feature"
],
' '
,
Common
::
StringToIntArray
(
key_vals
[
"split_feature"
],
' '
,
num_leaves_
-
1
,
split_feature_real_
);
num_leaves_
-
1
,
split_feature_real_
);
...
...
src/treelearner/serial_tree_learner.cpp
View file @
4aa63045
...
@@ -19,6 +19,7 @@ SerialTreeLearner::SerialTreeLearner(const TreeConfig& tree_config)
...
@@ -19,6 +19,7 @@ SerialTreeLearner::SerialTreeLearner(const TreeConfig& tree_config)
feature_fraction_
=
tree_config
.
feature_fraction
;
feature_fraction_
=
tree_config
.
feature_fraction
;
random_
=
Random
(
tree_config
.
feature_fraction_seed
);
random_
=
Random
(
tree_config
.
feature_fraction_seed
);
histogram_pool_size_
=
tree_config
.
histogram_pool_size
;
histogram_pool_size_
=
tree_config
.
histogram_pool_size
;
max_depth_
=
tree_config
.
max_depth
;
}
}
SerialTreeLearner
::~
SerialTreeLearner
()
{
SerialTreeLearner
::~
SerialTreeLearner
()
{
...
@@ -120,6 +121,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
...
@@ -120,6 +121,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// some initial works before training
// some initial works before training
BeforeTrain
();
BeforeTrain
();
Tree
*
tree
=
new
Tree
(
num_leaves_
);
Tree
*
tree
=
new
Tree
(
num_leaves_
);
// save pointer to last trained tree
last_trained_tree_
=
tree
;
// root leaf
// root leaf
int
left_leaf
=
0
;
int
left_leaf
=
0
;
// only root leaf can be splitted on first time
// only root leaf can be splitted on first time
...
@@ -145,8 +148,6 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
...
@@ -145,8 +148,6 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// split tree with best leaf
// split tree with best leaf
Split
(
tree
,
best_leaf
,
&
left_leaf
,
&
right_leaf
);
Split
(
tree
,
best_leaf
,
&
left_leaf
,
&
right_leaf
);
}
}
// save pointer to last trained tree
last_trained_tree_
=
tree
;
return
tree
;
return
tree
;
}
}
...
@@ -234,6 +235,17 @@ void SerialTreeLearner::BeforeTrain() {
...
@@ -234,6 +235,17 @@ void SerialTreeLearner::BeforeTrain() {
}
}
bool
SerialTreeLearner
::
BeforeFindBestSplit
(
int
left_leaf
,
int
right_leaf
)
{
bool
SerialTreeLearner
::
BeforeFindBestSplit
(
int
left_leaf
,
int
right_leaf
)
{
// check depth of current leaf
if
(
max_depth_
>
0
)
{
// only need to check left leaf, since right leaf is in same level of left leaf
if
(
last_trained_tree_
->
leaf_depth
(
left_leaf
)
>=
max_depth_
)
{
best_split_per_leaf_
[
left_leaf
].
gain
=
kMinScore
;
if
(
right_leaf
>=
0
)
{
best_split_per_leaf_
[
right_leaf
].
gain
=
kMinScore
;
}
return
false
;
}
}
data_size_t
num_data_in_left_child
=
GetGlobalDataCountInLeaf
(
left_leaf
);
data_size_t
num_data_in_left_child
=
GetGlobalDataCountInLeaf
(
left_leaf
);
data_size_t
num_data_in_right_child
=
GetGlobalDataCountInLeaf
(
right_leaf
);
data_size_t
num_data_in_right_child
=
GetGlobalDataCountInLeaf
(
right_leaf
);
// no enough data to continue
// no enough data to continue
...
...
src/treelearner/serial_tree_learner.h
View file @
4aa63045
...
@@ -163,6 +163,8 @@ protected:
...
@@ -163,6 +163,8 @@ protected:
double
histogram_pool_size_
;
double
histogram_pool_size_
;
/*! \brief used to cache historical histogram to speed up*/
/*! \brief used to cache historical histogram to speed up*/
LRUPool
<
FeatureHistogram
*>
histogram_pool_
;
LRUPool
<
FeatureHistogram
*>
histogram_pool_
;
/*! \brief max depth of tree model */
int
max_depth_
;
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment