Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
13329682
"include/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "b752170ba950890dd7e211fdecedef4d8deffb8b"
Commit
13329682
authored
Nov 23, 2016
by
Guolin Ke
Browse files
support rollback iteration and reset config during training.
parent
422c0ef7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
131 additions
and
0 deletions
+131
-0
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+19
-0
include/LightGBM/c_api.h
include/LightGBM/c_api.h
+21
-0
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+46
-0
src/boosting/gbdt.h
src/boosting/gbdt.h
+14
-0
src/c_api.cpp
src/c_api.cpp
+31
-0
No files found.
include/LightGBM/boosting.h
View file @
13329682
...
@@ -35,6 +35,12 @@ public:
...
@@ -35,6 +35,12 @@ public:
const
ObjectiveFunction
*
object_function
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
=
0
;
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
=
0
;
/*!
* \brief Reset Config for current boosting
* \param config Configs for boosting
*/
virtual
void
ResetConfig
(
const
BoostingConfig
*
config
)
=
0
;
/*!
/*!
* \brief Add a validation data
* \brief Add a validation data
* \param valid_data Validation data
* \param valid_data Validation data
...
@@ -52,6 +58,19 @@ public:
...
@@ -52,6 +58,19 @@ public:
*/
*/
virtual
bool
TrainOneIter
(
const
score_t
*
gradient
,
const
score_t
*
hessian
,
bool
is_eval
)
=
0
;
virtual
bool
TrainOneIter
(
const
score_t
*
gradient
,
const
score_t
*
hessian
,
bool
is_eval
)
=
0
;
/*!
* \brief Rollback one iteration
*/
virtual
void
RollbackOneIter
()
=
0
;
/*!
* \brief return current iteration
*/
virtual
int
GetCurrentIteration
()
const
=
0
;
/*!
* \brief Eval metrics and check is met early stopping or not
*/
virtual
bool
EvalAndCheckEarlyStopping
()
=
0
;
virtual
bool
EvalAndCheckEarlyStopping
()
=
0
;
/*!
/*!
* \brief Get evaluation result at data_idx data
* \brief Get evaluation result at data_idx data
...
...
include/LightGBM/c_api.h
View file @
13329682
...
@@ -239,6 +239,7 @@ DllExport int LGBM_BoosterCreateFromModelfile(
...
@@ -239,6 +239,7 @@ DllExport int LGBM_BoosterCreateFromModelfile(
int64_t
*
out_num_total_model
,
int64_t
*
out_num_total_model
,
BoosterHandle
*
out
);
BoosterHandle
*
out
);
/*!
/*!
* \brief free obj in handle
* \brief free obj in handle
* \param handle handle to be freed
* \param handle handle to be freed
...
@@ -246,6 +247,13 @@ DllExport int LGBM_BoosterCreateFromModelfile(
...
@@ -246,6 +247,13 @@ DllExport int LGBM_BoosterCreateFromModelfile(
*/
*/
DllExport
int
LGBM_BoosterFree
(
BoosterHandle
handle
);
DllExport
int
LGBM_BoosterFree
(
BoosterHandle
handle
);
/*!
* \brief Reset config for current booster
* \param parameters format: 'key1=value1 key2=value2'
* \return 0 when success, -1 when failure happens
*/
DllExport
int
LGBM_BoosterResetParameter
(
BoosterHandle
handle
,
const
char
*
parameters
);
/*!
/*!
* \brief Get number of class
* \brief Get number of class
* \return number of class
* \return number of class
...
@@ -274,6 +282,19 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
...
@@ -274,6 +282,19 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const
float
*
hess
,
const
float
*
hess
,
int
*
is_finished
);
int
*
is_finished
);
/*!
* \brief Rollback one iteration
* \param handle handle
* \return 0 when success, -1 when failure happens
*/
DllExport
int
LGBM_BoosterRollbackOneIter
(
BoosterHandle
handle
);
/*!
* \brief Get iteration of current boosting rounds
* \return iteration of boosting rounds
*/
DllExport
int
LGBM_BoosterGetCurrentIteration
(
BoosterHandle
handle
,
int64_t
*
out_iteration
);
/*!
/*!
* \brief Get number of eval
* \brief Get number of eval
* \return total number of eval result
* \return total number of eval result
...
...
src/boosting/gbdt.cpp
View file @
13329682
...
@@ -36,6 +36,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
...
@@ -36,6 +36,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
train_data_
=
train_data
;
train_data_
=
train_data
;
num_class_
=
config
->
num_class
;
num_class_
=
config
->
num_class
;
// create tree learner
// create tree learner
tree_learner_
.
clear
();
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
gbdt_config_
->
tree_learner_type
,
gbdt_config_
->
tree_config
));
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
gbdt_config_
->
tree_learner_type
,
gbdt_config_
->
tree_config
));
new_tree_learner
->
Init
(
train_data_
);
new_tree_learner
->
Init
(
train_data_
);
...
@@ -82,6 +83,32 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
...
@@ -82,6 +83,32 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
}
}
void
GBDT
::
ResetConfig
(
const
BoostingConfig
*
config
)
{
gbdt_config_
=
config
;
early_stopping_round_
=
gbdt_config_
->
early_stopping_round
;
shrinkage_rate_
=
gbdt_config_
->
learning_rate
;
// create tree learner
tree_learner_
.
clear
();
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
gbdt_config_
->
tree_learner_type
,
gbdt_config_
->
tree_config
));
new_tree_learner
->
Init
(
train_data_
);
// init tree learner
tree_learner_
.
push_back
(
std
::
move
(
new_tree_learner
));
}
tree_learner_
.
shrink_to_fit
();
// if need bagging, create buffer
if
(
gbdt_config_
->
bagging_fraction
<
1.0
&&
gbdt_config_
->
bagging_freq
>
0
)
{
out_of_bag_data_indices_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
bag_data_indices_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
}
else
{
out_of_bag_data_cnt_
=
0
;
out_of_bag_data_indices_
.
clear
();
bag_data_cnt_
=
num_data_
;
bag_data_indices_
.
clear
();
}
// initialize random generator
random_
=
Random
(
gbdt_config_
->
bagging_seed
);
}
void
GBDT
::
AddDataset
(
const
Dataset
*
valid_data
,
void
GBDT
::
AddDataset
(
const
Dataset
*
valid_data
,
const
std
::
vector
<
const
Metric
*>&
valid_metrics
)
{
const
std
::
vector
<
const
Metric
*>&
valid_metrics
)
{
if
(
iter_
>
0
)
{
if
(
iter_
>
0
)
{
...
@@ -204,6 +231,25 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
...
@@ -204,6 +231,25 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
}
}
void
GBDT
::
RollbackOneIter
()
{
if
(
iter_
==
0
)
{
return
;
}
int
cur_iter
=
iter_
-
1
;
// reset score
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
auto
curr_tree
=
cur_iter
*
num_class_
+
curr_class
;
models_
[
curr_tree
]
->
Shrinkage
(
-
1.0
);
train_score_updater_
->
AddScore
(
models_
[
curr_tree
].
get
(),
curr_class
);
for
(
auto
&
score_updater
:
valid_score_updater_
)
{
score_updater
->
AddScore
(
models_
[
curr_tree
].
get
(),
curr_class
);
}
}
// remove model
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
models_
.
pop_back
();
}
--
iter_
;
}
bool
GBDT
::
EvalAndCheckEarlyStopping
()
{
bool
GBDT
::
EvalAndCheckEarlyStopping
()
{
bool
is_met_early_stopping
=
false
;
bool
is_met_early_stopping
=
false
;
// print message for metric
// print message for metric
...
...
src/boosting/gbdt.h
View file @
13329682
...
@@ -35,6 +35,13 @@ public:
...
@@ -35,6 +35,13 @@ public:
void
Init
(
const
BoostingConfig
*
gbdt_config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
void
Init
(
const
BoostingConfig
*
gbdt_config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
override
;
/*!
* \brief Reset Config for current boosting
* \param config Configs for boosting
*/
void
ResetConfig
(
const
BoostingConfig
*
config
)
override
;
/*!
/*!
* \brief Adding a validation dataset
* \brief Adding a validation dataset
* \param valid_data Validation dataset
* \param valid_data Validation dataset
...
@@ -51,6 +58,13 @@ public:
...
@@ -51,6 +58,13 @@ public:
*/
*/
virtual
bool
TrainOneIter
(
const
score_t
*
gradient
,
const
score_t
*
hessian
,
bool
is_eval
)
override
;
virtual
bool
TrainOneIter
(
const
score_t
*
gradient
,
const
score_t
*
hessian
,
bool
is_eval
)
override
;
/*!
* \brief Rollback one iteration
*/
void
RollbackOneIter
()
override
;
int
GetCurrentIteration
()
const
override
{
return
iter_
;
}
bool
EvalAndCheckEarlyStopping
()
override
;
bool
EvalAndCheckEarlyStopping
()
override
;
/*!
/*!
...
...
src/c_api.cpp
View file @
13329682
...
@@ -150,6 +150,17 @@ public:
...
@@ -150,6 +150,17 @@ public:
return
idx
;
return
idx
;
}
}
void
ResetBoostingConfig
(
const
char
*
parameters
)
{
OverallConfig
new_config
;
new_config
.
LoadFromString
(
parameters
);
config_
.
boosting_config
=
new_config
.
boosting_config
;
boosting_
->
ResetConfig
(
&
config_
.
boosting_config
);
}
void
RollbackOneIter
()
{
boosting_
->
RollbackOneIter
();
}
const
Boosting
*
GetBoosting
()
const
{
return
boosting_
.
get
();
}
const
Boosting
*
GetBoosting
()
const
{
return
boosting_
.
get
();
}
private:
private:
...
@@ -471,6 +482,13 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) {
...
@@ -471,6 +482,13 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) {
API_END
();
API_END
();
}
}
DllExport
int
LGBM_BoosterResetParameter
(
BoosterHandle
handle
,
const
char
*
parameters
)
{
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
ref_booster
->
ResetBoostingConfig
(
parameters
);
API_END
();
}
DllExport
int
LGBM_BoosterGetNumClasses
(
BoosterHandle
handle
,
int64_t
*
out_len
)
{
DllExport
int
LGBM_BoosterGetNumClasses
(
BoosterHandle
handle
,
int64_t
*
out_len
)
{
API_BEGIN
();
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
...
@@ -503,6 +521,19 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
...
@@ -503,6 +521,19 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
API_END
();
API_END
();
}
}
DllExport
int
LGBM_BoosterRollbackOneIter
(
BoosterHandle
handle
)
{
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
ref_booster
->
RollbackOneIter
();
API_END
();
}
DllExport
int
LGBM_BoosterGetCurrentIteration
(
BoosterHandle
handle
,
int64_t
*
out_iteration
)
{
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
*
out_iteration
=
ref_booster
->
GetBoosting
()
->
GetCurrentIteration
();
API_END
();
}
/*!
/*!
* \brief Get number of eval
* \brief Get number of eval
* \return total number of eval result
* \return total number of eval result
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment