Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
5d12a8db
"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "6b56a90cd1324a6dbac2afa0a352c9355b0dc3cf"
Commit
5d12a8db
authored
Jan 10, 2017
by
Guolin Ke
Browse files
speed up bagging by multi-threading
parent
4306b22c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
100 additions
and
54 deletions
+100
-54
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+75
-48
src/boosting/gbdt.h
src/boosting/gbdt.h
+25
-6
No files found.
src/boosting/gbdt.cpp
View file @
5d12a8db
#include "gbdt.h"
#include "gbdt.h"
#include <omp.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/feature.h>
#include <LightGBM/feature.h>
...
@@ -27,6 +29,11 @@ GBDT::GBDT()
...
@@ -27,6 +29,11 @@ GBDT::GBDT()
num_iteration_for_pred_
(
0
),
num_iteration_for_pred_
(
0
),
shrinkage_rate_
(
0.1
f
),
shrinkage_rate_
(
0.1
f
),
num_init_iteration_
(
0
)
{
num_init_iteration_
(
0
)
{
#pragma omp parallel
#pragma omp master
{
num_threads_
=
omp_get_num_threads
();
}
}
}
GBDT
::~
GBDT
()
{
GBDT
::~
GBDT
()
{
...
@@ -39,7 +46,9 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
...
@@ -39,7 +46,9 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
num_iteration_for_pred_
=
0
;
num_iteration_for_pred_
=
0
;
max_feature_idx_
=
0
;
max_feature_idx_
=
0
;
num_class_
=
config
->
num_class
;
num_class_
=
config
->
num_class
;
random_
=
Random
(
config
->
bagging_seed
);
for
(
int
i
=
0
;
i
<
num_threads_
;
++
i
)
{
random_
.
emplace_back
(
config
->
bagging_seed
+
i
);
}
train_data_
=
nullptr
;
train_data_
=
nullptr
;
gbdt_config_
=
nullptr
;
gbdt_config_
=
nullptr
;
tree_learner_
=
nullptr
;
tree_learner_
=
nullptr
;
...
@@ -104,13 +113,19 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -104,13 +113,19 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
||
(
gbdt_config_
!=
nullptr
&&
gbdt_config_
->
bagging_fraction
!=
new_config
->
bagging_fraction
))
{
||
(
gbdt_config_
!=
nullptr
&&
gbdt_config_
->
bagging_fraction
!=
new_config
->
bagging_fraction
))
{
// if need bagging, create buffer
// if need bagging, create buffer
if
(
new_config
->
bagging_fraction
<
1.0
&&
new_config
->
bagging_freq
>
0
)
{
if
(
new_config
->
bagging_fraction
<
1.0
&&
new_config
->
bagging_freq
>
0
)
{
out_of_bag_data_indices_
.
resize
(
num_data_
);
bag_data_cnt_
=
static_cast
<
data_size_t
>
(
new_config
->
bagging_fraction
*
num_data_
);
bag_data_indices_
.
resize
(
num_data_
);
bag_data_indices_
.
resize
(
num_data_
);
tmp_indices_
.
resize
(
num_data_
);
offsets_buf_
.
resize
(
num_threads_
);
left_cnts_buf_
.
resize
(
num_threads_
);
right_cnts_buf_
.
resize
(
num_threads_
);
left_write_pos_buf_
.
resize
(
num_threads_
);
right_write_pos_buf_
.
resize
(
num_threads_
);
}
else
{
}
else
{
out_of_bag_data_cnt_
=
0
;
out_of_bag_data_indices_
.
clear
();
bag_data_cnt_
=
num_data_
;
bag_data_cnt_
=
num_data_
;
bag_data_indices_
.
clear
();
bag_data_indices_
.
clear
();
tmp_indices_
.
clear
();
}
}
}
}
train_data_
=
train_data
;
train_data_
=
train_data
;
...
@@ -153,53 +168,65 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
...
@@ -153,53 +168,65 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
valid_metrics_
.
back
().
shrink_to_fit
();
valid_metrics_
.
back
().
shrink_to_fit
();
}
}
data_size_t
GBDT
::
BaggingHelper
(
data_size_t
start
,
data_size_t
cnt
,
data_size_t
*
buffer
){
const
int
tid
=
omp_get_thread_num
();
data_size_t
bag_data_cnt
=
static_cast
<
data_size_t
>
(
gbdt_config_
->
bagging_fraction
*
cnt
);
data_size_t
cur_left_cnt
=
0
;
data_size_t
cur_right_cnt
=
0
;
// random bagging, minimal unit is one record
for
(
data_size_t
i
=
0
;
i
<
cnt
;
++
i
)
{
double
prob
=
(
bag_data_cnt
-
cur_left_cnt
)
/
static_cast
<
double
>
(
cnt
-
i
);
if
(
random_
[
tid
].
NextDouble
()
<
prob
)
{
buffer
[
cur_left_cnt
++
]
=
start
+
i
;
}
else
{
buffer
[
bag_data_cnt
+
cur_right_cnt
++
]
=
start
+
i
;
}
}
CHECK
(
cur_left_cnt
==
bag_data_cnt
);
return
cur_left_cnt
;
}
void
GBDT
::
Bagging
(
int
iter
)
{
void
GBDT
::
Bagging
(
int
iter
)
{
// if need bagging
// if need bagging
if
(
!
out_of_bag_data_indices_
.
empty
()
&&
iter
%
gbdt_config_
->
bagging_freq
==
0
)
{
if
(
bag_data_cnt_
<
num_data_
&&
iter
%
gbdt_config_
->
bagging_freq
==
0
)
{
// if doesn't have query data
const
data_size_t
min_inner_size
=
10000
;
if
(
train_data_
->
metadata
().
query_boundaries
()
==
nullptr
)
{
data_size_t
inner_size
=
(
num_data_
+
num_threads_
-
1
)
/
num_threads_
;
bag_data_cnt_
=
if
(
inner_size
<
min_inner_size
)
{
inner_size
=
min_inner_size
;
}
static_cast
<
data_size_t
>
(
gbdt_config_
->
bagging_fraction
*
num_data_
);
out_of_bag_data_cnt_
=
num_data_
-
bag_data_cnt_
;
#pragma omp parallel for schedule(static, 1)
data_size_t
cur_left_cnt
=
0
;
for
(
int
i
=
0
;
i
<
num_threads_
;
++
i
)
{
data_size_t
cur_right_cnt
=
0
;
left_cnts_buf_
[
i
]
=
0
;
// random bagging, minimal unit is one record
right_cnts_buf_
[
i
]
=
0
;
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
++
i
)
{
data_size_t
cur_start
=
i
*
inner_size
;
double
prob
=
if
(
cur_start
>
num_data_
)
{
continue
;
}
(
bag_data_cnt_
-
cur_left_cnt
)
/
static_cast
<
double
>
(
num_data_
-
i
);
data_size_t
cur_cnt
=
inner_size
;
if
(
random_
.
NextDouble
()
<
prob
)
{
if
(
cur_start
+
cur_cnt
>
num_data_
)
{
cur_cnt
=
num_data_
-
cur_start
;
}
bag_data_indices_
[
cur_left_cnt
++
]
=
i
;
data_size_t
cur_left_count
=
BaggingHelper
(
cur_start
,
cur_cnt
,
tmp_indices_
.
data
()
+
cur_start
);
}
else
{
offsets_buf_
[
i
]
=
cur_start
;
out_of_bag_data_indices_
[
cur_right_cnt
++
]
=
i
;
left_cnts_buf_
[
i
]
=
cur_left_count
;
}
right_cnts_buf_
[
i
]
=
cur_cnt
-
cur_left_count
;
}
data_size_t
left_cnt
=
0
;
left_write_pos_buf_
[
0
]
=
0
;
right_write_pos_buf_
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_threads_
;
++
i
)
{
left_write_pos_buf_
[
i
]
=
left_write_pos_buf_
[
i
-
1
]
+
left_cnts_buf_
[
i
-
1
];
right_write_pos_buf_
[
i
]
=
right_write_pos_buf_
[
i
-
1
]
+
right_cnts_buf_
[
i
-
1
];
}
left_cnt
=
left_write_pos_buf_
[
num_threads_
-
1
]
+
left_cnts_buf_
[
num_threads_
-
1
];
#pragma omp parallel for schedule(static, 1)
for
(
int
i
=
0
;
i
<
num_threads_
;
++
i
)
{
if
(
left_cnts_buf_
[
i
]
>
0
)
{
std
::
memcpy
(
bag_data_indices_
.
data
()
+
left_write_pos_buf_
[
i
],
tmp_indices_
.
data
()
+
offsets_buf_
[
i
],
left_cnts_buf_
[
i
]
*
sizeof
(
data_size_t
));
}
}
}
else
{
if
(
right_cnts_buf_
[
i
]
>
0
)
{
// if have query data
std
::
memcpy
(
bag_data_indices_
.
data
()
+
left_cnt
+
right_write_pos_buf_
[
i
],
const
data_size_t
*
query_boundaries
=
train_data_
->
metadata
().
query_boundaries
();
tmp_indices_
.
data
()
+
offsets_buf_
[
i
]
+
left_cnts_buf_
[
i
],
right_cnts_buf_
[
i
]
*
sizeof
(
data_size_t
));
data_size_t
num_query
=
train_data_
->
metadata
().
num_queries
();
data_size_t
bag_query_cnt
=
static_cast
<
data_size_t
>
(
num_query
*
gbdt_config_
->
bagging_fraction
);
data_size_t
cur_left_query_cnt
=
0
;
data_size_t
cur_left_cnt
=
0
;
data_size_t
cur_right_cnt
=
0
;
// random bagging, minimal unit is one query
for
(
data_size_t
i
=
0
;
i
<
num_query
;
++
i
)
{
double
prob
=
(
bag_query_cnt
-
cur_left_query_cnt
)
/
static_cast
<
double
>
(
num_query
-
i
);
if
(
random_
.
NextDouble
()
<
prob
)
{
for
(
data_size_t
j
=
query_boundaries
[
i
];
j
<
query_boundaries
[
i
+
1
];
++
j
)
{
bag_data_indices_
[
cur_left_cnt
++
]
=
j
;
}
cur_left_query_cnt
++
;
}
else
{
for
(
data_size_t
j
=
query_boundaries
[
i
];
j
<
query_boundaries
[
i
+
1
];
++
j
)
{
out_of_bag_data_indices_
[
cur_right_cnt
++
]
=
j
;
}
}
}
}
bag_data_cnt_
=
cur_left_cnt
;
out_of_bag_data_cnt_
=
num_data_
-
bag_data_cnt_
;
}
}
Log
::
Debug
(
"Re-bagging, using %d data to train"
,
bag_data_cnt_
);
Log
::
Debug
(
"Re-bagging, using %d data to train"
,
bag_data_cnt_
);
// set bagging data to tree learner
// set bagging data to tree learner
...
@@ -209,8 +236,8 @@ void GBDT::Bagging(int iter) {
...
@@ -209,8 +236,8 @@ void GBDT::Bagging(int iter) {
void
GBDT
::
UpdateScoreOutOfBag
(
const
Tree
*
tree
,
const
int
curr_class
)
{
void
GBDT
::
UpdateScoreOutOfBag
(
const
Tree
*
tree
,
const
int
curr_class
)
{
// we need to predict out-of-bag socres of data for boosting
// we need to predict out-of-bag socres of data for boosting
if
(
!
out_of_bag_data_indices_
.
empty
()
)
{
if
(
num_data_
-
bag_data_cnt_
>
0
)
{
train_score_updater_
->
AddScore
(
tree
,
out_of_
bag_data_indices_
.
data
()
,
out_of_
bag_data_cnt_
,
curr_class
);
train_score_updater_
->
AddScore
(
tree
,
bag_data_indices_
.
data
()
+
bag_data_cnt_
,
num_data_
-
bag_data_cnt_
,
curr_class
);
}
}
}
}
...
...
src/boosting/gbdt.h
View file @
5d12a8db
...
@@ -219,7 +219,16 @@ protected:
...
@@ -219,7 +219,16 @@ protected:
* \brief Implement bagging logic
* \brief Implement bagging logic
* \param iter Current interation
* \param iter Current interation
*/
*/
void
Bagging
(
int
iter
);
virtual
void
Bagging
(
int
iter
);
/*!
* \brief Helper function for bagging, used for multi-threading optimization
* \param start start indice of bagging
* \param cnt count
* \param buffer output buffer
* \return count of left size
*/
virtual
data_size_t
BaggingHelper
(
data_size_t
start
,
data_size_t
cnt
,
data_size_t
*
buffer
);
/*!
/*!
* \brief updating score for out-of-bag data.
* \brief updating score for out-of-bag data.
* Data should be update since we may re-bagging data on training
* Data should be update since we may re-bagging data on training
...
@@ -282,20 +291,18 @@ protected:
...
@@ -282,20 +291,18 @@ protected:
std
::
vector
<
score_t
>
gradients_
;
std
::
vector
<
score_t
>
gradients_
;
/*! \brief Secend order derivative of training data */
/*! \brief Secend order derivative of training data */
std
::
vector
<
score_t
>
hessians_
;
std
::
vector
<
score_t
>
hessians_
;
/*! \brief Store the data indices of out-of-bag */
std
::
vector
<
data_size_t
>
out_of_bag_data_indices_
;
/*! \brief Number of out-of-bag data */
data_size_t
out_of_bag_data_cnt_
;
/*! \brief Store the indices of in-bag data */
/*! \brief Store the indices of in-bag data */
std
::
vector
<
data_size_t
>
bag_data_indices_
;
std
::
vector
<
data_size_t
>
bag_data_indices_
;
/*! \brief Number of in-bag data */
/*! \brief Number of in-bag data */
data_size_t
bag_data_cnt_
;
data_size_t
bag_data_cnt_
;
/*! \brief Store the indices of in-bag data */
std
::
vector
<
data_size_t
>
tmp_indices_
;
/*! \brief Number of training data */
/*! \brief Number of training data */
data_size_t
num_data_
;
data_size_t
num_data_
;
/*! \brief Number of classes */
/*! \brief Number of classes */
int
num_class_
;
int
num_class_
;
/*! \brief Random generator, used for bagging */
/*! \brief Random generator, used for bagging */
Random
random_
;
std
::
vector
<
Random
>
random_
;
/*!
/*!
* \brief Sigmoid parameter, used for prediction.
* \brief Sigmoid parameter, used for prediction.
* if > 0 means output score will transform by sigmoid function
* if > 0 means output score will transform by sigmoid function
...
@@ -311,6 +318,18 @@ protected:
...
@@ -311,6 +318,18 @@ protected:
int
num_init_iteration_
;
int
num_init_iteration_
;
/*! \brief Feature names */
/*! \brief Feature names */
std
::
vector
<
std
::
string
>
feature_names_
;
std
::
vector
<
std
::
string
>
feature_names_
;
/*! \brief number of threads */
int
num_threads_
;
/*! \brief Buffer for multi-threading bagging */
std
::
vector
<
data_size_t
>
offsets_buf_
;
/*! \brief Buffer for multi-threading bagging */
std
::
vector
<
data_size_t
>
left_cnts_buf_
;
/*! \brief Buffer for multi-threading bagging */
std
::
vector
<
data_size_t
>
right_cnts_buf_
;
/*! \brief Buffer for multi-threading bagging */
std
::
vector
<
data_size_t
>
left_write_pos_buf_
;
/*! \brief Buffer for multi-threading bagging */
std
::
vector
<
data_size_t
>
right_write_pos_buf_
;
};
};
}
// namespace LightGBM
}
// namespace LightGBM
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment