Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
9db054cf
Commit
9db054cf
authored
Dec 20, 2016
by
Guolin Ke
Browse files
fix bug in ResetTrainingData
parent
99b483dd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
30 deletions
+24
-30
include/LightGBM/utils/log.h
include/LightGBM/utils/log.h
+1
-1
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+23
-29
No files found.
include/LightGBM/utils/log.h
View file @
9db054cf
...
@@ -19,7 +19,7 @@ namespace LightGBM {
...
@@ -19,7 +19,7 @@ namespace LightGBM {
#ifndef CHECK_NOTNULL
#ifndef CHECK_NOTNULL
#define CHECK_NOTNULL(pointer) \
#define CHECK_NOTNULL(pointer) \
if ((pointer) == nullptr) LightGBM::Log::Fatal(#pointer " Can't be NULL
"
);
if ((pointer) == nullptr) LightGBM::Log::Fatal(#pointer " Can't be NULL
at %s, line %d .\n", __FILE__, __LINE__
);
#endif
#endif
...
...
src/boosting/gbdt.cpp
View file @
9db054cf
...
@@ -32,44 +32,21 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
...
@@ -32,44 +32,21 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
num_iteration_for_pred_
=
0
;
num_iteration_for_pred_
=
0
;
max_feature_idx_
=
0
;
max_feature_idx_
=
0
;
num_class_
=
config
->
num_class
;
num_class_
=
config
->
num_class
;
random_
=
Random
(
config
->
bagging_seed
);
train_data_
=
nullptr
;
train_data_
=
nullptr
;
gbdt_config_
=
nullptr
;
gbdt_config_
=
nullptr
;
tree_learner_
.
clear
();
ResetTrainingData
(
config
,
train_data
,
object_function
,
training_metrics
);
ResetTrainingData
(
config
,
train_data
,
object_function
,
training_metrics
);
}
}
void
GBDT
::
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
void
GBDT
::
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
if
(
train_data
==
nullptr
)
{
return
;
}
auto
new_config
=
std
::
unique_ptr
<
BoostingConfig
>
(
new
BoostingConfig
(
*
config
));
auto
new_config
=
std
::
unique_ptr
<
BoostingConfig
>
(
new
BoostingConfig
(
*
config
));
if
(
train_data_
!=
nullptr
&&
!
train_data_
->
CheckAlign
(
*
train_data
))
{
if
(
train_data_
!=
nullptr
&&
!
train_data_
->
CheckAlign
(
*
train_data
))
{
Log
::
Fatal
(
"cannot reset training data, since new training data has different bin mappers"
);
Log
::
Fatal
(
"cannot reset training data, since new training data has different bin mappers"
);
}
}
early_stopping_round_
=
new_config
->
early_stopping_round
;
early_stopping_round_
=
new_config
->
early_stopping_round
;
shrinkage_rate_
=
new_config
->
learning_rate
;
shrinkage_rate_
=
new_config
->
learning_rate
;
// cannot reset seed, only create one time
if
(
gbdt_config_
==
nullptr
)
{
random_
=
Random
(
new_config
->
bagging_seed
);
}
// create tree learner, only create once
if
(
gbdt_config_
==
nullptr
)
{
tree_learner_
.
clear
();
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
new_config
->
tree_learner_type
,
&
new_config
->
tree_config
));
tree_learner_
.
push_back
(
std
::
move
(
new_tree_learner
));
}
tree_learner_
.
shrink_to_fit
();
}
// init tree learner
if
(
train_data_
!=
train_data
)
{
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
tree_learner_
[
i
]
->
Init
(
train_data
);
}
}
// reset config for tree learner
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
tree_learner_
[
i
]
->
ResetConfig
(
&
new_config
->
tree_config
);
}
object_function_
=
object_function
;
object_function_
=
object_function
;
...
@@ -80,7 +57,19 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -80,7 +57,19 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
sigmoid_
=
new_config
->
sigmoid
;
sigmoid_
=
new_config
->
sigmoid
;
}
}
if
(
train_data_
!=
train_data
)
{
if
(
train_data_
!=
train_data
&&
train_data
!=
nullptr
)
{
if
(
tree_learner_
.
empty
())
{
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
new_config
->
tree_learner_type
,
&
new_config
->
tree_config
));
tree_learner_
.
push_back
(
std
::
move
(
new_tree_learner
));
}
tree_learner_
.
shrink_to_fit
();
}
// init tree learner
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
tree_learner_
[
i
]
->
Init
(
train_data
);
}
// push training metrics
// push training metrics
training_metrics_
.
clear
();
training_metrics_
.
clear
();
for
(
const
auto
&
metric
:
training_metrics
)
{
for
(
const
auto
&
metric
:
training_metrics
)
{
...
@@ -109,9 +98,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -109,9 +98,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
label_idx_
=
train_data
->
label_idx
();
label_idx_
=
train_data
->
label_idx
();
}
}
if
(
train_data_
!=
train_data
if
((
train_data_
!=
train_data
&&
train_data
!=
nullptr
)
||
gbdt_config_
==
nullptr
||
(
gbdt_config_
!=
nullptr
&&
gbdt_config_
->
bagging_fraction
!=
new_config
->
bagging_fraction
))
{
||
(
gbdt_config_
->
bagging_fraction
!=
new_config
->
bagging_fraction
))
{
// if need bagging, create buffer
// if need bagging, create buffer
if
(
new_config
->
bagging_fraction
<
1.0
&&
new_config
->
bagging_freq
>
0
)
{
if
(
new_config
->
bagging_fraction
<
1.0
&&
new_config
->
bagging_freq
>
0
)
{
out_of_bag_data_indices_
.
resize
(
num_data_
);
out_of_bag_data_indices_
.
resize
(
num_data_
);
...
@@ -124,6 +112,12 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -124,6 +112,12 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
}
}
}
}
train_data_
=
train_data
;
train_data_
=
train_data
;
if
(
train_data_
!=
nullptr
)
{
// reset config for tree learner
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
tree_learner_
[
i
]
->
ResetConfig
(
&
new_config
->
tree_config
);
}
}
gbdt_config_
.
reset
(
new_config
.
release
());
gbdt_config_
.
reset
(
new_config
.
release
());
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment