Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
c8fbd42b
"...tests/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "b0774151cbd860b02ed21d63d01cf5acfc1c4702"
Commit
c8fbd42b
authored
Jan 22, 2017
by
Guolin Ke
Browse files
use subset to speed up bagging
parent
873528c1
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
143 additions
and
44 deletions
+143
-44
include/LightGBM/dataset.h
include/LightGBM/dataset.h
+1
-1
include/LightGBM/tree_learner.h
include/LightGBM/tree_learner.h
+2
-0
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+41
-4
src/boosting/gbdt.h
src/boosting/gbdt.h
+2
-0
src/c_api.cpp
src/c_api.cpp
+2
-1
src/io/dataset.cpp
src/io/dataset.cpp
+6
-2
src/treelearner/data_parallel_tree_learner.cpp
src/treelearner/data_parallel_tree_learner.cpp
+30
-28
src/treelearner/data_partition.hpp
src/treelearner/data_partition.hpp
+6
-1
src/treelearner/feature_histogram.hpp
src/treelearner/feature_histogram.hpp
+2
-5
src/treelearner/leaf_splits.hpp
src/treelearner/leaf_splits.hpp
+4
-0
src/treelearner/serial_tree_learner.cpp
src/treelearner/serial_tree_learner.cpp
+45
-2
src/treelearner/serial_tree_learner.h
src/treelearner/serial_tree_learner.h
+2
-0
No files found.
include/LightGBM/dataset.h
View file @
c8fbd42b
...
@@ -328,7 +328,7 @@ public:
...
@@ -328,7 +328,7 @@ public:
return
used_feature_map_
[
col_idx
];
return
used_feature_map_
[
col_idx
];
}
}
Dataset
*
Subset
(
const
data_size_t
*
used_indices
,
data_size_t
num_used_indices
,
bool
is_enable_sparse
)
const
;
Dataset
*
Subset
(
const
data_size_t
*
used_indices
,
data_size_t
num_used_indices
,
bool
is_enable_sparse
,
bool
need_meta_data
)
const
;
LIGHTGBM_EXPORT
void
FinishLoad
();
LIGHTGBM_EXPORT
void
FinishLoad
();
...
...
include/LightGBM/tree_learner.h
View file @
c8fbd42b
...
@@ -27,6 +27,8 @@ public:
...
@@ -27,6 +27,8 @@ public:
*/
*/
virtual
void
Init
(
const
Dataset
*
train_data
)
=
0
;
virtual
void
Init
(
const
Dataset
*
train_data
)
=
0
;
virtual
void
ResetTrainingData
(
const
Dataset
*
train_data
)
=
0
;
/*!
/*!
* \brief Reset tree configs
* \brief Reset tree configs
* \param tree_config config of tree
* \param tree_config config of tree
...
...
src/boosting/gbdt.cpp
View file @
c8fbd42b
...
@@ -134,10 +134,17 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -134,10 +134,17 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
right_cnts_buf_
.
resize
(
num_threads_
);
right_cnts_buf_
.
resize
(
num_threads_
);
left_write_pos_buf_
.
resize
(
num_threads_
);
left_write_pos_buf_
.
resize
(
num_threads_
);
right_write_pos_buf_
.
resize
(
num_threads_
);
right_write_pos_buf_
.
resize
(
num_threads_
);
double
average_bag_rate
=
new_config
->
bagging_fraction
/
new_config
->
bagging_freq
;
is_use_subset_
=
false
;
if
(
average_bag_rate
<
0.3
)
{
is_use_subset_
=
true
;
Log
::
Debug
(
"use subset for bagging"
);
}
}
else
{
}
else
{
bag_data_cnt_
=
num_data_
;
bag_data_cnt_
=
num_data_
;
bag_data_indices_
.
clear
();
bag_data_indices_
.
clear
();
tmp_indices_
.
clear
();
tmp_indices_
.
clear
();
is_use_subset_
=
false
;
}
}
}
}
train_data_
=
train_data
;
train_data_
=
train_data
;
...
@@ -196,6 +203,7 @@ data_size_t GBDT::BaggingHelper(data_size_t start, data_size_t cnt, data_size_t*
...
@@ -196,6 +203,7 @@ data_size_t GBDT::BaggingHelper(data_size_t start, data_size_t cnt, data_size_t*
buffer
[
bag_data_cnt
+
cur_right_cnt
++
]
=
start
+
i
;
buffer
[
bag_data_cnt
+
cur_right_cnt
++
]
=
start
+
i
;
}
}
}
}
CHECK
(
buffer
[
bag_data_cnt
-
1
]
>
buffer
[
bag_data_cnt
]);
CHECK
(
cur_left_cnt
==
bag_data_cnt
);
CHECK
(
cur_left_cnt
==
bag_data_cnt
);
return
cur_left_cnt
;
return
cur_left_cnt
;
}
}
...
@@ -240,15 +248,24 @@ void GBDT::Bagging(int iter) {
...
@@ -240,15 +248,24 @@ void GBDT::Bagging(int iter) {
tmp_indices_
.
data
()
+
offsets_buf_
[
i
]
+
left_cnts_buf_
[
i
],
right_cnts_buf_
[
i
]
*
sizeof
(
data_size_t
));
tmp_indices_
.
data
()
+
offsets_buf_
[
i
]
+
left_cnts_buf_
[
i
],
right_cnts_buf_
[
i
]
*
sizeof
(
data_size_t
));
}
}
}
}
bag_data_cnt_
=
left_cnt
;
CHECK
(
bag_data_indices_
[
bag_data_cnt_
-
1
]
>
bag_data_indices_
[
bag_data_cnt_
]);
Log
::
Debug
(
"Re-bagging, using %d data to train"
,
bag_data_cnt_
);
Log
::
Debug
(
"Re-bagging, using %d data to train"
,
bag_data_cnt_
);
// set bagging data to tree learner
// set bagging data to tree learner
if
(
!
is_use_subset_
)
{
tree_learner_
->
SetBaggingData
(
bag_data_indices_
.
data
(),
bag_data_cnt_
);
tree_learner_
->
SetBaggingData
(
bag_data_indices_
.
data
(),
bag_data_cnt_
);
}
else
{
// get subset
tmp_subset_
.
reset
(
train_data_
->
Subset
(
bag_data_indices_
.
data
(),
bag_data_cnt_
,
false
,
false
));
tmp_subset_
->
FinishLoad
();
tree_learner_
->
ResetTrainingData
(
tmp_subset_
.
get
());
}
}
}
}
}
void
GBDT
::
UpdateScoreOutOfBag
(
const
Tree
*
tree
,
const
int
curr_class
)
{
void
GBDT
::
UpdateScoreOutOfBag
(
const
Tree
*
tree
,
const
int
curr_class
)
{
// we need to predict out-of-bag socres of data for boosting
// we need to predict out-of-bag socres of data for boosting
if
(
num_data_
-
bag_data_cnt_
>
0
)
{
if
(
num_data_
-
bag_data_cnt_
>
0
&&
!
is_use_subset_
)
{
train_score_updater_
->
AddScore
(
tree
,
bag_data_indices_
.
data
()
+
bag_data_cnt_
,
num_data_
-
bag_data_cnt_
,
curr_class
);
train_score_updater_
->
AddScore
(
tree
,
bag_data_indices_
.
data
()
+
bag_data_cnt_
,
num_data_
-
bag_data_cnt_
,
curr_class
);
}
}
}
}
...
@@ -262,8 +279,24 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
...
@@ -262,8 +279,24 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
}
}
// bagging logic
// bagging logic
Bagging
(
iter_
);
Bagging
(
iter_
);
if
(
is_use_subset_
&&
bag_data_cnt_
<
num_data_
)
{
if
(
gradients_
.
empty
())
{
size_t
total_size
=
static_cast
<
size_t
>
(
num_data_
)
*
num_class_
;
gradients_
.
resize
(
total_size
);
hessians_
.
resize
(
total_size
);
}
// get sub gradients
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
auto
bias
=
curr_class
*
num_data_
;
for
(
int
i
=
0
;
i
<
bag_data_cnt_
;
++
i
)
{
gradients_
[
bias
+
i
]
=
gradient
[
bias
+
bag_data_indices_
[
i
]];
hessians_
[
bias
+
i
]
=
hessian
[
bias
+
bag_data_indices_
[
i
]];
}
}
gradient
=
gradients_
.
data
();
hessian
=
hessians_
.
data
();
}
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
// train a new tree
// train a new tree
std
::
unique_ptr
<
Tree
>
new_tree
(
tree_learner_
->
Train
(
gradient
+
curr_class
*
num_data_
,
hessian
+
curr_class
*
num_data_
));
std
::
unique_ptr
<
Tree
>
new_tree
(
tree_learner_
->
Train
(
gradient
+
curr_class
*
num_data_
,
hessian
+
curr_class
*
num_data_
));
// if cannot learn a new tree, then stop
// if cannot learn a new tree, then stop
...
@@ -328,7 +361,11 @@ bool GBDT::EvalAndCheckEarlyStopping() {
...
@@ -328,7 +361,11 @@ bool GBDT::EvalAndCheckEarlyStopping() {
void
GBDT
::
UpdateScore
(
const
Tree
*
tree
,
const
int
curr_class
)
{
void
GBDT
::
UpdateScore
(
const
Tree
*
tree
,
const
int
curr_class
)
{
// update training score
// update training score
if
(
!
is_use_subset_
)
{
train_score_updater_
->
AddScore
(
tree_learner_
.
get
(),
curr_class
);
train_score_updater_
->
AddScore
(
tree_learner_
.
get
(),
curr_class
);
}
else
{
train_score_updater_
->
AddScore
(
tree
,
curr_class
);
}
// update validation score
// update validation score
for
(
auto
&
score_updater
:
valid_score_updater_
)
{
for
(
auto
&
score_updater
:
valid_score_updater_
)
{
score_updater
->
AddScore
(
tree
,
curr_class
);
score_updater
->
AddScore
(
tree
,
curr_class
);
...
...
src/boosting/gbdt.h
View file @
c8fbd42b
...
@@ -339,6 +339,8 @@ protected:
...
@@ -339,6 +339,8 @@ protected:
std
::
vector
<
data_size_t
>
left_write_pos_buf_
;
std
::
vector
<
data_size_t
>
left_write_pos_buf_
;
/*! \brief Buffer for multi-threading bagging */
/*! \brief Buffer for multi-threading bagging */
std
::
vector
<
data_size_t
>
right_write_pos_buf_
;
std
::
vector
<
data_size_t
>
right_write_pos_buf_
;
std
::
unique_ptr
<
Dataset
>
tmp_subset_
;
bool
is_use_subset_
;
};
};
}
// namespace LightGBM
}
// namespace LightGBM
...
...
src/c_api.cpp
View file @
c8fbd42b
...
@@ -488,7 +488,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
...
@@ -488,7 +488,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
auto
ret
=
std
::
unique_ptr
<
Dataset
>
(
auto
ret
=
std
::
unique_ptr
<
Dataset
>
(
full_dataset
->
Subset
(
used_row_indices
,
full_dataset
->
Subset
(
used_row_indices
,
num_used_row_indices
,
num_used_row_indices
,
io_config
.
is_enable_sparse
));
io_config
.
is_enable_sparse
,
true
));
ret
->
FinishLoad
();
ret
->
FinishLoad
();
*
out
=
ret
.
release
();
*
out
=
ret
.
release
();
API_END
();
API_END
();
...
...
src/io/dataset.cpp
View file @
c8fbd42b
...
@@ -22,6 +22,7 @@ Dataset::Dataset() {
...
@@ -22,6 +22,7 @@ Dataset::Dataset() {
}
}
Dataset
::
Dataset
(
data_size_t
num_data
)
{
Dataset
::
Dataset
(
data_size_t
num_data
)
{
data_filename_
=
"noname"
;
num_data_
=
num_data
;
num_data_
=
num_data
;
metadata_
.
Init
(
num_data_
,
-
1
,
-
1
);
metadata_
.
Init
(
num_data_
,
-
1
,
-
1
);
}
}
...
@@ -56,7 +57,8 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset, bool is_enable_spars
...
@@ -56,7 +57,8 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset, bool is_enable_spars
label_idx_
=
dataset
->
label_idx_
;
label_idx_
=
dataset
->
label_idx_
;
}
}
Dataset
*
Dataset
::
Subset
(
const
data_size_t
*
used_indices
,
data_size_t
num_used_indices
,
bool
is_enable_sparse
)
const
{
Dataset
*
Dataset
::
Subset
(
const
data_size_t
*
used_indices
,
data_size_t
num_used_indices
,
bool
is_enable_sparse
,
bool
need_meta_data
)
const
{
auto
ret
=
std
::
unique_ptr
<
Dataset
>
(
new
Dataset
(
num_used_indices
));
auto
ret
=
std
::
unique_ptr
<
Dataset
>
(
new
Dataset
(
num_used_indices
));
ret
->
CopyFeatureMapperFrom
(
this
,
is_enable_sparse
);
ret
->
CopyFeatureMapperFrom
(
this
,
is_enable_sparse
);
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
...
@@ -66,7 +68,9 @@ Dataset* Dataset::Subset(const data_size_t* used_indices, data_size_t num_used_i
...
@@ -66,7 +68,9 @@ Dataset* Dataset::Subset(const data_size_t* used_indices, data_size_t num_used_i
ret
->
features_
[
fidx
]
->
PushBin
(
0
,
i
,
iterator
->
Get
(
used_indices
[
i
]));
ret
->
features_
[
fidx
]
->
PushBin
(
0
,
i
,
iterator
->
Get
(
used_indices
[
i
]));
}
}
}
}
if
(
need_meta_data
)
{
ret
->
metadata_
.
Init
(
metadata_
,
used_indices
,
num_used_indices
);
ret
->
metadata_
.
Init
(
metadata_
,
used_indices
,
num_used_indices
);
}
return
ret
.
release
();
return
ret
.
release
();
}
}
...
...
src/treelearner/data_parallel_tree_learner.cpp
View file @
c8fbd42b
...
@@ -126,12 +126,14 @@ void DataParallelTreeLearner::BeforeTrain() {
...
@@ -126,12 +126,14 @@ void DataParallelTreeLearner::BeforeTrain() {
void
DataParallelTreeLearner
::
FindBestThresholds
()
{
void
DataParallelTreeLearner
::
FindBestThresholds
()
{
// construct local histograms
// construct local histograms
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
for
(
int
feature_index
=
0
;
feature_index
<
num_features_
;
++
feature_index
)
{
for
(
int
feature_index
=
0
;
feature_index
<
num_features_
;
++
feature_index
)
{
if
((
!
is_feature_used_
.
empty
()
&&
is_feature_used_
[
feature_index
]
==
false
))
continue
;
if
((
!
is_feature_used_
.
empty
()
&&
is_feature_used_
[
feature_index
]
==
false
))
continue
;
// construct histograms for smaller leaf
// construct histograms for smaller leaf
if
(
ordered_bins_
[
feature_index
]
==
nullptr
)
{
if
(
ordered_bins_
[
feature_index
]
==
nullptr
)
{
smaller_leaf_histogram_array_
[
feature_index
].
Construct
(
smaller_leaf_splits_
->
data_indices
(),
smaller_leaf_histogram_array_
[
feature_index
].
Construct
(
train_data_
->
FeatureAt
(
feature_index
)
->
bin_data
(),
smaller_leaf_splits_
->
data_indices
(),
smaller_leaf_splits_
->
num_data_in_leaf
(),
smaller_leaf_splits_
->
num_data_in_leaf
(),
smaller_leaf_splits_
->
sum_gradients
(),
smaller_leaf_splits_
->
sum_gradients
(),
smaller_leaf_splits_
->
sum_hessians
(),
smaller_leaf_splits_
->
sum_hessians
(),
...
@@ -155,7 +157,7 @@ void DataParallelTreeLearner::FindBestThresholds() {
...
@@ -155,7 +157,7 @@ void DataParallelTreeLearner::FindBestThresholds() {
// Reduce scatter for histogram
// Reduce scatter for histogram
Network
::
ReduceScatter
(
input_buffer_
.
data
(),
reduce_scatter_size_
,
block_start_
.
data
(),
Network
::
ReduceScatter
(
input_buffer_
.
data
(),
reduce_scatter_size_
,
block_start_
.
data
(),
block_len_
.
data
(),
output_buffer_
.
data
(),
&
HistogramBinEntry
::
SumReducer
);
block_len_
.
data
(),
output_buffer_
.
data
(),
&
HistogramBinEntry
::
SumReducer
);
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
for
(
int
feature_index
=
0
;
feature_index
<
num_features_
;
++
feature_index
)
{
for
(
int
feature_index
=
0
;
feature_index
<
num_features_
;
++
feature_index
)
{
if
(
!
is_feature_aggregated_
[
feature_index
])
continue
;
if
(
!
is_feature_aggregated_
[
feature_index
])
continue
;
// copy global sumup info
// copy global sumup info
...
...
src/treelearner/data_partition.hpp
View file @
c8fbd42b
...
@@ -41,7 +41,12 @@ public:
...
@@ -41,7 +41,12 @@ public:
leaf_begin_
.
resize
(
num_leaves_
);
leaf_begin_
.
resize
(
num_leaves_
);
leaf_count_
.
resize
(
num_leaves_
);
leaf_count_
.
resize
(
num_leaves_
);
}
}
void
ResetNumData
(
int
num_data
)
{
num_data_
=
num_data
;
indices_
.
resize
(
num_data_
);
temp_left_indices_
.
resize
(
num_data_
);
temp_right_indices_
.
resize
(
num_data_
);
}
~
DataPartition
()
{
~
DataPartition
()
{
}
}
...
...
src/treelearner/feature_histogram.hpp
View file @
c8fbd42b
...
@@ -31,7 +31,6 @@ public:
...
@@ -31,7 +31,6 @@ public:
void
Init
(
const
Feature
*
feature
,
int
feature_idx
,
const
TreeConfig
*
tree_config
)
{
void
Init
(
const
Feature
*
feature
,
int
feature_idx
,
const
TreeConfig
*
tree_config
)
{
feature_idx_
=
feature_idx
;
feature_idx_
=
feature_idx
;
tree_config_
=
tree_config
;
tree_config_
=
tree_config
;
bin_data_
=
feature
->
bin_data
();
num_bins_
=
feature
->
num_bin
();
num_bins_
=
feature
->
num_bin
();
data_
.
resize
(
num_bins_
);
data_
.
resize
(
num_bins_
);
if
(
feature
->
bin_type
()
==
BinType
::
NumericalBin
)
{
if
(
feature
->
bin_type
()
==
BinType
::
NumericalBin
)
{
...
@@ -51,13 +50,13 @@ public:
...
@@ -51,13 +50,13 @@ public:
* \param ordered_hessians Ordered hessians
* \param ordered_hessians Ordered hessians
* \param data_indices data indices of current leaf
* \param data_indices data indices of current leaf
*/
*/
void
Construct
(
const
data_size_t
*
data_indices
,
data_size_t
num_data
,
double
sum_gradients
,
void
Construct
(
const
Bin
*
bin_data
,
const
data_size_t
*
data_indices
,
data_size_t
num_data
,
double
sum_gradients
,
double
sum_hessians
,
const
score_t
*
ordered_gradients
,
const
score_t
*
ordered_hessians
)
{
double
sum_hessians
,
const
score_t
*
ordered_gradients
,
const
score_t
*
ordered_hessians
)
{
std
::
memset
(
data_
.
data
(),
0
,
sizeof
(
HistogramBinEntry
)
*
num_bins_
);
std
::
memset
(
data_
.
data
(),
0
,
sizeof
(
HistogramBinEntry
)
*
num_bins_
);
num_data_
=
num_data
;
num_data_
=
num_data
;
sum_gradients_
=
sum_gradients
;
sum_gradients_
=
sum_gradients
;
sum_hessians_
=
sum_hessians
+
2
*
kEpsilon
;
sum_hessians_
=
sum_hessians
+
2
*
kEpsilon
;
bin_data
_
->
ConstructHistogram
(
data_indices
,
num_data
,
ordered_gradients
,
ordered_hessians
,
data_
.
data
());
bin_data
->
ConstructHistogram
(
data_indices
,
num_data
,
ordered_gradients
,
ordered_hessians
,
data_
.
data
());
}
}
/*!
/*!
...
@@ -315,8 +314,6 @@ private:
...
@@ -315,8 +314,6 @@ private:
int
feature_idx_
;
int
feature_idx_
;
/*! \brief pointer of tree config */
/*! \brief pointer of tree config */
const
TreeConfig
*
tree_config_
;
const
TreeConfig
*
tree_config_
;
/*! \brief the bin data of current feature */
const
Bin
*
bin_data_
;
/*! \brief number of bin of histogram */
/*! \brief number of bin of histogram */
unsigned
int
num_bins_
;
unsigned
int
num_bins_
;
/*! \brief sum of gradient of each bin */
/*! \brief sum of gradient of each bin */
...
...
src/treelearner/leaf_splits.hpp
View file @
c8fbd42b
...
@@ -22,6 +22,10 @@ public:
...
@@ -22,6 +22,10 @@ public:
best_split_per_feature_
[
i
].
feature
=
i
;
best_split_per_feature_
[
i
].
feature
=
i
;
}
}
}
}
void
ResetNumData
(
data_size_t
num_data
)
{
num_data_
=
num_data
;
num_data_in_leaf_
=
num_data
;
}
~
LeafSplits
()
{
~
LeafSplits
()
{
}
}
...
...
src/treelearner/serial_tree_learner.cpp
View file @
c8fbd42b
...
@@ -83,6 +83,45 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
...
@@ -83,6 +83,45 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
Log
::
Info
(
"Number of data: %d, number of features: %d"
,
num_data_
,
num_features_
);
Log
::
Info
(
"Number of data: %d, number of features: %d"
,
num_data_
,
num_features_
);
}
}
void
SerialTreeLearner
::
ResetTrainingData
(
const
Dataset
*
train_data
)
{
train_data_
=
train_data
;
num_data_
=
train_data_
->
num_data
();
num_features_
=
train_data_
->
num_features
();
// initialize ordered_bins_ with nullptr
ordered_bins_
.
resize
(
num_features_
);
// get ordered bin
#pragma omp parallel for schedule(guided)
for
(
int
i
=
0
;
i
<
num_features_
;
++
i
)
{
ordered_bins_
[
i
].
reset
(
train_data_
->
FeatureAt
(
i
)
->
bin_data
()
->
CreateOrderedBin
());
}
has_ordered_bin_
=
false
;
// check existing for ordered bin
for
(
int
i
=
0
;
i
<
num_features_
;
++
i
)
{
if
(
ordered_bins_
[
i
]
!=
nullptr
)
{
has_ordered_bin_
=
true
;
break
;
}
}
// initialize splits for leaf
smaller_leaf_splits_
->
ResetNumData
(
num_data_
);
larger_leaf_splits_
->
ResetNumData
(
num_data_
);
// initialize data partition
data_partition_
->
ResetNumData
(
num_data_
);
is_feature_used_
.
resize
(
num_features_
);
// initialize ordered gradients and hessians
ordered_gradients_
.
resize
(
num_data_
);
ordered_hessians_
.
resize
(
num_data_
);
// if has ordered bin, need to allocate a buffer to fast split
if
(
has_ordered_bin_
)
{
is_data_in_leaf_
.
resize
(
num_data_
);
}
}
void
SerialTreeLearner
::
ResetConfig
(
const
TreeConfig
*
tree_config
)
{
void
SerialTreeLearner
::
ResetConfig
(
const
TreeConfig
*
tree_config
)
{
if
(
tree_config_
->
num_leaves
!=
tree_config
->
num_leaves
)
{
if
(
tree_config_
->
num_leaves
!=
tree_config
->
num_leaves
)
{
...
@@ -351,7 +390,9 @@ void SerialTreeLearner::FindBestThresholds() {
...
@@ -351,7 +390,9 @@ void SerialTreeLearner::FindBestThresholds() {
// construct histograms for smaller leaf
// construct histograms for smaller leaf
if
(
ordered_bins_
[
feature_index
]
==
nullptr
)
{
if
(
ordered_bins_
[
feature_index
]
==
nullptr
)
{
// if not use ordered bin
// if not use ordered bin
smaller_leaf_histogram_array_
[
feature_index
].
Construct
(
smaller_leaf_splits_
->
data_indices
(),
smaller_leaf_histogram_array_
[
feature_index
].
Construct
(
train_data_
->
FeatureAt
(
feature_index
)
->
bin_data
(),
smaller_leaf_splits_
->
data_indices
(),
smaller_leaf_splits_
->
num_data_in_leaf
(),
smaller_leaf_splits_
->
num_data_in_leaf
(),
smaller_leaf_splits_
->
sum_gradients
(),
smaller_leaf_splits_
->
sum_gradients
(),
smaller_leaf_splits_
->
sum_hessians
(),
smaller_leaf_splits_
->
sum_hessians
(),
...
@@ -380,7 +421,9 @@ void SerialTreeLearner::FindBestThresholds() {
...
@@ -380,7 +421,9 @@ void SerialTreeLearner::FindBestThresholds() {
}
else
{
}
else
{
if
(
ordered_bins_
[
feature_index
]
==
nullptr
)
{
if
(
ordered_bins_
[
feature_index
]
==
nullptr
)
{
// if not use ordered bin
// if not use ordered bin
larger_leaf_histogram_array_
[
feature_index
].
Construct
(
larger_leaf_splits_
->
data_indices
(),
larger_leaf_histogram_array_
[
feature_index
].
Construct
(
train_data_
->
FeatureAt
(
feature_index
)
->
bin_data
(),
larger_leaf_splits_
->
data_indices
(),
larger_leaf_splits_
->
num_data_in_leaf
(),
larger_leaf_splits_
->
num_data_in_leaf
(),
larger_leaf_splits_
->
sum_gradients
(),
larger_leaf_splits_
->
sum_gradients
(),
larger_leaf_splits_
->
sum_hessians
(),
larger_leaf_splits_
->
sum_hessians
(),
...
...
src/treelearner/serial_tree_learner.h
View file @
c8fbd42b
...
@@ -32,6 +32,8 @@ public:
...
@@ -32,6 +32,8 @@ public:
void
Init
(
const
Dataset
*
train_data
)
override
;
void
Init
(
const
Dataset
*
train_data
)
override
;
void
ResetTrainingData
(
const
Dataset
*
train_data
)
override
;
void
ResetConfig
(
const
TreeConfig
*
tree_config
)
override
;
void
ResetConfig
(
const
TreeConfig
*
tree_config
)
override
;
Tree
*
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
)
override
;
Tree
*
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
)
override
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment