Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
c2e94f17
"include/LightGBM/vscode:/vscode.git/clone" did not exist on "1548b42bac5d5b7c295ba4d3132e8bda47e34fd1"
Commit
c2e94f17
authored
Dec 18, 2016
by
Guolin Ke
Browse files
refine reset_parameters logic
parent
714c6732
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
226 additions
and
127 deletions
+226
-127
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+0
-6
include/LightGBM/tree_learner.h
include/LightGBM/tree_learner.h
+9
-2
src/boosting/dart.hpp
src/boosting/dart.hpp
+10
-5
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+41
-19
src/boosting/gbdt.h
src/boosting/gbdt.h
+1
-9
src/c_api.cpp
src/c_api.cpp
+49
-39
src/treelearner/data_parallel_tree_learner.cpp
src/treelearner/data_parallel_tree_learner.cpp
+6
-3
src/treelearner/feature_histogram.hpp
src/treelearner/feature_histogram.hpp
+26
-1
src/treelearner/feature_parallel_tree_learner.cpp
src/treelearner/feature_parallel_tree_learner.cpp
+1
-1
src/treelearner/parallel_tree_learner.h
src/treelearner/parallel_tree_learner.h
+5
-4
src/treelearner/serial_tree_learner.cpp
src/treelearner/serial_tree_learner.cpp
+51
-20
src/treelearner/serial_tree_learner.h
src/treelearner/serial_tree_learner.h
+4
-2
src/treelearner/tree_learner.cpp
src/treelearner/tree_learner.cpp
+1
-1
src/treelearner/voting_parallel_tree_learner.cpp
src/treelearner/voting_parallel_tree_learner.cpp
+22
-15
No files found.
include/LightGBM/boosting.h
View file @
c2e94f17
...
...
@@ -51,12 +51,6 @@ public:
*/
virtual
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
=
0
;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
virtual
void
ResetShrinkageRate
(
double
shrinkage_rate
)
=
0
;
/*!
* \brief Add a validation data
* \param valid_data Validation data
...
...
include/LightGBM/tree_learner.h
View file @
c2e94f17
...
...
@@ -22,11 +22,17 @@ public:
virtual
~
TreeLearner
()
{}
/*!
* \brief Initialize tree learner with training dataset
and configs
* \brief Initialize tree learner with training dataset
* \param train_data The used training data
*/
virtual
void
Init
(
const
Dataset
*
train_data
)
=
0
;
/*!
* \brief Reset tree configs
* \param tree_config config of tree
*/
virtual
void
ResetConfig
(
const
TreeConfig
*
tree_config
)
=
0
;
/*!
* \brief training tree model on dataset
* \param gradients The first order gradients
...
...
@@ -58,9 +64,10 @@ public:
/*!
* \brief Create object of tree learner
* \param type Type of tree learner
* \param tree_config config of tree
*/
static
TreeLearner
*
CreateTreeLearner
(
TreeLearnerType
type
,
const
TreeConfig
&
tree_config
);
const
TreeConfig
*
tree_config
);
};
}
// namespace LightGBM
...
...
src/boosting/dart.hpp
View file @
c2e94f17
...
...
@@ -35,7 +35,6 @@ public:
void
Init
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
{
GBDT
::
Init
(
config
,
train_data
,
object_function
,
training_metrics
);
drop_rate_
=
gbdt_config_
->
drop_rate
;
shrinkage_rate_
=
1.0
;
random_for_drop_
=
Random
(
gbdt_config_
->
drop_seed
);
}
...
...
@@ -53,6 +52,14 @@ public:
return
false
;
}
}
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
GBDT
::
ResetTrainingData
(
config
,
train_data
,
object_function
,
training_metrics
);
shrinkage_rate_
=
1.0
;
random_for_drop_
=
Random
(
gbdt_config_
->
drop_seed
);
}
/*!
* \brief Get current training score
* \param out_len length of returned score
...
...
@@ -81,9 +88,9 @@ private:
drop_index_
.
clear
();
// select dropping tree indexes based on drop_rate
// if drop rate is too small, skip this step, drop one tree randomly
if
(
drop_rate
_
>
kEpsilon
)
{
if
(
gbdt_config_
->
drop_rate
>
kEpsilon
)
{
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
if
(
random_for_drop_
.
NextDouble
()
<
drop_rate
_
)
{
if
(
random_for_drop_
.
NextDouble
()
<
gbdt_config_
->
drop_rate
)
{
drop_index_
.
push_back
(
i
);
}
}
...
...
@@ -123,8 +130,6 @@ private:
}
/*! \brief The indexes of dropping trees */
std
::
vector
<
int
>
drop_index_
;
/*! \brief Dropping rate */
double
drop_rate_
;
/*! \brief Random generator, used to select dropping trees */
Random
random_for_drop_
;
/*! \brief Flag that the score is update on current iter or not*/
...
...
src/boosting/gbdt.cpp
View file @
c2e94f17
...
...
@@ -33,41 +33,57 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
max_feature_idx_
=
0
;
num_class_
=
config
->
num_class
;
train_data_
=
nullptr
;
gbdt_config_
=
nullptr
;
ResetTrainingData
(
config
,
train_data
,
object_function
,
training_metrics
);
}
void
GBDT
::
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
if
(
train_data
==
nullptr
)
{
return
;
}
auto
new_config
=
std
::
unique_ptr
<
BoostingConfig
>
(
new
BoostingConfig
(
*
config
));
if
(
train_data_
!=
nullptr
&&
!
train_data_
->
CheckAlign
(
*
train_data
))
{
Log
::
Fatal
(
"cannot reset training data, since new training data has different bin mappers"
);
}
gbdt_config_
=
config
;
early_stopping_round_
=
gbdt_config_
->
early_stopping_round
;
shrinkage_rate_
=
gbdt_config_
->
learning_rate
;
random_
=
Random
(
gbdt_config_
->
bagging_seed
);
// create tree learner
early_stopping_round_
=
new_config
->
early_stopping_round
;
shrinkage_rate_
=
new_config
->
learning_rate
;
random_
=
Random
(
new_config
->
bagging_seed
);
// create tree learner, only create once
if
(
gbdt_config_
==
nullptr
)
{
tree_learner_
.
clear
();
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
gbdt_config_
->
tree_learner_type
,
gbdt_config_
->
tree_config
));
new_tree_learner
->
Init
(
train_data
);
// init tree learner
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
new_config
->
tree_learner_type
,
&
new_config
->
tree_config
));
tree_learner_
.
push_back
(
std
::
move
(
new_tree_learner
));
}
tree_learner_
.
shrink_to_fit
();
object_function_
=
object_function
;
// push training metrics
training_metrics_
.
clear
();
for
(
const
auto
&
metric
:
training_metrics
)
{
training_metrics_
.
push_back
(
metric
);
}
training_metrics_
.
shrink_to_fit
();
// init tree learner
if
(
train_data_
!=
train_data
)
{
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
tree_learner_
[
i
]
->
Init
(
train_data
);
}
}
// reset config for tree learner
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
tree_learner_
[
i
]
->
ResetConfig
(
&
new_config
->
tree_config
);
}
object_function_
=
object_function
;
sigmoid_
=
-
1.0
f
;
if
(
object_function_
!=
nullptr
&&
std
::
string
(
object_function_
->
GetName
())
==
std
::
string
(
"binary"
))
{
// only binary classification need sigmoid transform
sigmoid_
=
gbdt
_config
_
->
sigmoid
;
sigmoid_
=
new
_config
->
sigmoid
;
}
if
(
train_data_
!=
train_data
)
{
// push training metrics
training_metrics_
.
clear
();
for
(
const
auto
&
metric
:
training_metrics
)
{
training_metrics_
.
push_back
(
metric
);
}
training_metrics_
.
shrink_to_fit
();
// not same training data, need reset score and others
// create score tracker
train_score_updater_
.
reset
(
new
ScoreUpdater
(
train_data
,
num_class_
));
...
...
@@ -88,8 +104,13 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
max_feature_idx_
=
train_data
->
num_total_features
()
-
1
;
// get label index
label_idx_
=
train_data
->
label_idx
();
}
if
(
train_data_
!=
train_data
||
gbdt_config_
==
nullptr
||
(
gbdt_config_
->
bagging_fraction
!=
new_config
->
bagging_fraction
))
{
// if need bagging, create buffer
if
(
gbdt
_config
_
->
bagging_fraction
<
1.0
&&
gbdt
_config
_
->
bagging_freq
>
0
)
{
if
(
new
_config
->
bagging_fraction
<
1.0
&&
new
_config
->
bagging_freq
>
0
)
{
out_of_bag_data_indices_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
bag_data_indices_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
}
else
{
...
...
@@ -100,6 +121,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
}
}
train_data_
=
train_data
;
gbdt_config_
.
reset
(
new_config
.
release
());
}
void
GBDT
::
AddValidDataset
(
const
Dataset
*
valid_data
,
...
...
src/boosting/gbdt.h
View file @
c2e94f17
...
...
@@ -68,14 +68,6 @@ public:
*/
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
void
ResetShrinkageRate
(
double
shrinkage_rate
)
override
{
shrinkage_rate_
=
shrinkage_rate
;
}
/*!
* \brief Adding a validation dataset
* \param valid_data Validation dataset
...
...
@@ -245,7 +237,7 @@ protected:
/*! \brief Pointer to training data */
const
Dataset
*
train_data_
;
/*! \brief Config of gbdt */
const
BoostingConfig
*
gbdt_config_
;
std
::
unique_ptr
<
BoostingConfig
>
gbdt_config_
;
/*! \brief Tree learner, will use this class to learn trees */
std
::
vector
<
std
::
unique_ptr
<
TreeLearner
>>
tree_learner_
;
/*! \brief Objective function */
...
...
src/c_api.cpp
View file @
c2e94f17
...
...
@@ -40,12 +40,14 @@ public:
Log
::
Warning
(
"continued train from model is not support for c_api, \
please use continued train with input score"
);
}
boosting_
.
reset
(
Boosting
::
CreateBoosting
(
config_
.
boosting_type
,
nullptr
));
train_data_
=
train_data
;
ConstructObjectAndTrainingMetrics
(
train_data
);
// initialize the boosting
boosting_
->
Init
(
&
config_
.
boosting_config
,
train_data
,
objective_fun_
.
get
(),
boosting_
->
Init
(
&
config_
.
boosting_config
,
nullptr
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
ResetTrainingData
(
train_data
);
}
void
MergeFrom
(
const
Booster
*
other
)
{
...
...
@@ -60,13 +62,34 @@ public:
void
ResetTrainingData
(
const
Dataset
*
train_data
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
train_data_
=
train_data
;
ConstructObjectAndTrainingMetrics
(
train_data_
);
// initialize the boosting
// create objective function
objective_fun_
.
reset
(
ObjectiveFunction
::
CreateObjectiveFunction
(
config_
.
objective_type
,
config_
.
objective_config
));
if
(
objective_fun_
==
nullptr
)
{
Log
::
Warning
(
"Using self-defined objective function"
);
}
// initialize the objective function
if
(
objective_fun_
!=
nullptr
)
{
objective_fun_
->
Init
(
train_data_
->
metadata
(),
train_data_
->
num_data
());
}
// create training metric
train_metric_
.
clear
();
for
(
auto
metric_type
:
config_
.
metric_types
)
{
auto
metric
=
std
::
unique_ptr
<
Metric
>
(
Metric
::
CreateMetric
(
metric_type
,
config_
.
metric_config
));
if
(
metric
==
nullptr
)
{
continue
;
}
metric
->
Init
(
train_data_
->
metadata
(),
train_data_
->
num_data
());
train_metric_
.
push_back
(
std
::
move
(
metric
));
}
train_metric_
.
shrink_to_fit
();
// reset the boosting
boosting_
->
ResetTrainingData
(
&
config_
.
boosting_config
,
train_data_
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
}
void
ResetConfig
(
const
char
*
parameters
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
if
(
param
.
count
(
"num_class"
))
{
Log
::
Fatal
(
"cannot change num class during training"
);
...
...
@@ -77,21 +100,28 @@ public:
if
(
param
.
count
(
"metric"
))
{
Log
::
Fatal
(
"cannot change metric during training"
);
}
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
config_
.
Set
(
param
);
}
if
(
config_
.
num_threads
>
0
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
omp_set_num_threads
(
config_
.
num_threads
);
}
if
(
param
.
size
()
==
1
&&
(
param
.
count
(
"learning_rate"
)
||
param
.
count
(
"shrinkage_rate"
)))
{
// only need to set learning rate
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
boosting_
->
ResetShrinkageRate
(
config_
.
boosting_config
.
learning_rate
);
}
else
{
ResetTrainingData
(
train_data_
);
if
(
param
.
count
(
"objective"
))
{
// create objective function
objective_fun_
.
reset
(
ObjectiveFunction
::
CreateObjectiveFunction
(
config_
.
objective_type
,
config_
.
objective_config
));
if
(
objective_fun_
==
nullptr
)
{
Log
::
Warning
(
"Using self-defined objective function"
);
}
// initialize the objective function
if
(
objective_fun_
!=
nullptr
)
{
objective_fun_
->
Init
(
train_data_
->
metadata
(),
train_data_
->
num_data
());
}
}
boosting_
->
ResetTrainingData
(
&
config_
.
boosting_config
,
train_data_
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
}
void
AddValidData
(
const
Dataset
*
valid_data
)
{
...
...
@@ -107,6 +137,7 @@ public:
boosting_
->
AddValidDataset
(
valid_data
,
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
valid_metrics_
.
back
()));
}
bool
TrainOneIter
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
boosting_
->
TrainOneIter
(
nullptr
,
nullptr
,
false
);
...
...
@@ -142,10 +173,12 @@ public:
}
std
::
vector
<
double
>
Predict
(
const
std
::
vector
<
std
::
pair
<
int
,
double
>>&
features
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
predictor_
->
GetPredictFunction
()(
features
);
}
void
PredictForFile
(
const
char
*
data_filename
,
const
char
*
result_filename
,
bool
data_has_header
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
predictor_
->
Predict
(
data_filename
,
result_filename
,
data_has_header
);
}
...
...
@@ -180,29 +213,6 @@ public:
private:
void
ConstructObjectAndTrainingMetrics
(
const
Dataset
*
train_data
)
{
// create objective function
objective_fun_
.
reset
(
ObjectiveFunction
::
CreateObjectiveFunction
(
config_
.
objective_type
,
config_
.
objective_config
));
if
(
objective_fun_
==
nullptr
)
{
Log
::
Warning
(
"Using self-defined objective functions"
);
}
// create training metric
train_metric_
.
clear
();
for
(
auto
metric_type
:
config_
.
metric_types
)
{
auto
metric
=
std
::
unique_ptr
<
Metric
>
(
Metric
::
CreateMetric
(
metric_type
,
config_
.
metric_config
));
if
(
metric
==
nullptr
)
{
continue
;
}
metric
->
Init
(
train_data
->
metadata
(),
train_data
->
num_data
());
train_metric_
.
push_back
(
std
::
move
(
metric
));
}
train_metric_
.
shrink_to_fit
();
// initialize the objective function
if
(
objective_fun_
!=
nullptr
)
{
objective_fun_
->
Init
(
train_data
->
metadata
(),
train_data
->
num_data
());
}
}
const
Dataset
*
train_data_
;
std
::
unique_ptr
<
Boosting
>
boosting_
;
/*! \brief All configs */
...
...
src/treelearner/data_parallel_tree_learner.cpp
View file @
c2e94f17
...
...
@@ -7,7 +7,7 @@
namespace
LightGBM
{
DataParallelTreeLearner
::
DataParallelTreeLearner
(
const
TreeConfig
&
tree_config
)
DataParallelTreeLearner
::
DataParallelTreeLearner
(
const
TreeConfig
*
tree_config
)
:
SerialTreeLearner
(
tree_config
)
{
}
...
...
@@ -37,10 +37,13 @@ void DataParallelTreeLearner::Init(const Dataset* train_data) {
buffer_write_start_pos_
.
resize
(
num_features_
);
buffer_read_start_pos_
.
resize
(
num_features_
);
global_data_count_in_leaf_
.
resize
(
tree_config_
.
num_leaves
);
global_data_count_in_leaf_
.
resize
(
tree_config_
->
num_leaves
);
}
void
DataParallelTreeLearner
::
ResetConfig
(
const
TreeConfig
*
tree_config
)
{
SerialTreeLearner
::
ResetConfig
(
tree_config
);
global_data_count_in_leaf_
.
resize
(
tree_config_
->
num_leaves
);
}
void
DataParallelTreeLearner
::
BeforeTrain
()
{
SerialTreeLearner
::
BeforeTrain
();
...
...
src/treelearner/feature_histogram.hpp
View file @
c2e94f17
...
...
@@ -276,6 +276,10 @@ public:
*/
void
set_is_splittable
(
bool
val
)
{
is_splittable_
=
val
;
}
void
ResetConfig
(
const
TreeConfig
*
tree_config
)
{
tree_config_
=
tree_config
;
}
private:
/*!
* \brief Calculate the split gain based on regularized sum_gradients and sum_hessians
...
...
@@ -336,6 +340,8 @@ public:
* \brief Constructor
*/
HistogramPool
()
{
cache_size_
=
0
;
total_size_
=
0
;
}
/*!
...
...
@@ -348,7 +354,7 @@ public:
* \param cache_size Max cache size
* \param total_size Total size will be used
*/
void
Reset
Size
(
int
cache_size
,
int
total_size
)
{
void
Reset
(
int
cache_size
,
int
total_size
)
{
cache_size_
=
cache_size
;
// at least need 2 bucket to store smaller leaf and larger leaf
CHECK
(
cache_size_
>=
2
);
...
...
@@ -382,6 +388,7 @@ public:
* \param obj_create_fun that used to generate object
*/
void
Fill
(
std
::
function
<
FeatureHistogram
*
()
>
obj_create_fun
)
{
fill_func_
=
obj_create_fun
;
pool_
.
clear
();
pool_
.
resize
(
cache_size_
);
for
(
int
i
=
0
;
i
<
cache_size_
;
++
i
)
{
...
...
@@ -389,6 +396,23 @@ public:
}
}
void
DynamicChangeSize
(
int
cache_size
,
int
total_size
)
{
int
old_cache_size
=
cache_size_
;
Reset
(
cache_size
,
total_size
);
pool_
.
resize
(
cache_size_
);
for
(
int
i
=
old_cache_size
;
i
<
cache_size_
;
++
i
)
{
pool_
[
i
].
reset
(
fill_func_
());
}
}
void
ResetConfig
(
const
TreeConfig
*
tree_config
,
int
array_size
)
{
for
(
int
i
=
0
;
i
<
cache_size_
;
++
i
)
{
auto
data_ptr
=
pool_
[
i
].
get
();
for
(
int
j
=
0
;
j
<
array_size
;
++
j
)
{
data_ptr
[
j
].
ResetConfig
(
tree_config
);
}
}
}
/*!
* \brief Get data for the specific index
* \param idx which index want to get
...
...
@@ -446,6 +470,7 @@ public:
private:
std
::
vector
<
std
::
unique_ptr
<
FeatureHistogram
[]
>>
pool_
;
std
::
function
<
FeatureHistogram
*
()
>
fill_func_
;
int
cache_size_
;
int
total_size_
;
bool
is_enough_
=
false
;
...
...
src/treelearner/feature_parallel_tree_learner.cpp
View file @
c2e94f17
...
...
@@ -6,7 +6,7 @@
namespace
LightGBM
{
FeatureParallelTreeLearner
::
FeatureParallelTreeLearner
(
const
TreeConfig
&
tree_config
)
FeatureParallelTreeLearner
::
FeatureParallelTreeLearner
(
const
TreeConfig
*
tree_config
)
:
SerialTreeLearner
(
tree_config
)
{
}
...
...
src/treelearner/parallel_tree_learner.h
View file @
c2e94f17
...
...
@@ -20,7 +20,7 @@ namespace LightGBM {
*/
class
FeatureParallelTreeLearner
:
public
SerialTreeLearner
{
public:
explicit
FeatureParallelTreeLearner
(
const
TreeConfig
&
tree_config
);
explicit
FeatureParallelTreeLearner
(
const
TreeConfig
*
tree_config
);
~
FeatureParallelTreeLearner
();
virtual
void
Init
(
const
Dataset
*
train_data
);
...
...
@@ -45,9 +45,10 @@ private:
*/
class
DataParallelTreeLearner
:
public
SerialTreeLearner
{
public:
explicit
DataParallelTreeLearner
(
const
TreeConfig
&
tree_config
);
explicit
DataParallelTreeLearner
(
const
TreeConfig
*
tree_config
);
~
DataParallelTreeLearner
();
void
Init
(
const
Dataset
*
train_data
)
override
;
void
ResetConfig
(
const
TreeConfig
*
tree_config
)
override
;
protected:
void
BeforeTrain
()
override
;
void
FindBestThresholds
()
override
;
...
...
@@ -96,10 +97,10 @@ private:
*/
class
VotingParallelTreeLearner
:
public
SerialTreeLearner
{
public:
explicit
VotingParallelTreeLearner
(
const
TreeConfig
&
tree_config
);
explicit
VotingParallelTreeLearner
(
const
TreeConfig
*
tree_config
);
~
VotingParallelTreeLearner
()
{
}
void
Init
(
const
Dataset
*
train_data
)
override
;
void
ResetConfig
(
const
TreeConfig
*
tree_config
)
override
;
protected:
void
BeforeTrain
()
override
;
bool
BeforeFindBestSplit
(
int
left_leaf
,
int
right_leaf
)
override
;
...
...
src/treelearner/serial_tree_learner.cpp
View file @
c2e94f17
...
...
@@ -7,9 +7,9 @@
namespace
LightGBM
{
SerialTreeLearner
::
SerialTreeLearner
(
const
TreeConfig
&
tree_config
)
SerialTreeLearner
::
SerialTreeLearner
(
const
TreeConfig
*
tree_config
)
:
tree_config_
(
tree_config
){
random_
=
Random
(
tree_config
.
feature_fraction_seed
);
random_
=
Random
(
tree_config
_
->
feature_fraction_seed
);
}
SerialTreeLearner
::~
SerialTreeLearner
()
{
...
...
@@ -22,32 +22,32 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
num_features_
=
train_data_
->
num_features
();
int
max_cache_size
=
0
;
// Get the max size of pool
if
(
tree_config_
.
histogram_pool_size
<
0
)
{
max_cache_size
=
tree_config_
.
num_leaves
;
if
(
tree_config_
->
histogram_pool_size
<
=
0
)
{
max_cache_size
=
tree_config_
->
num_leaves
;
}
else
{
size_t
total_histogram_size
=
0
;
for
(
int
i
=
0
;
i
<
train_data_
->
num_features
();
++
i
)
{
total_histogram_size
+=
sizeof
(
HistogramBinEntry
)
*
train_data_
->
FeatureAt
(
i
)
->
num_bin
();
}
max_cache_size
=
static_cast
<
int
>
(
tree_config_
.
histogram_pool_size
*
1024
*
1024
/
total_histogram_size
);
max_cache_size
=
static_cast
<
int
>
(
tree_config_
->
histogram_pool_size
*
1024
*
1024
/
total_histogram_size
);
}
// at least need 2 leaves
max_cache_size
=
std
::
max
(
2
,
max_cache_size
);
max_cache_size
=
std
::
min
(
max_cache_size
,
tree_config_
.
num_leaves
);
histogram_pool_
.
Reset
Size
(
max_cache_size
,
tree_config_
.
num_leaves
);
max_cache_size
=
std
::
min
(
max_cache_size
,
tree_config_
->
num_leaves
);
histogram_pool_
.
Reset
(
max_cache_size
,
tree_config_
->
num_leaves
);
auto
histogram_create_function
=
[
this
]()
{
auto
tmp_histogram_array
=
std
::
unique_ptr
<
FeatureHistogram
[]
>
(
new
FeatureHistogram
[
train_data_
->
num_features
()]);
for
(
int
j
=
0
;
j
<
train_data_
->
num_features
();
++
j
)
{
tmp_histogram_array
[
j
].
Init
(
train_data_
->
FeatureAt
(
j
),
j
,
&
tree_config_
);
j
,
tree_config_
);
}
return
tmp_histogram_array
.
release
();
};
histogram_pool_
.
Fill
(
histogram_create_function
);
// push split information for all leaves
best_split_per_leaf_
.
resize
(
tree_config_
.
num_leaves
);
best_split_per_leaf_
.
resize
(
tree_config_
->
num_leaves
);
// initialize ordered_bins_ with nullptr
ordered_bins_
.
resize
(
num_features_
);
...
...
@@ -69,7 +69,7 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
larger_leaf_splits_
.
reset
(
new
LeafSplits
(
train_data_
->
num_features
(),
train_data_
->
num_data
()));
// initialize data partition
data_partition_
.
reset
(
new
DataPartition
(
num_data_
,
tree_config_
.
num_leaves
));
data_partition_
.
reset
(
new
DataPartition
(
num_data_
,
tree_config_
->
num_leaves
));
is_feature_used_
.
resize
(
num_features_
);
...
...
@@ -84,19 +84,49 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
}
void
SerialTreeLearner
::
ResetConfig
(
const
TreeConfig
*
tree_config
)
{
if
(
tree_config_
->
num_leaves
!=
tree_config
->
num_leaves
)
{
tree_config_
=
tree_config
;
int
max_cache_size
=
0
;
// Get the max size of pool
if
(
tree_config
->
histogram_pool_size
<=
0
)
{
max_cache_size
=
tree_config_
->
num_leaves
;
}
else
{
size_t
total_histogram_size
=
0
;
for
(
int
i
=
0
;
i
<
train_data_
->
num_features
();
++
i
)
{
total_histogram_size
+=
sizeof
(
HistogramBinEntry
)
*
train_data_
->
FeatureAt
(
i
)
->
num_bin
();
}
max_cache_size
=
static_cast
<
int
>
(
tree_config_
->
histogram_pool_size
*
1024
*
1024
/
total_histogram_size
);
}
// at least need 2 leaves
max_cache_size
=
std
::
max
(
2
,
max_cache_size
);
max_cache_size
=
std
::
min
(
max_cache_size
,
tree_config_
->
num_leaves
);
histogram_pool_
.
DynamicChangeSize
(
max_cache_size
,
tree_config_
->
num_leaves
);
// push split information for all leaves
best_split_per_leaf_
.
resize
(
tree_config_
->
num_leaves
);
data_partition_
.
reset
(
new
DataPartition
(
num_data_
,
tree_config_
->
num_leaves
));
}
else
{
tree_config_
=
tree_config
;
}
histogram_pool_
.
ResetConfig
(
tree_config_
,
train_data_
->
num_features
());
random_
=
Random
(
tree_config_
->
feature_fraction_seed
);
}
Tree
*
SerialTreeLearner
::
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
)
{
gradients_
=
gradients
;
hessians_
=
hessians
;
// some initial works before training
BeforeTrain
();
auto
tree
=
std
::
unique_ptr
<
Tree
>
(
new
Tree
(
tree_config_
.
num_leaves
));
auto
tree
=
std
::
unique_ptr
<
Tree
>
(
new
Tree
(
tree_config_
->
num_leaves
));
// save pointer to last trained tree
last_trained_tree_
=
tree
.
get
();
// root leaf
int
left_leaf
=
0
;
// only root leaf can be splitted on first time
int
right_leaf
=
-
1
;
for
(
int
split
=
0
;
split
<
tree_config_
.
num_leaves
-
1
;
split
++
)
{
for
(
int
split
=
0
;
split
<
tree_config_
->
num_leaves
-
1
;
split
++
)
{
// some initial works before finding best split
if
(
BeforeFindBestSplit
(
left_leaf
,
right_leaf
))
{
// find best threshold for every feature
...
...
@@ -121,6 +151,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
}
void
SerialTreeLearner
::
BeforeTrain
()
{
// reset histogram pool
histogram_pool_
.
ResetMap
();
// initialize used features
...
...
@@ -128,7 +159,7 @@ void SerialTreeLearner::BeforeTrain() {
is_feature_used_
[
i
]
=
false
;
}
// Get used feature at current tree
int
used_feature_cnt
=
static_cast
<
int
>
(
num_features_
*
tree_config_
.
feature_fraction
);
int
used_feature_cnt
=
static_cast
<
int
>
(
num_features_
*
tree_config_
->
feature_fraction
);
auto
used_feature_indices
=
random_
.
Sample
(
num_features_
,
used_feature_cnt
);
for
(
auto
idx
:
used_feature_indices
)
{
is_feature_used_
[
idx
]
=
true
;
...
...
@@ -138,7 +169,7 @@ void SerialTreeLearner::BeforeTrain() {
data_partition_
->
Init
();
// reset the splits for leaves
for
(
int
i
=
0
;
i
<
tree_config_
.
num_leaves
;
++
i
)
{
for
(
int
i
=
0
;
i
<
tree_config_
->
num_leaves
;
++
i
)
{
best_split_per_leaf_
[
i
].
Reset
();
}
...
...
@@ -177,7 +208,7 @@ void SerialTreeLearner::BeforeTrain() {
#pragma omp parallel for schedule(guided)
for
(
int
i
=
0
;
i
<
num_features_
;
++
i
)
{
if
(
ordered_bins_
[
i
]
!=
nullptr
)
{
ordered_bins_
[
i
]
->
Init
(
nullptr
,
tree_config_
.
num_leaves
);
ordered_bins_
[
i
]
->
Init
(
nullptr
,
tree_config_
->
num_leaves
);
}
}
}
else
{
...
...
@@ -196,7 +227,7 @@ void SerialTreeLearner::BeforeTrain() {
#pragma omp parallel for schedule(guided)
for
(
int
i
=
0
;
i
<
num_features_
;
++
i
)
{
if
(
ordered_bins_
[
i
]
!=
nullptr
)
{
ordered_bins_
[
i
]
->
Init
(
is_data_in_leaf_
.
data
(),
tree_config_
.
num_leaves
);
ordered_bins_
[
i
]
->
Init
(
is_data_in_leaf_
.
data
(),
tree_config_
->
num_leaves
);
}
}
}
...
...
@@ -205,9 +236,9 @@ void SerialTreeLearner::BeforeTrain() {
bool
SerialTreeLearner
::
BeforeFindBestSplit
(
int
left_leaf
,
int
right_leaf
)
{
// check depth of current leaf
if
(
tree_config_
.
max_depth
>
0
)
{
if
(
tree_config_
->
max_depth
>
0
)
{
// only need to check left leaf, since right leaf is in same level of left leaf
if
(
last_trained_tree_
->
leaf_depth
(
left_leaf
)
>=
tree_config_
.
max_depth
)
{
if
(
last_trained_tree_
->
leaf_depth
(
left_leaf
)
>=
tree_config_
->
max_depth
)
{
best_split_per_leaf_
[
left_leaf
].
gain
=
kMinScore
;
if
(
right_leaf
>=
0
)
{
best_split_per_leaf_
[
right_leaf
].
gain
=
kMinScore
;
...
...
@@ -218,8 +249,8 @@ bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
data_size_t
num_data_in_left_child
=
GetGlobalDataCountInLeaf
(
left_leaf
);
data_size_t
num_data_in_right_child
=
GetGlobalDataCountInLeaf
(
right_leaf
);
// no enough data to continue
if
(
num_data_in_right_child
<
static_cast
<
data_size_t
>
(
tree_config_
.
min_data_in_leaf
*
2
)
&&
num_data_in_left_child
<
static_cast
<
data_size_t
>
(
tree_config_
.
min_data_in_leaf
*
2
))
{
if
(
num_data_in_right_child
<
static_cast
<
data_size_t
>
(
tree_config_
->
min_data_in_leaf
*
2
)
&&
num_data_in_left_child
<
static_cast
<
data_size_t
>
(
tree_config_
->
min_data_in_leaf
*
2
))
{
best_split_per_leaf_
[
left_leaf
].
gain
=
kMinScore
;
if
(
right_leaf
>=
0
)
{
best_split_per_leaf_
[
right_leaf
].
gain
=
kMinScore
;
...
...
src/treelearner/serial_tree_learner.h
View file @
c2e94f17
...
...
@@ -26,12 +26,14 @@ namespace LightGBM {
*/
class
SerialTreeLearner
:
public
TreeLearner
{
public:
explicit
SerialTreeLearner
(
const
TreeConfig
&
tree_config
);
explicit
SerialTreeLearner
(
const
TreeConfig
*
tree_config
);
~
SerialTreeLearner
();
void
Init
(
const
Dataset
*
train_data
)
override
;
void
ResetConfig
(
const
TreeConfig
*
tree_config
)
override
;
Tree
*
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
)
override
;
void
SetBaggingData
(
const
data_size_t
*
used_indices
,
data_size_t
num_data
)
override
{
...
...
@@ -153,7 +155,7 @@ protected:
/*! \brief used to cache historical histogram to speed up*/
HistogramPool
histogram_pool_
;
/*! \brief config of tree learner*/
const
TreeConfig
&
tree_config_
;
const
TreeConfig
*
tree_config_
;
};
...
...
src/treelearner/tree_learner.cpp
View file @
c2e94f17
...
...
@@ -5,7 +5,7 @@
namespace
LightGBM
{
TreeLearner
*
TreeLearner
::
CreateTreeLearner
(
TreeLearnerType
type
,
const
TreeConfig
&
tree_config
)
{
TreeLearner
*
TreeLearner
::
CreateTreeLearner
(
TreeLearnerType
type
,
const
TreeConfig
*
tree_config
)
{
if
(
type
==
TreeLearnerType
::
kSerialTreeLearner
)
{
return
new
SerialTreeLearner
(
tree_config
);
}
else
if
(
type
==
TreeLearnerType
::
kFeatureParallelTreelearner
)
{
...
...
src/treelearner/voting_parallel_tree_learner.cpp
View file @
c2e94f17
...
...
@@ -9,9 +9,9 @@
namespace
LightGBM
{
VotingParallelTreeLearner
::
VotingParallelTreeLearner
(
const
TreeConfig
&
tree_config
)
VotingParallelTreeLearner
::
VotingParallelTreeLearner
(
const
TreeConfig
*
tree_config
)
:
SerialTreeLearner
(
tree_config
)
{
top_k_
=
tree_config
.
top_k
;
top_k_
=
tree_config
_
->
top_k
;
}
void
VotingParallelTreeLearner
::
Init
(
const
Dataset
*
train_data
)
{
...
...
@@ -44,34 +44,41 @@ void VotingParallelTreeLearner::Init(const Dataset* train_data) {
smaller_buffer_read_start_pos_
.
resize
(
num_features_
);
larger_buffer_read_start_pos_
.
resize
(
num_features_
);
global_data_count_in_leaf_
.
resize
(
tree_config_
.
num_leaves
);
global_data_count_in_leaf_
.
resize
(
tree_config_
->
num_leaves
);
smaller_leaf_splits_global_
.
reset
(
new
LeafSplits
(
train_data_
->
num_features
(),
train_data_
->
num_data
()));
larger_leaf_splits_global_
.
reset
(
new
LeafSplits
(
train_data_
->
num_features
(),
train_data_
->
num_data
()));
local_tree_config_
=
tree_config_
;
local_tree_config_
=
*
tree_config_
;
local_tree_config_
.
min_data_in_leaf
/=
num_machines_
;
local_tree_config_
.
min_sum_hessian_in_leaf
/=
num_machines_
;
auto
histogram_create_function
=
[
this
]()
{
auto
tmp_histogram_array
=
std
::
unique_ptr
<
FeatureHistogram
[]
>
(
new
FeatureHistogram
[
train_data_
->
num_features
()]);
for
(
int
j
=
0
;
j
<
train_data_
->
num_features
();
++
j
)
{
tmp_histogram_array
[
j
].
Init
(
train_data_
->
FeatureAt
(
j
),
j
,
&
local_tree_config_
);
}
return
tmp_histogram_array
.
release
();
};
histogram_pool_
.
Fill
(
histogram_create_function
);
histogram_pool_
.
ResetConfig
(
&
local_tree_config_
,
train_data_
->
num_features
());
// initialize histograms for global
smaller_leaf_histogram_array_global_
.
reset
(
new
FeatureHistogram
[
num_features_
]);
larger_leaf_histogram_array_global_
.
reset
(
new
FeatureHistogram
[
num_features_
]);
for
(
int
j
=
0
;
j
<
num_features_
;
++
j
)
{
smaller_leaf_histogram_array_global_
[
j
].
Init
(
train_data_
->
FeatureAt
(
j
),
j
,
&
tree_config_
);
larger_leaf_histogram_array_global_
[
j
].
Init
(
train_data_
->
FeatureAt
(
j
),
j
,
&
tree_config_
);
smaller_leaf_histogram_array_global_
[
j
].
Init
(
train_data_
->
FeatureAt
(
j
),
j
,
tree_config_
);
larger_leaf_histogram_array_global_
[
j
].
Init
(
train_data_
->
FeatureAt
(
j
),
j
,
tree_config_
);
}
}
void
VotingParallelTreeLearner
::
ResetConfig
(
const
TreeConfig
*
tree_config
)
{
SerialTreeLearner
::
ResetConfig
(
tree_config
);
local_tree_config_
=
*
tree_config_
;
local_tree_config_
.
min_data_in_leaf
/=
num_machines_
;
local_tree_config_
.
min_sum_hessian_in_leaf
/=
num_machines_
;
histogram_pool_
.
ResetConfig
(
&
local_tree_config_
,
train_data_
->
num_features
());
global_data_count_in_leaf_
.
resize
(
tree_config_
->
num_leaves
);
for
(
int
j
=
0
;
j
<
num_features_
;
++
j
)
{
smaller_leaf_histogram_array_global_
[
j
].
ResetConfig
(
tree_config_
);
larger_leaf_histogram_array_global_
[
j
].
ResetConfig
(
tree_config_
);
}
}
void
VotingParallelTreeLearner
::
BeforeTrain
()
{
SerialTreeLearner
::
BeforeTrain
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment