Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
4948dcc5
Commit
4948dcc5
authored
Jan 23, 2017
by
Guolin Ke
Browse files
improve bagging speed
parent
c8fbd42b
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
190 additions
and
229 deletions
+190
-229
include/LightGBM/bin.h
include/LightGBM/bin.h
+2
-0
include/LightGBM/dataset.h
include/LightGBM/dataset.h
+3
-1
include/LightGBM/feature.h
include/LightGBM/feature.h
+3
-0
include/LightGBM/utils/random.h
include/LightGBM/utils/random.h
+8
-16
src/boosting/dart.hpp
src/boosting/dart.hpp
+3
-3
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+17
-15
src/boosting/gbdt.h
src/boosting/gbdt.h
+1
-3
src/c_api.cpp
src/c_api.cpp
+3
-6
src/io/dataset.cpp
src/io/dataset.cpp
+22
-16
src/io/dense_bin.hpp
src/io/dense_bin.hpp
+7
-0
src/io/sparse_bin.hpp
src/io/sparse_bin.hpp
+5
-0
src/treelearner/data_parallel_tree_learner.cpp
src/treelearner/data_parallel_tree_learner.cpp
+15
-20
src/treelearner/feature_histogram.hpp
src/treelearner/feature_histogram.hpp
+59
-106
src/treelearner/serial_tree_learner.cpp
src/treelearner/serial_tree_learner.cpp
+22
-24
src/treelearner/voting_parallel_tree_learner.cpp
src/treelearner/voting_parallel_tree_learner.cpp
+20
-19
No files found.
include/LightGBM/bin.h
View file @
4948dcc5
...
...
@@ -289,6 +289,8 @@ public:
/*! \brief Number of all data */
virtual
data_size_t
num_data
()
const
=
0
;
virtual
void
ReSize
(
data_size_t
num_data
)
=
0
;
/*!
* \brief Construct histogram of this feature,
* Note: We use ordered_gradients and ordered_hessians to improve cache hit chance
...
...
include/LightGBM/dataset.h
View file @
4948dcc5
...
...
@@ -328,7 +328,9 @@ public:
return
used_feature_map_
[
col_idx
];
}
Dataset
*
Subset
(
const
data_size_t
*
used_indices
,
data_size_t
num_used_indices
,
bool
is_enable_sparse
,
bool
need_meta_data
)
const
;
void
ReSize
(
data_size_t
num_data
);
void
CopySubset
(
const
Dataset
*
fullset
,
const
data_size_t
*
used_indices
,
data_size_t
num_used_indices
,
bool
need_meta_data
);
LIGHTGBM_EXPORT
void
FinishLoad
();
...
...
include/LightGBM/feature.h
View file @
4948dcc5
...
...
@@ -83,6 +83,9 @@ public:
inline
void
PushBin
(
int
tid
,
data_size_t
line_idx
,
unsigned
int
bin
)
{
bin_data_
->
Push
(
tid
,
line_idx
,
bin
);
}
void
ReSize
(
data_size_t
num_data
)
{
bin_data_
->
ReSize
(
num_data
);
}
inline
void
FinishLoad
()
{
bin_data_
->
FinishLoad
();
}
/*! \brief Index of this feature */
inline
int
feature_index
()
const
{
return
feature_index_
;
}
...
...
include/LightGBM/utils/random.h
View file @
4948dcc5
...
...
@@ -35,15 +35,15 @@ public:
* \return The random integer between [lower_bound, upper_bound)
*/
inline
int
NextInt
(
int
lower_bound
,
int
upper_bound
)
{
return
(
next
())
%
(
upper_bound
-
lower_bound
)
+
lower_bound
;
return
(
fastrand
())
%
(
upper_bound
-
lower_bound
)
+
lower_bound
;
}
/*!
* \brief Generate random float data
* \return The random float between [0.0, 1.0)
*/
inline
double
NextDouble
()
{
inline
float
NextFloat
()
{
// get random float in [0,1)
return
static_cast
<
double
>
(
next
()
%
2047
)
/
2047
.0
f
;
return
static_cast
<
float
>
(
fastrand
())
/
(
32768
.0
f
)
;
}
/*!
* \brief Sample K data from {0,1,...,N-1}
...
...
@@ -58,26 +58,18 @@ public:
}
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
double
prob
=
(
K
-
ret
.
size
())
/
static_cast
<
double
>
(
N
-
i
);
if
(
Next
Double
()
<
prob
)
{
if
(
Next
Float
()
<
prob
)
{
ret
.
push_back
(
i
);
}
}
return
ret
;
}
private:
unsigned
next
()
{
x
^=
x
<<
16
;
x
^=
x
>>
5
;
x
^=
x
<<
1
;
auto
t
=
x
;
x
=
y
;
y
=
z
;
z
=
t
^
x
^
y
;
return
z
;
inline
int
fastrand
()
{
x
=
(
214013
*
x
+
2531011
);
return
(
x
>>
16
)
&
0x7FFF
;
}
unsigned
int
x
=
123456789
;
unsigned
int
y
=
362436069
;
unsigned
int
z
=
521288629
;
int
x
=
123456789
;
};
...
...
src/boosting/dart.hpp
View file @
4948dcc5
...
...
@@ -78,7 +78,7 @@ private:
*/
void
DroppingTrees
()
{
drop_index_
.
clear
();
bool
is_skip
=
random_for_drop_
.
Next
Double
()
<
gbdt_config_
->
skip_drop
;
bool
is_skip
=
random_for_drop_
.
Next
Float
()
<
gbdt_config_
->
skip_drop
;
// select dropping tree indexes based on drop_rate and tree weights
if
(
!
is_skip
)
{
double
drop_rate
=
gbdt_config_
->
drop_rate
;
...
...
@@ -88,7 +88,7 @@ private:
drop_rate
=
std
::
min
(
drop_rate
,
gbdt_config_
->
max_drop
*
inv_average_weight
/
sum_weight_
);
}
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
if
(
random_for_drop_
.
Next
Double
()
<
drop_rate
*
tree_weight_
[
i
]
*
inv_average_weight
)
{
if
(
random_for_drop_
.
Next
Float
()
<
drop_rate
*
tree_weight_
[
i
]
*
inv_average_weight
)
{
drop_index_
.
push_back
(
i
);
}
}
...
...
@@ -97,7 +97,7 @@ private:
drop_rate
=
std
::
min
(
drop_rate
,
gbdt_config_
->
max_drop
/
static_cast
<
double
>
(
iter_
));
}
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
if
(
random_for_drop_
.
Next
Double
()
<
drop_rate
)
{
if
(
random_for_drop_
.
Next
Float
()
<
drop_rate
)
{
drop_index_
.
push_back
(
i
);
}
}
...
...
src/boosting/gbdt.cpp
View file @
4948dcc5
...
...
@@ -46,9 +46,6 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
num_iteration_for_pred_
=
0
;
max_feature_idx_
=
0
;
num_class_
=
config
->
num_class
;
for
(
int
i
=
0
;
i
<
num_threads_
;
++
i
)
{
random_
.
emplace_back
(
config
->
bagging_seed
+
i
);
}
train_data_
=
nullptr
;
gbdt_config_
=
nullptr
;
tree_learner_
=
nullptr
;
...
...
@@ -136,7 +133,9 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
right_write_pos_buf_
.
resize
(
num_threads_
);
double
average_bag_rate
=
new_config
->
bagging_fraction
/
new_config
->
bagging_freq
;
is_use_subset_
=
false
;
if
(
average_bag_rate
<
0.3
)
{
if
(
average_bag_rate
<
0.5
)
{
tmp_subset_
.
reset
(
new
Dataset
(
bag_data_cnt_
));
tmp_subset_
->
CopyFeatureMapperFrom
(
train_data
,
false
);
is_use_subset_
=
true
;
Log
::
Debug
(
"use subset for bagging"
);
}
...
...
@@ -187,20 +186,20 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
valid_metrics_
.
back
().
shrink_to_fit
();
}
data_size_t
GBDT
::
BaggingHelper
(
data_size_t
start
,
data_size_t
cnt
,
data_size_t
*
buffer
){
const
int
tid
=
omp_get_thread_num
();
data_size_t
GBDT
::
BaggingHelper
(
Random
&
cur_rand
,
data_size_t
start
,
data_size_t
cnt
,
data_size_t
*
buffer
){
data_size_t
bag_data_cnt
=
static_cast
<
data_size_t
>
(
gbdt_config_
->
bagging_fraction
*
cnt
);
data_size_t
cur_left_cnt
=
0
;
data_size_t
cur_right_cnt
=
0
;
auto
right_buffer
=
buffer
+
bag_data_cnt
;
// random bagging, minimal unit is one record
for
(
data_size_t
i
=
0
;
i
<
cnt
;
++
i
)
{
double
prob
=
(
bag_data_cnt
-
cur_left_cnt
)
/
static_cast
<
double
>
(
cnt
-
i
);
if
(
random_
[
tid
].
NextDouble
()
<
prob
)
{
float
prob
=
(
bag_data_cnt
-
cur_left_cnt
)
/
static_cast
<
float
>
(
cnt
-
i
);
if
(
cur_rand
.
NextFloat
()
<
prob
)
{
buffer
[
cur_left_cnt
++
]
=
start
+
i
;
}
else
{
buffer
[
bag_data_cnt
+
cur_right_cnt
++
]
=
start
+
i
;
right_
buffer
[
cur_right_cnt
++
]
=
start
+
i
;
}
}
CHECK
(
buffer
[
bag_data_cnt
-
1
]
>
buffer
[
bag_data_cnt
]);
...
...
@@ -208,14 +207,16 @@ data_size_t GBDT::BaggingHelper(data_size_t start, data_size_t cnt, data_size_t*
return
cur_left_cnt
;
}
void
GBDT
::
Bagging
(
int
iter
)
{
// if need bagging
if
(
bag_data_cnt_
<
num_data_
&&
iter
%
gbdt_config_
->
bagging_freq
==
0
)
{
const
data_size_t
min_inner_size
=
1000
0
;
const
data_size_t
min_inner_size
=
1000
;
data_size_t
inner_size
=
(
num_data_
+
num_threads_
-
1
)
/
num_threads_
;
if
(
inner_size
<
min_inner_size
)
{
inner_size
=
min_inner_size
;
}
#pragma omp parallel for schedule(static,
1)
#pragma omp parallel for schedule(static,1)
for
(
int
i
=
0
;
i
<
num_threads_
;
++
i
)
{
left_cnts_buf_
[
i
]
=
0
;
right_cnts_buf_
[
i
]
=
0
;
...
...
@@ -223,7 +224,8 @@ void GBDT::Bagging(int iter) {
if
(
cur_start
>
num_data_
)
{
continue
;
}
data_size_t
cur_cnt
=
inner_size
;
if
(
cur_start
+
cur_cnt
>
num_data_
)
{
cur_cnt
=
num_data_
-
cur_start
;
}
data_size_t
cur_left_count
=
BaggingHelper
(
cur_start
,
cur_cnt
,
tmp_indices_
.
data
()
+
cur_start
);
Random
cur_rand
(
gbdt_config_
->
bagging_seed
+
iter
*
num_threads_
+
i
);
data_size_t
cur_left_count
=
BaggingHelper
(
cur_rand
,
cur_start
,
cur_cnt
,
tmp_indices_
.
data
()
+
cur_start
);
offsets_buf_
[
i
]
=
cur_start
;
left_cnts_buf_
[
i
]
=
cur_left_count
;
right_cnts_buf_
[
i
]
=
cur_cnt
-
cur_left_count
;
...
...
@@ -256,8 +258,8 @@ void GBDT::Bagging(int iter) {
tree_learner_
->
SetBaggingData
(
bag_data_indices_
.
data
(),
bag_data_cnt_
);
}
else
{
// get subset
tmp_subset_
.
reset
(
train_data_
->
Subset
(
bag_data_indices_
.
data
(),
bag_data_cnt_
,
false
,
false
)
);
tmp_subset_
->
FinishLoad
(
);
tmp_subset_
->
ReSize
(
bag_data_cnt_
);
tmp_subset_
->
CopySubset
(
train_data_
,
bag_data_indices_
.
data
(),
bag_data_cnt_
,
false
);
tree_learner_
->
ResetTrainingData
(
tmp_subset_
.
get
());
}
}
...
...
src/boosting/gbdt.h
View file @
4948dcc5
...
...
@@ -235,7 +235,7 @@ protected:
* \param buffer output buffer
* \return count of left size
*/
virtual
data_size_t
BaggingHelper
(
data_size_t
start
,
data_size_t
cnt
,
data_size_t
*
buffer
);
data_size_t
BaggingHelper
(
Random
&
cur_rand
,
data_size_t
start
,
data_size_t
cnt
,
data_size_t
*
buffer
);
/*!
* \brief updating score for out-of-bag data.
* Data should be update since we may re-bagging data on training
...
...
@@ -308,8 +308,6 @@ protected:
data_size_t
num_data_
;
/*! \brief Number of classes */
int
num_class_
;
/*! \brief Random generator, used for bagging */
std
::
vector
<
Random
>
random_
;
/*!
* \brief Sigmoid parameter, used for prediction.
* if > 0 means output score will transform by sigmoid function
...
...
src/c_api.cpp
View file @
4948dcc5
...
...
@@ -485,12 +485,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
IOConfig
io_config
;
io_config
.
Set
(
param
);
auto
full_dataset
=
reinterpret_cast
<
const
Dataset
*>
(
handle
);
auto
ret
=
std
::
unique_ptr
<
Dataset
>
(
full_dataset
->
Subset
(
used_row_indices
,
num_used_row_indices
,
io_config
.
is_enable_sparse
,
true
));
ret
->
FinishLoad
();
auto
ret
=
std
::
unique_ptr
<
Dataset
>
(
new
Dataset
(
num_used_row_indices
));
ret
->
CopyFeatureMapperFrom
(
full_dataset
,
io_config
.
is_enable_sparse
);
ret
->
CopySubset
(
full_dataset
,
used_row_indices
,
num_used_row_indices
,
true
);
*
out
=
ret
.
release
();
API_END
();
}
...
...
src/io/dataset.cpp
View file @
4948dcc5
...
...
@@ -40,38 +40,44 @@ void Dataset::FinishLoad() {
void
Dataset
::
CopyFeatureMapperFrom
(
const
Dataset
*
dataset
,
bool
is_enable_sparse
)
{
features_
.
clear
();
num_features_
=
dataset
->
num_features_
;
// copy feature bin mapper data
for
(
const
auto
&
feature
:
dataset
->
features_
)
{
features_
.
emplace_back
(
std
::
unique_ptr
<
Feature
>
(
new
Feature
(
feature
->
feature_index
(),
new
BinMapper
(
*
feature
->
bin_mapper
()),
for
(
int
i
=
0
;
i
<
num_features_
;
++
i
){
features_
.
emplace_back
(
new
Feature
(
dataset
->
features_
[
i
]
->
feature_index
(),
new
BinMapper
(
*
(
dataset
->
features_
[
i
]
->
bin_mapper
())),
num_data_
,
is_enable_sparse
)
));
is_enable_sparse
));
}
features_
.
shrink_to_fit
();
used_feature_map_
=
dataset
->
used_feature_map_
;
num_features_
=
static_cast
<
int
>
(
features_
.
size
());
num_total_features_
=
dataset
->
num_total_features_
;
feature_names_
=
dataset
->
feature_names_
;
label_idx_
=
dataset
->
label_idx_
;
}
Dataset
*
Dataset
::
Subset
(
const
data_size_t
*
used_indices
,
data_size_t
num_used_indices
,
bool
is_enable_sparse
,
bool
need_meta_data
)
const
{
auto
ret
=
std
::
unique_ptr
<
Dataset
>
(
new
Dataset
(
num_used_indices
));
ret
->
CopyFeatureMapperFrom
(
this
,
is_enable_sparse
);
void
Dataset
::
ReSize
(
data_size_t
num_data
)
{
if
(
num_data_
!=
num_data
)
{
num_data_
=
num_data
;
#pragma omp parallel for schedule(guided)
for
(
int
fidx
=
0
;
fidx
<
num_features_
;
++
fidx
)
{
features_
[
fidx
]
->
ReSize
(
num_data_
);
}
}
}
void
Dataset
::
CopySubset
(
const
Dataset
*
fullset
,
const
data_size_t
*
used_indices
,
data_size_t
num_used_indices
,
bool
need_meta_data
)
{
CHECK
(
num_used_indices
==
num_data_
);
#pragma omp parallel for schedule(guided)
for
(
int
fidx
=
0
;
fidx
<
num_features_
;
++
fidx
)
{
auto
iterator
=
features_
[
fidx
]
->
bin_data
()
->
GetIterator
(
0
);
auto
iterator
=
fullset
->
features_
[
fidx
]
->
bin_data
()
->
GetIterator
(
used_indices
[
0
]
);
for
(
data_size_t
i
=
0
;
i
<
num_used_indices
;
++
i
)
{
ret
->
features_
[
fidx
]
->
PushBin
(
0
,
i
,
iterator
->
Get
(
used_indices
[
i
]));
features_
[
fidx
]
->
PushBin
(
0
,
i
,
iterator
->
Get
(
used_indices
[
i
]));
}
}
if
(
need_meta_data
)
{
ret
->
metadata_
.
Init
(
metadata_
,
used_indices
,
num_used_indices
);
metadata_
.
Init
(
metadata_
,
used_indices
,
num_used_indices
);
}
return
ret
.
release
();
FinishLoad
();
}
bool
Dataset
::
SetFloatField
(
const
char
*
field_name
,
const
float
*
field_data
,
data_size_t
num_element
)
{
...
...
src/io/dense_bin.hpp
View file @
4948dcc5
...
...
@@ -33,6 +33,13 @@ public:
data_
[
idx
]
=
static_cast
<
VAL_T
>
(
value
);
}
void
ReSize
(
data_size_t
num_data
)
override
{
if
(
num_data_
!=
num_data
)
{
num_data_
=
num_data
;
data_
.
resize
(
num_data_
);
}
}
inline
uint32_t
Get
(
data_size_t
idx
)
const
{
return
static_cast
<
uint32_t
>
(
data_
[
idx
]);
}
...
...
src/io/sparse_bin.hpp
View file @
4948dcc5
...
...
@@ -67,6 +67,11 @@ public:
}
~
SparseBin
()
{
}
void
ReSize
(
data_size_t
num_data
)
override
{
num_data_
=
num_data
;
}
void
Push
(
int
tid
,
data_size_t
idx
,
uint32_t
value
)
override
{
...
...
src/treelearner/data_parallel_tree_learner.cpp
View file @
4948dcc5
...
...
@@ -131,22 +131,19 @@ void DataParallelTreeLearner::FindBestThresholds() {
if
((
!
is_feature_used_
.
empty
()
&&
is_feature_used_
[
feature_index
]
==
false
))
continue
;
// construct histograms for smaller leaf
if
(
ordered_bins_
[
feature_index
]
==
nullptr
)
{
smaller_leaf_histogram_array_
[
feature_index
].
Construct
(
train_data_
->
FeatureAt
(
feature_index
)
->
bin_data
()
,
// if not use ordered bin
train_data_
->
FeatureAt
(
feature_index
)
->
bin_data
()
->
ConstructHistogram
(
smaller_leaf_splits_
->
data_indices
(),
smaller_leaf_splits_
->
num_data_in_leaf
(),
smaller_leaf_splits_
->
sum_gradients
(),
smaller_leaf_splits_
->
sum_hessians
(),
ptr_to_ordered_gradients_smaller_leaf_
,
ptr_to_ordered_hessians_smaller_leaf_
);
ptr_to_ordered_hessians_smaller_leaf_
,
smaller_leaf_histogram_array_
[
feature_index
].
GetData
());
}
else
{
smaller_leaf_histogram_array_
[
feature_index
].
Construct
(
ordered_bins_
[
feature_index
].
get
(),
smaller_leaf_splits_
->
LeafIndex
(),
smaller_leaf_splits_
->
num_data_in_leaf
(),
smaller_leaf_splits_
->
sum_gradients
(),
smaller_leaf_splits_
->
sum_hessians
(),
// used ordered bin
ordered_bins_
[
feature_index
]
->
ConstructHistogram
(
smaller_leaf_splits_
->
LeafIndex
(),
gradients_
,
hessians_
);
hessians_
,
smaller_leaf_histogram_array_
[
feature_index
].
GetData
());
}
// copy to buffer
std
::
memcpy
(
input_buffer_
.
data
()
+
buffer_write_start_pos_
[
feature_index
],
...
...
@@ -160,11 +157,6 @@ void DataParallelTreeLearner::FindBestThresholds() {
#pragma omp parallel for schedule(guided)
for
(
int
feature_index
=
0
;
feature_index
<
num_features_
;
++
feature_index
)
{
if
(
!
is_feature_aggregated_
[
feature_index
])
continue
;
// copy global sumup info
smaller_leaf_histogram_array_
[
feature_index
].
SetSumup
(
GetGlobalDataCountInLeaf
(
smaller_leaf_splits_
->
LeafIndex
()),
smaller_leaf_splits_
->
sum_gradients
(),
smaller_leaf_splits_
->
sum_hessians
());
// restore global histograms from buffer
smaller_leaf_histogram_array_
[
feature_index
].
FromMemory
(
...
...
@@ -172,6 +164,9 @@ void DataParallelTreeLearner::FindBestThresholds() {
// find best threshold for smaller child
smaller_leaf_histogram_array_
[
feature_index
].
FindBestThreshold
(
smaller_leaf_splits_
->
sum_gradients
(),
smaller_leaf_splits_
->
sum_hessians
(),
smaller_leaf_splits_
->
num_data_in_leaf
(),
&
smaller_leaf_splits_
->
BestSplitPerFeature
()[
feature_index
]);
// only root leaf
...
...
@@ -180,12 +175,12 @@ void DataParallelTreeLearner::FindBestThresholds() {
// construct histgroms for large leaf, we init larger leaf as the parent, so we can just subtract the smaller leaf's histograms
larger_leaf_histogram_array_
[
feature_index
].
Subtract
(
smaller_leaf_histogram_array_
[
feature_index
]);
// set sumup info for histogram
larger_leaf_histogram_array_
[
feature_index
].
SetSumup
(
GetGlobalDataCountInLeaf
(
larger_leaf_splits_
->
LeafIndex
()),
larger_leaf_splits_
->
sum_gradients
(),
larger_leaf_splits_
->
sum_hessians
());
// find best threshold for larger child
larger_leaf_histogram_array_
[
feature_index
].
FindBestThreshold
(
larger_leaf_splits_
->
sum_gradients
(),
larger_leaf_splits_
->
sum_hessians
(),
larger_leaf_splits_
->
num_data_in_leaf
(),
&
larger_leaf_splits_
->
BestSplitPerFeature
()[
feature_index
]);
}
...
...
src/treelearner/feature_histogram.hpp
View file @
4948dcc5
...
...
@@ -31,62 +31,20 @@ public:
void
Init
(
const
Feature
*
feature
,
int
feature_idx
,
const
TreeConfig
*
tree_config
)
{
feature_idx_
=
feature_idx
;
tree_config_
=
tree_config
;
num_bins
_
=
feature
->
num_bin
()
;
data_
.
resize
(
num_bin
s_
);
feature
_
=
feature
;
data_
.
resize
(
feature_
->
num_bin
()
);
if
(
feature
->
bin_type
()
==
BinType
::
NumericalBin
)
{
find_best_threshold_fun_
=
std
::
bind
(
&
FeatureHistogram
::
FindBestThresholdForNumerical
,
this
,
std
::
placeholders
::
_1
);
find_best_threshold_fun_
=
std
::
bind
(
&
FeatureHistogram
::
FindBestThresholdForNumerical
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
,
std
::
placeholders
::
_4
);
}
else
{
find_best_threshold_fun_
=
std
::
bind
(
&
FeatureHistogram
::
FindBestThresholdForCategorical
,
this
,
std
::
placeholders
::
_1
);
find_best_threshold_fun_
=
std
::
bind
(
&
FeatureHistogram
::
FindBestThresholdForCategorical
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
,
std
::
placeholders
::
_4
);
}
}
/*!
* \brief Construct a histogram
* \param num_data number of data in current leaf
* \param sum_gradients sum of gradients of current leaf
* \param sum_hessians sum of hessians of current leaf
* \param ordered_gradients Orederd gradients
* \param ordered_hessians Ordered hessians
* \param data_indices data indices of current leaf
*/
void
Construct
(
const
Bin
*
bin_data
,
const
data_size_t
*
data_indices
,
data_size_t
num_data
,
double
sum_gradients
,
double
sum_hessians
,
const
score_t
*
ordered_gradients
,
const
score_t
*
ordered_hessians
)
{
std
::
memset
(
data_
.
data
(),
0
,
sizeof
(
HistogramBinEntry
)
*
num_bins_
);
num_data_
=
num_data
;
sum_gradients_
=
sum_gradients
;
sum_hessians_
=
sum_hessians
+
2
*
kEpsilon
;
bin_data
->
ConstructHistogram
(
data_indices
,
num_data
,
ordered_gradients
,
ordered_hessians
,
data_
.
data
());
}
/*!
* \brief Construct a histogram by ordered bin
* \param leaf current leaf
* \param num_data number of data in current leaf
* \param sum_gradients sum of gradients of current leaf
* \param sum_hessians sum of hessians of current leaf
* \param gradients
* \param hessian
*/
void
Construct
(
const
OrderedBin
*
ordered_bin
,
int
leaf
,
data_size_t
num_data
,
double
sum_gradients
,
double
sum_hessians
,
const
score_t
*
gradients
,
const
score_t
*
hessians
)
{
std
::
memset
(
data_
.
data
(),
0
,
sizeof
(
HistogramBinEntry
)
*
num_bins_
);
num_data_
=
num_data
;
sum_gradients_
=
sum_gradients
;
sum_hessians_
=
sum_hessians
+
2
*
kEpsilon
;
ordered_bin
->
ConstructHistogram
(
leaf
,
gradients
,
hessians
,
data_
.
data
());
}
/*!
* \brief Set sumup information for current histogram
* \param num_data number of data in current leaf
* \param sum_gradients sum of gradients of current leaf
* \param sum_hessians sum of hessians of current leaf
*/
void
SetSumup
(
data_size_t
num_data
,
double
sum_gradients
,
double
sum_hessians
)
{
num_data_
=
num_data
;
sum_gradients_
=
sum_gradients
;
sum_hessians_
=
sum_hessians
+
2
*
kEpsilon
;
HistogramBinEntry
*
GetData
()
{
std
::
memset
(
data_
.
data
(),
0
,
feature_
->
num_bin
()
*
sizeof
(
HistogramBinEntry
));
return
data_
.
data
();
}
/*!
...
...
@@ -94,53 +52,56 @@ public:
* \param other The histogram that want to subtract
*/
void
Subtract
(
const
FeatureHistogram
&
other
)
{
num_data_
-=
other
.
num_data_
;
sum_gradients_
-=
other
.
sum_gradients_
;
sum_hessians_
-=
other
.
sum_hessians_
;
for
(
unsigned
int
i
=
0
;
i
<
num_bins_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
feature_
->
num_bin
();
++
i
)
{
data_
[
i
].
cnt
-=
other
.
data_
[
i
].
cnt
;
data_
[
i
].
sum_gradients
-=
other
.
data_
[
i
].
sum_gradients
;
data_
[
i
].
sum_hessians
-=
other
.
data_
[
i
].
sum_hessians
;
}
}
/*!
* \brief Find best threshold for this histogram
* \param output The best split result
*/
void
FindBestThreshold
(
SplitInfo
*
output
)
{
find_best_threshold_fun_
(
output
);
void
FindBestThreshold
(
double
sum_gradient
,
double
sum_hessian
,
data_size_t
num_data
,
SplitInfo
*
output
)
{
find_best_threshold_fun_
(
sum_gradient
,
sum_hessian
,
num_data
,
output
);
if
(
output
->
gain
>
kMinScore
)
{
is_splittable_
=
true
;
}
else
{
is_splittable_
=
false
;
}
}
void
FindBestThresholdForNumerical
(
SplitInfo
*
output
)
{
void
FindBestThresholdForNumerical
(
double
sum_gradient
,
double
sum_hessian
,
data_size_t
num_data
,
SplitInfo
*
output
)
{
double
best_sum_left_gradient
=
NAN
;
double
best_sum_left_hessian
=
NAN
;
double
best_gain
=
kMinScore
;
data_size_t
best_left_count
=
0
;
unsigned
int
best_threshold
=
static_cast
<
unsigned
int
>
(
num_bin
s_
);
unsigned
int
best_threshold
=
static_cast
<
unsigned
int
>
(
feature_
->
num_bin
()
);
double
sum_right_gradient
=
0.0
f
;
double
sum_right_hessian
=
kEpsilon
;
data_size_t
right_count
=
0
;
double
gain_shift
=
GetLeafSplitGain
(
sum_gradient
s_
,
sum_hessian
s_
);
double
gain_shift
=
GetLeafSplitGain
(
sum_gradient
,
sum_hessian
);
double
min_gain_shift
=
gain_shift
+
tree_config_
->
min_gain_to_split
;
is_splittable
_
=
false
;
bool
is_splittable
=
false
;
// from right to left, and we don't need data in bin0
for
(
unsigned
int
t
=
num_bin
s_
-
1
;
t
>
0
;
--
t
)
{
for
(
int
t
=
feature_
->
num_bin
()
-
1
;
t
>
0
;
--
t
)
{
sum_right_gradient
+=
data_
[
t
].
sum_gradients
;
sum_right_hessian
+=
data_
[
t
].
sum_hessians
;
right_count
+=
data_
[
t
].
cnt
;
// if data not enough, or sum hessian too small
if
(
right_count
<
tree_config_
->
min_data_in_leaf
||
sum_right_hessian
<
tree_config_
->
min_sum_hessian_in_leaf
)
continue
;
data_size_t
left_count
=
num_data
_
-
right_count
;
data_size_t
left_count
=
num_data
-
right_count
;
// if data not enough
if
(
left_count
<
tree_config_
->
min_data_in_leaf
)
break
;
double
sum_left_hessian
=
sum_hessian
s_
-
sum_right_hessian
;
double
sum_left_hessian
=
sum_hessian
-
sum_right_hessian
;
// if sum hessian too small
if
(
sum_left_hessian
<
tree_config_
->
min_sum_hessian_in_leaf
)
break
;
double
sum_left_gradient
=
sum_gradient
s_
-
sum_right_gradient
;
double
sum_left_gradient
=
sum_gradient
-
sum_right_gradient
;
// current split gain
double
current_gain
=
GetLeafSplitGain
(
sum_left_gradient
,
sum_left_hessian
)
+
GetLeafSplitGain
(
sum_right_gradient
,
sum_right_hessian
);
...
...
@@ -148,18 +109,18 @@ public:
if
(
current_gain
<
min_gain_shift
)
continue
;
// mark to is splittable
is_splittable
_
=
true
;
is_splittable
=
true
;
// better split point
if
(
current_gain
>
best_gain
)
{
best_left_count
=
left_count
;
best_sum_left_gradient
=
sum_left_gradient
;
best_sum_left_hessian
=
sum_left_hessian
;
// left is <= threshold, right is > threshold. so this is t-1
best_threshold
=
t
-
1
;
best_threshold
=
static_cast
<
unsigned
int
>
(
t
-
1
)
;
best_gain
=
current_gain
;
}
}
if
(
is_splittable
_
)
{
if
(
is_splittable
)
{
// update split information
output
->
feature
=
feature_idx_
;
output
->
threshold
=
best_threshold
;
...
...
@@ -167,11 +128,11 @@ public:
output
->
left_count
=
best_left_count
;
output
->
left_sum_gradient
=
best_sum_left_gradient
;
output
->
left_sum_hessian
=
best_sum_left_hessian
;
output
->
right_output
=
CalculateSplittedLeafOutput
(
sum_gradient
s_
-
best_sum_left_gradient
,
sum_hessian
s_
-
best_sum_left_hessian
);
output
->
right_count
=
num_data
_
-
best_left_count
;
output
->
right_sum_gradient
=
sum_gradient
s_
-
best_sum_left_gradient
;
output
->
right_sum_hessian
=
sum_hessian
s_
-
best_sum_left_hessian
;
output
->
right_output
=
CalculateSplittedLeafOutput
(
sum_gradient
-
best_sum_left_gradient
,
sum_hessian
-
best_sum_left_hessian
);
output
->
right_count
=
num_data
-
best_left_count
;
output
->
right_sum_gradient
=
sum_gradient
-
best_sum_left_gradient
;
output
->
right_sum_hessian
=
sum_hessian
-
best_sum_left_hessian
;
output
->
gain
=
best_gain
-
gain_shift
;
}
else
{
output
->
feature
=
feature_idx_
;
...
...
@@ -183,30 +144,30 @@ public:
* \brief Find best threshold for this histogram
* \param output The best split result
*/
void
FindBestThresholdForCategorical
(
SplitInfo
*
output
)
{
void
FindBestThresholdForCategorical
(
double
sum_gradient
,
double
sum_hessian
,
data_size_t
num_data
,
SplitInfo
*
output
)
{
double
best_gain
=
kMinScore
;
unsigned
int
best_threshold
=
static_cast
<
unsigned
int
>
(
num_bin
s_
);
unsigned
int
best_threshold
=
static_cast
<
unsigned
int
>
(
feature_
->
num_bin
()
);
double
gain_shift
=
GetLeafSplitGain
(
sum_gradient
s_
,
sum_hessian
s_
);
double
gain_shift
=
GetLeafSplitGain
(
sum_gradient
,
sum_hessian
);
double
min_gain_shift
=
gain_shift
+
tree_config_
->
min_gain_to_split
;
is_splittable_
=
false
;
for
(
int
t
=
num_bins_
-
1
;
t
>=
0
;
--
t
)
{
bool
is_splittable
=
false
;
for
(
int
t
=
feature_
->
num_bin
()
-
1
;
t
>=
0
;
--
t
)
{
double
sum_current_gradient
=
data_
[
t
].
sum_gradients
;
double
sum_current_hessian
=
data_
[
t
].
sum_hessians
;
data_size_t
current_count
=
data_
[
t
].
cnt
;
// if data not enough, or sum hessian too small
if
(
current_count
<
tree_config_
->
min_data_in_leaf
||
sum_current_hessian
<
tree_config_
->
min_sum_hessian_in_leaf
)
continue
;
data_size_t
other_count
=
num_data
_
-
current_count
;
data_size_t
other_count
=
num_data
-
current_count
;
// if data not enough
if
(
other_count
<
tree_config_
->
min_data_in_leaf
)
continue
;
double
sum_other_hessian
=
sum_hessian
s_
-
sum_current_hessian
;
double
sum_other_hessian
=
sum_hessian
-
sum_current_hessian
;
// if sum hessian too small
if
(
sum_other_hessian
<
tree_config_
->
min_sum_hessian_in_leaf
)
continue
;
double
sum_other_gradient
=
sum_gradient
s_
-
sum_current_gradient
;
double
sum_other_gradient
=
sum_gradient
-
sum_current_gradient
;
// current split gain
double
current_gain
=
GetLeafSplitGain
(
sum_other_gradient
,
sum_other_hessian
)
+
GetLeafSplitGain
(
sum_current_gradient
,
sum_current_hessian
);
...
...
@@ -214,7 +175,7 @@ public:
if
(
current_gain
<
min_gain_shift
)
continue
;
// mark to is splittable
is_splittable
_
=
true
;
is_splittable
=
true
;
// better split point
if
(
current_gain
>
best_gain
)
{
best_threshold
=
static_cast
<
unsigned
int
>
(
t
);
...
...
@@ -222,7 +183,7 @@ public:
}
}
// update split information
if
(
is_splittable
_
)
{
if
(
is_splittable
)
{
output
->
feature
=
feature_idx_
;
output
->
threshold
=
best_threshold
;
output
->
left_output
=
CalculateSplittedLeafOutput
(
data_
[
best_threshold
].
sum_gradients
,
...
...
@@ -231,11 +192,11 @@ public:
output
->
left_sum_gradient
=
data_
[
best_threshold
].
sum_gradients
;
output
->
left_sum_hessian
=
data_
[
best_threshold
].
sum_hessians
;
output
->
right_output
=
CalculateSplittedLeafOutput
(
sum_gradient
s_
-
data_
[
best_threshold
].
sum_gradients
,
sum_hessian
s_
-
data_
[
best_threshold
].
sum_hessians
);
output
->
right_count
=
num_data
_
-
data_
[
best_threshold
].
cnt
;
output
->
right_sum_gradient
=
sum_gradient
s_
-
data_
[
best_threshold
].
sum_gradients
;
output
->
right_sum_hessian
=
sum_hessian
s_
-
data_
[
best_threshold
].
sum_hessians
;
output
->
right_output
=
CalculateSplittedLeafOutput
(
sum_gradient
-
data_
[
best_threshold
].
sum_gradients
,
sum_hessian
-
data_
[
best_threshold
].
sum_hessians
);
output
->
right_count
=
num_data
-
data_
[
best_threshold
].
cnt
;
output
->
right_sum_gradient
=
sum_gradient
-
data_
[
best_threshold
].
sum_gradients
;
output
->
right_sum_hessian
=
sum_hessian
-
data_
[
best_threshold
].
sum_hessians
;
output
->
gain
=
best_gain
-
gain_shift
;
}
else
{
...
...
@@ -243,12 +204,11 @@ public:
output
->
gain
=
kMinScore
;
}
}
/*!
* \brief Binary size of this histogram
*/
int
SizeOfHistgram
()
const
{
return
num_bin
s_
*
sizeof
(
HistogramBinEntry
);
return
feature_
->
num_bin
()
*
sizeof
(
HistogramBinEntry
);
}
/*!
...
...
@@ -262,7 +222,7 @@ public:
* \brief Restore histogram from memory
*/
void
FromMemory
(
char
*
memory_data
)
{
std
::
memcpy
(
data_
.
data
(),
memory_data
,
num_bin
s_
*
sizeof
(
HistogramBinEntry
));
std
::
memcpy
(
data_
.
data
(),
memory_data
,
feature_
->
num_bin
()
*
sizeof
(
HistogramBinEntry
));
}
/*!
...
...
@@ -312,22 +272,15 @@ private:
}
int
feature_idx_
;
const
Feature
*
feature_
;
/*! \brief pointer of tree config */
const
TreeConfig
*
tree_config_
;
/*! \brief number of bin of histogram */
unsigned
int
num_bins_
;
/*! \brief sum of gradient of each bin */
std
::
vector
<
HistogramBinEntry
>
data_
;
/*! \brief number of all data */
data_size_t
num_data_
;
/*! \brief sum of gradient of current leaf */
double
sum_gradients_
;
/*! \brief sum of hessians of current leaf */
double
sum_hessians_
;
/*! \brief False if this histogram cannot split */
bool
is_splittable_
=
true
;
/*! \brief function that used to find best threshold */
std
::
function
<
void
(
SplitInfo
*
)
>
find_best_threshold_fun_
;
std
::
function
<
void
(
double
,
double
,
data_size_t
,
SplitInfo
*
)
>
find_best_threshold_fun_
;
};
...
...
src/treelearner/serial_tree_learner.cpp
View file @
4948dcc5
...
...
@@ -390,26 +390,25 @@ void SerialTreeLearner::FindBestThresholds() {
// construct histograms for smaller leaf
if
(
ordered_bins_
[
feature_index
]
==
nullptr
)
{
// if not use ordered bin
smaller_leaf_histogram_array_
[
feature_index
].
Construct
(
train_data_
->
FeatureAt
(
feature_index
)
->
bin_data
(),
train_data_
->
FeatureAt
(
feature_index
)
->
bin_data
()
->
ConstructHistogram
(
smaller_leaf_splits_
->
data_indices
(),
smaller_leaf_splits_
->
num_data_in_leaf
(),
smaller_leaf_splits_
->
sum_gradients
(),
smaller_leaf_splits_
->
sum_hessians
(),
ptr_to_ordered_gradients_smaller_leaf_
,
ptr_to_ordered_hessians_smaller_leaf_
);
ptr_to_ordered_hessians_smaller_leaf_
,
smaller_leaf_histogram_array_
[
feature_index
].
GetData
());
}
else
{
// used ordered bin
smaller_leaf_histogram_array_
[
feature_index
].
Construct
(
ordered_bins_
[
feature_index
].
get
(),
smaller_leaf_splits_
->
LeafIndex
(),
smaller_leaf_splits_
->
num_data_in_leaf
(),
smaller_leaf_splits_
->
sum_gradients
(),
smaller_leaf_splits_
->
sum_hessians
(),
ordered_bins_
[
feature_index
]
->
ConstructHistogram
(
smaller_leaf_splits_
->
LeafIndex
(),
gradients_
,
hessians_
);
hessians_
,
smaller_leaf_histogram_array_
[
feature_index
].
GetData
());
}
// find best threshold for smaller child
smaller_leaf_histogram_array_
[
feature_index
].
FindBestThreshold
(
&
smaller_leaf_splits_
->
BestSplitPerFeature
()[
feature_index
]);
smaller_leaf_histogram_array_
[
feature_index
].
FindBestThreshold
(
smaller_leaf_splits_
->
sum_gradients
(),
smaller_leaf_splits_
->
sum_hessians
(),
smaller_leaf_splits_
->
num_data_in_leaf
(),
&
smaller_leaf_splits_
->
BestSplitPerFeature
()[
feature_index
]);
// only has root leaf
if
(
larger_leaf_splits_
==
nullptr
||
larger_leaf_splits_
->
LeafIndex
()
<
0
)
continue
;
...
...
@@ -421,28 +420,27 @@ void SerialTreeLearner::FindBestThresholds() {
}
else
{
if
(
ordered_bins_
[
feature_index
]
==
nullptr
)
{
// if not use ordered bin
larger_leaf_histogram_array_
[
feature_index
].
Construct
(
train_data_
->
FeatureAt
(
feature_index
)
->
bin_data
(),
train_data_
->
FeatureAt
(
feature_index
)
->
bin_data
()
->
ConstructHistogram
(
larger_leaf_splits_
->
data_indices
(),
larger_leaf_splits_
->
num_data_in_leaf
(),
larger_leaf_splits_
->
sum_gradients
(),
larger_leaf_splits_
->
sum_hessians
(),
ptr_to_ordered_gradients_larger_leaf_
,
ptr_to_ordered_hessians_larger_leaf_
);
ptr_to_ordered_hessians_larger_leaf_
,
larger_leaf_histogram_array_
[
feature_index
].
GetData
());
}
else
{
// used ordered bin
larger_leaf_histogram_array_
[
feature_index
].
Construct
(
ordered_bins_
[
feature_index
].
get
(),
larger_leaf_splits_
->
LeafIndex
(),
larger_leaf_splits_
->
num_data_in_leaf
(),
larger_leaf_splits_
->
sum_gradients
(),
larger_leaf_splits_
->
sum_hessians
(),
ordered_bins_
[
feature_index
]
->
ConstructHistogram
(
larger_leaf_splits_
->
LeafIndex
(),
gradients_
,
hessians_
);
hessians_
,
larger_leaf_histogram_array_
[
feature_index
].
GetData
());
}
}
// find best threshold for larger child
larger_leaf_histogram_array_
[
feature_index
].
FindBestThreshold
(
&
larger_leaf_splits_
->
BestSplitPerFeature
()[
feature_index
]);
larger_leaf_histogram_array_
[
feature_index
].
FindBestThreshold
(
larger_leaf_splits_
->
sum_gradients
(),
larger_leaf_splits_
->
sum_hessians
(),
larger_leaf_splits_
->
num_data_in_leaf
(),
&
larger_leaf_splits_
->
BestSplitPerFeature
()[
feature_index
]);
}
}
...
...
src/treelearner/voting_parallel_tree_learner.cpp
View file @
4948dcc5
...
...
@@ -264,29 +264,30 @@ void VotingParallelTreeLearner::FindBestThresholds() {
Network
::
ReduceScatter
(
input_buffer_
.
data
(),
reduce_scatter_size_
,
block_start_
.
data
(),
block_len_
.
data
(),
output_buffer_
.
data
(),
&
HistogramBinEntry
::
SumReducer
);
// find best split from local aggregated histograms
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
for
(
int
feature_index
=
0
;
feature_index
<
num_features_
;
++
feature_index
)
{
if
(
smaller_is_feature_aggregated_
[
feature_index
])
{
smaller_leaf_histogram_array_global_
[
feature_index
].
SetSumup
(
GetGlobalDataCountInLeaf
(
smaller_leaf_splits_global_
->
LeafIndex
()),
smaller_leaf_splits_global_
->
sum_gradients
(),
smaller_leaf_splits_global_
->
sum_hessians
());
// restore from buffer
smaller_leaf_histogram_array_global_
[
feature_index
].
FromMemory
(
output_buffer_
.
data
()
+
smaller_buffer_read_start_pos_
[
feature_index
]);
// find best threshold
smaller_leaf_histogram_array_global_
[
feature_index
].
FindBestThreshold
(
smaller_leaf_splits_global_
->
sum_gradients
(),
smaller_leaf_splits_global_
->
sum_hessians
(),
GetGlobalDataCountInLeaf
(
smaller_leaf_splits_global_
->
LeafIndex
()),
&
smaller_leaf_splits_global_
->
BestSplitPerFeature
()[
feature_index
]);
}
if
(
larger_is_feature_aggregated_
[
feature_index
])
{
larger_leaf_histogram_array_global_
[
feature_index
].
SetSumup
(
GetGlobalDataCountInLeaf
(
larger_leaf_splits_global_
->
LeafIndex
()),
larger_leaf_splits_global_
->
sum_gradients
(),
larger_leaf_splits_global_
->
sum_hessians
());
// restore from buffer
larger_leaf_histogram_array_global_
[
feature_index
].
FromMemory
(
output_buffer_
.
data
()
+
larger_buffer_read_start_pos_
[
feature_index
]);
// find best threshold
larger_leaf_histogram_array_global_
[
feature_index
].
FindBestThreshold
(
&
larger_leaf_splits_global_
->
BestSplitPerFeature
()[
feature_index
]);
larger_leaf_histogram_array_global_
[
feature_index
].
FindBestThreshold
(
larger_leaf_splits_global_
->
sum_gradients
(),
larger_leaf_splits_global_
->
sum_hessians
(),
GetGlobalDataCountInLeaf
(
larger_leaf_splits_global_
->
LeafIndex
()),
&
larger_leaf_splits_global_
->
BestSplitPerFeature
()[
feature_index
]);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment