Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
f449a45b
Commit
f449a45b
authored
Dec 24, 2016
by
Guolin Ke
Browse files
reduce memory cost for multi classification
parent
381a945d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
24 deletions
+14
-24
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+12
-21
src/boosting/gbdt.h
src/boosting/gbdt.h
+2
-3
No files found.
src/boosting/gbdt.cpp
View file @
f449a45b
...
...
@@ -35,7 +35,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
random_
=
Random
(
config
->
bagging_seed
);
train_data_
=
nullptr
;
gbdt_config_
=
nullptr
;
tree_learner_
.
clear
()
;
tree_learner_
=
nullptr
;
ResetTrainingData
(
config
,
train_data
,
object_function
,
training_metrics
);
}
...
...
@@ -58,17 +58,11 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
}
if
(
train_data_
!=
train_data
&&
train_data
!=
nullptr
)
{
if
(
tree_learner_
.
empty
())
{
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
new_config
->
tree_learner_type
,
&
new_config
->
tree_config
));
tree_learner_
.
push_back
(
std
::
move
(
new_tree_learner
));
}
tree_learner_
.
shrink_to_fit
();
if
(
tree_learner_
==
nullptr
)
{
tree_learner_
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
new_config
->
tree_learner_type
,
&
new_config
->
tree_config
));
}
// init tree learner
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
tree_learner_
[
i
]
->
Init
(
train_data
);
}
tree_learner_
->
Init
(
train_data
);
// push training metrics
training_metrics_
.
clear
();
...
...
@@ -114,9 +108,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
train_data_
=
train_data
;
if
(
train_data_
!=
nullptr
)
{
// reset config for tree learner
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
tree_learner_
[
i
]
->
ResetConfig
(
&
new_config
->
tree_config
);
}
tree_learner_
->
ResetConfig
(
&
new_config
->
tree_config
);
}
gbdt_config_
.
reset
(
new_config
.
release
());
}
...
...
@@ -154,7 +146,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
}
void
GBDT
::
Bagging
(
int
iter
,
const
int
curr_class
)
{
void
GBDT
::
Bagging
(
int
iter
)
{
// if need bagging
if
(
!
out_of_bag_data_indices_
.
empty
()
&&
iter
%
gbdt_config_
->
bagging_freq
==
0
)
{
// if doesn't have query data
...
...
@@ -203,7 +195,7 @@ void GBDT::Bagging(int iter, const int curr_class) {
}
Log
::
Debug
(
"Re-bagging, using %d data to train"
,
bag_data_cnt_
);
// set bagging data to tree learner
tree_learner_
[
curr_class
]
->
SetBaggingData
(
bag_data_indices_
.
data
(),
bag_data_cnt_
);
tree_learner_
->
SetBaggingData
(
bag_data_indices_
.
data
(),
bag_data_cnt_
);
}
}
...
...
@@ -221,13 +213,12 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
gradient
=
gradients_
.
data
();
hessian
=
hessians_
.
data
();
}
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
// bagging logic
Bagging
(
iter_
,
curr_class
);
Bagging
(
iter_
);
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
// train a new tree
std
::
unique_ptr
<
Tree
>
new_tree
(
tree_learner_
[
curr_class
]
->
Train
(
gradient
+
curr_class
*
num_data_
,
hessian
+
curr_class
*
num_data_
));
std
::
unique_ptr
<
Tree
>
new_tree
(
tree_learner_
->
Train
(
gradient
+
curr_class
*
num_data_
,
hessian
+
curr_class
*
num_data_
));
// if cannot learn a new tree, then stop
if
(
new_tree
->
num_leaves
()
<=
1
)
{
Log
::
Info
(
"Stopped training because there are no more leafs that meet the split requirements."
);
...
...
@@ -290,7 +281,7 @@ bool GBDT::EvalAndCheckEarlyStopping() {
void
GBDT
::
UpdateScore
(
const
Tree
*
tree
,
const
int
curr_class
)
{
// update training score
train_score_updater_
->
AddScore
(
tree_learner_
[
curr_class
]
.
get
(),
curr_class
);
train_score_updater_
->
AddScore
(
tree_learner_
.
get
(),
curr_class
);
// update validation score
for
(
auto
&
score_updater
:
valid_score_updater_
)
{
score_updater
->
AddScore
(
tree
,
curr_class
);
...
...
@@ -301,7 +292,7 @@ std::string GBDT::OutputMetric(int iter) {
bool
need_output
=
(
iter
%
gbdt_config_
->
output_freq
)
==
0
;
std
::
string
ret
=
""
;
std
::
stringstream
msg_buf
;
std
::
vector
<
std
::
pair
<
int
,
in
t
>>
meet_early_stopping_pairs
;
std
::
vector
<
std
::
pair
<
size_t
,
size_
t
>>
meet_early_stopping_pairs
;
// print training metric
if
(
need_output
)
{
for
(
auto
&
sub_metric
:
training_metrics_
)
{
...
...
src/boosting/gbdt.h
View file @
f449a45b
...
...
@@ -214,9 +214,8 @@ protected:
/*!
* \brief Implement bagging logic
* \param iter Current interation
* \param curr_class Current class for multiclass training
*/
void
Bagging
(
int
iter
,
const
int
curr_class
);
void
Bagging
(
int
iter
);
/*!
* \brief updating score for out-of-bag data.
* Data should be update since we may re-bagging data on training
...
...
@@ -252,7 +251,7 @@ protected:
/*! \brief Config of gbdt */
std
::
unique_ptr
<
BoostingConfig
>
gbdt_config_
;
/*! \brief Tree learner, will use this class to learn trees */
std
::
vector
<
std
::
unique_ptr
<
TreeLearner
>
>
tree_learner_
;
std
::
unique_ptr
<
TreeLearner
>
tree_learner_
;
/*! \brief Objective function */
const
ObjectiveFunction
*
object_function_
;
/*! \brief Store and update training data's score */
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment