Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
da91c613
Unverified
Commit
da91c613
authored
Mar 02, 2020
by
Guolin Ke
Committed by
GitHub
Mar 02, 2020
Browse files
fix bug in parallel learning (#2851)
* refix * fix config * avoid to rely on config
parent
9c386db1
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
95 additions
and
68 deletions
+95
-68
include/LightGBM/config.h
include/LightGBM/config.h
+1
-1
src/application/application.cpp
src/application/application.cpp
+2
-2
src/io/config.cpp
src/io/config.cpp
+2
-2
src/treelearner/data_parallel_tree_learner.cpp
src/treelearner/data_parallel_tree_learner.cpp
+1
-1
src/treelearner/serial_tree_learner.cpp
src/treelearner/serial_tree_learner.cpp
+82
-60
src/treelearner/serial_tree_learner.h
src/treelearner/serial_tree_learner.h
+6
-1
src/treelearner/voting_parallel_tree_learner.cpp
src/treelearner/voting_parallel_tree_learner.cpp
+1
-1
No files found.
include/LightGBM/config.h
View file @
da91c613
...
@@ -908,7 +908,7 @@ struct Config {
...
@@ -908,7 +908,7 @@ struct Config {
size_t
file_load_progress_interval_bytes
=
size_t
(
10
)
*
1024
*
1024
*
1024
;
size_t
file_load_progress_interval_bytes
=
size_t
(
10
)
*
1024
*
1024
*
1024
;
bool
is_parallel
=
false
;
bool
is_parallel
=
false
;
bool
is_
parallel_find_bin
=
false
;
bool
is_
data_based_parallel
=
false
;
LIGHTGBM_EXPORT
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
);
LIGHTGBM_EXPORT
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
);
static
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
alias_table
();
static
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
alias_table
();
static
const
std
::
unordered_set
<
std
::
string
>&
parameter_set
();
static
const
std
::
unordered_set
<
std
::
string
>&
parameter_set
();
...
...
src/application/application.cpp
View file @
da91c613
...
@@ -93,7 +93,7 @@ void Application::LoadData() {
...
@@ -93,7 +93,7 @@ void Application::LoadData() {
}
}
// sync up random seed for data partition
// sync up random seed for data partition
if
(
config_
.
is_
parallel_find_bin
)
{
if
(
config_
.
is_
data_based_parallel
)
{
config_
.
data_random_seed
=
Network
::
GlobalSyncUpByMin
(
config_
.
data_random_seed
);
config_
.
data_random_seed
=
Network
::
GlobalSyncUpByMin
(
config_
.
data_random_seed
);
}
}
...
@@ -101,7 +101,7 @@ void Application::LoadData() {
...
@@ -101,7 +101,7 @@ void Application::LoadData() {
DatasetLoader
dataset_loader
(
config_
,
predict_fun
,
DatasetLoader
dataset_loader
(
config_
,
predict_fun
,
config_
.
num_class
,
config_
.
data
.
c_str
());
config_
.
num_class
,
config_
.
data
.
c_str
());
// load Training data
// load Training data
if
(
config_
.
is_
parallel_find_bin
)
{
if
(
config_
.
is_
data_based_parallel
)
{
// load data for parallel training
// load data for parallel training
train_data_
.
reset
(
dataset_loader
.
LoadFromFile
(
config_
.
data
.
c_str
(),
train_data_
.
reset
(
dataset_loader
.
LoadFromFile
(
config_
.
data
.
c_str
(),
Network
::
rank
(),
Network
::
num_machines
()));
Network
::
rank
(),
Network
::
num_machines
()));
...
...
src/io/config.cpp
View file @
da91c613
...
@@ -280,10 +280,10 @@ void Config::CheckParamConflict() {
...
@@ -280,10 +280,10 @@ void Config::CheckParamConflict() {
}
}
if
(
is_single_tree_learner
||
tree_learner
==
std
::
string
(
"feature"
))
{
if
(
is_single_tree_learner
||
tree_learner
==
std
::
string
(
"feature"
))
{
is_
parallel_find_bin
=
false
;
is_
data_based_parallel
=
false
;
}
else
if
(
tree_learner
==
std
::
string
(
"data"
)
}
else
if
(
tree_learner
==
std
::
string
(
"data"
)
||
tree_learner
==
std
::
string
(
"voting"
))
{
||
tree_learner
==
std
::
string
(
"voting"
))
{
is_
parallel_find_bin
=
true
;
is_
data_based_parallel
=
true
;
if
(
histogram_pool_size
>=
0
if
(
histogram_pool_size
>=
0
&&
tree_learner
==
std
::
string
(
"data"
))
{
&&
tree_learner
==
std
::
string
(
"data"
))
{
Log
::
Warning
(
"Histogram LRU queue was enabled (histogram_pool_size=%f).
\n
"
Log
::
Warning
(
"Histogram LRU queue was enabled (histogram_pool_size=%f).
\n
"
...
...
src/treelearner/data_parallel_tree_learner.cpp
View file @
da91c613
...
@@ -241,7 +241,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
...
@@ -241,7 +241,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
template
<
typename
TREELEARNER_T
>
template
<
typename
TREELEARNER_T
>
void
DataParallelTreeLearner
<
TREELEARNER_T
>::
Split
(
Tree
*
tree
,
int
best_Leaf
,
int
*
left_leaf
,
int
*
right_leaf
)
{
void
DataParallelTreeLearner
<
TREELEARNER_T
>::
Split
(
Tree
*
tree
,
int
best_Leaf
,
int
*
left_leaf
,
int
*
right_leaf
)
{
TREELEARNER_T
::
Split
(
tree
,
best_Leaf
,
left_leaf
,
right_leaf
);
this
->
SplitInner
(
tree
,
best_Leaf
,
left_leaf
,
right_leaf
,
false
);
const
SplitInfo
&
best_split_info
=
this
->
best_split_per_leaf_
[
best_Leaf
];
const
SplitInfo
&
best_split_info
=
this
->
best_split_per_leaf_
[
best_Leaf
];
// need update global number of data in leaf
// need update global number of data in leaf
global_data_count_in_leaf_
[
*
left_leaf
]
=
best_split_info
.
left_count
;
global_data_count_in_leaf_
[
*
left_leaf
]
=
best_split_info
.
left_count
;
...
...
src/treelearner/serial_tree_learner.cpp
View file @
da91c613
...
@@ -648,29 +648,37 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
...
@@ -648,29 +648,37 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
return
result_count
;
return
result_count
;
}
}
void
SerialTreeLearner
::
Split
(
Tree
*
tree
,
int
best_leaf
,
int
*
left_leaf
,
int
*
right_leaf
)
{
void
SerialTreeLearner
::
SplitInner
(
Tree
*
tree
,
int
best_leaf
,
int
*
left_leaf
,
Common
::
FunctionTimer
fun_timer
(
"SerialTreeLearner::Split"
,
global_timer
);
int
*
right_leaf
,
bool
update_cnt
)
{
Common
::
FunctionTimer
fun_timer
(
"SerialTreeLearner::SplitInner"
,
global_timer
);
SplitInfo
&
best_split_info
=
best_split_per_leaf_
[
best_leaf
];
SplitInfo
&
best_split_info
=
best_split_per_leaf_
[
best_leaf
];
const
int
inner_feature_index
=
train_data_
->
InnerFeatureIndex
(
best_split_info
.
feature
);
const
int
inner_feature_index
=
train_data_
->
InnerFeatureIndex
(
best_split_info
.
feature
);
if
(
cegb_
!=
nullptr
)
{
if
(
cegb_
!=
nullptr
)
{
cegb_
->
UpdateLeafBestSplits
(
tree
,
best_leaf
,
&
best_split_info
,
&
best_split_per_leaf_
);
cegb_
->
UpdateLeafBestSplits
(
tree
,
best_leaf
,
&
best_split_info
,
&
best_split_per_leaf_
);
}
}
*
left_leaf
=
best_leaf
;
*
left_leaf
=
best_leaf
;
auto
next_leaf_id
=
tree
->
NextLeafId
();
auto
next_leaf_id
=
tree
->
NextLeafId
();
bool
is_numerical_split
=
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
bin_type
()
==
BinType
::
NumericalBin
;
bool
is_numerical_split
=
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
bin_type
()
==
BinType
::
NumericalBin
;
if
(
is_numerical_split
)
{
if
(
is_numerical_split
)
{
auto
threshold_double
=
train_data_
->
RealThreshold
(
inner_feature_index
,
best_split_info
.
threshold
);
auto
threshold_double
=
train_data_
->
RealThreshold
(
inner_feature_index
,
best_split_info
.
threshold
);
data_partition_
->
Split
(
best_leaf
,
train_data_
,
inner_feature_index
,
data_partition_
->
Split
(
best_leaf
,
train_data_
,
inner_feature_index
,
&
best_split_info
.
threshold
,
1
,
best_split_info
.
default_left
,
next_leaf_id
);
&
best_split_info
.
threshold
,
1
,
best_split_info
.
default_left
,
next_leaf_id
);
if
(
update_cnt
)
{
// don't need to update this in data-based parallel model
best_split_info
.
left_count
=
data_partition_
->
leaf_count
(
*
left_leaf
);
best_split_info
.
left_count
=
data_partition_
->
leaf_count
(
*
left_leaf
);
best_split_info
.
right_count
=
data_partition_
->
leaf_count
(
next_leaf_id
);
best_split_info
.
right_count
=
data_partition_
->
leaf_count
(
next_leaf_id
);
}
// split tree, will return right leaf
// split tree, will return right leaf
*
right_leaf
=
tree
->
Split
(
best_leaf
,
*
right_leaf
=
tree
->
Split
(
inner_feature_index
,
best_leaf
,
inner_feature_index
,
best_split_info
.
feature
,
best_split_info
.
feature
,
best_split_info
.
threshold
,
threshold_double
,
best_split_info
.
threshold
,
threshold_double
,
static_cast
<
double
>
(
best_split_info
.
left_output
),
static_cast
<
double
>
(
best_split_info
.
left_output
),
static_cast
<
double
>
(
best_split_info
.
right_output
),
static_cast
<
double
>
(
best_split_info
.
right_output
),
static_cast
<
data_size_t
>
(
best_split_info
.
left_count
),
static_cast
<
data_size_t
>
(
best_split_info
.
left_count
),
...
@@ -681,26 +689,32 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
...
@@ -681,26 +689,32 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
missing_type
(),
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
missing_type
(),
best_split_info
.
default_left
);
best_split_info
.
default_left
);
}
else
{
}
else
{
std
::
vector
<
uint32_t
>
cat_bitset_inner
=
Common
::
ConstructBitset
(
best_split_info
.
cat_threshold
.
data
(),
best_split_info
.
num_cat_threshold
);
std
::
vector
<
uint32_t
>
cat_bitset_inner
=
Common
::
ConstructBitset
(
best_split_info
.
cat_threshold
.
data
(),
best_split_info
.
num_cat_threshold
);
std
::
vector
<
int
>
threshold_int
(
best_split_info
.
num_cat_threshold
);
std
::
vector
<
int
>
threshold_int
(
best_split_info
.
num_cat_threshold
);
for
(
int
i
=
0
;
i
<
best_split_info
.
num_cat_threshold
;
++
i
)
{
for
(
int
i
=
0
;
i
<
best_split_info
.
num_cat_threshold
;
++
i
)
{
threshold_int
[
i
]
=
static_cast
<
int
>
(
train_data_
->
RealThreshold
(
inner_feature_index
,
best_split_info
.
cat_threshold
[
i
]));
threshold_int
[
i
]
=
static_cast
<
int
>
(
train_data_
->
RealThreshold
(
inner_feature_index
,
best_split_info
.
cat_threshold
[
i
]));
}
}
std
::
vector
<
uint32_t
>
cat_bitset
=
Common
::
ConstructBitset
(
threshold_int
.
data
(),
best_split_info
.
num_cat_threshold
);
std
::
vector
<
uint32_t
>
cat_bitset
=
Common
::
ConstructBitset
(
threshold_int
.
data
(),
best_split_info
.
num_cat_threshold
);
data_partition_
->
Split
(
best_leaf
,
train_data_
,
inner_feature_index
,
data_partition_
->
Split
(
best_leaf
,
train_data_
,
inner_feature_index
,
cat_bitset_inner
.
data
(),
static_cast
<
int
>
(
cat_bitset_inner
.
size
()),
best_split_info
.
default_left
,
next_leaf_id
);
cat_bitset_inner
.
data
(),
static_cast
<
int
>
(
cat_bitset_inner
.
size
()),
best_split_info
.
default_left
,
next_leaf_id
);
if
(
update_cnt
)
{
// don't need to update this in data-based parallel model
best_split_info
.
left_count
=
data_partition_
->
leaf_count
(
*
left_leaf
);
best_split_info
.
left_count
=
data_partition_
->
leaf_count
(
*
left_leaf
);
best_split_info
.
right_count
=
data_partition_
->
leaf_count
(
next_leaf_id
);
best_split_info
.
right_count
=
data_partition_
->
leaf_count
(
next_leaf_id
);
}
*
right_leaf
=
tree
->
SplitCategorical
(
best_leaf
,
*
right_leaf
=
tree
->
SplitCategorical
(
inner_feature_index
,
best_leaf
,
inner_feature_index
,
best_split_info
.
feature
,
best_split_info
.
feature
,
cat_bitset_inner
.
data
(),
static_cast
<
int
>
(
cat_bitset_inner
.
size
()),
cat_bitset_inner
.
data
(),
cat_bitset
.
data
(),
static_cast
<
int
>
(
cat_bitset
.
size
()),
static_cast
<
int
>
(
cat_bitset_inner
.
size
()),
cat_bitset
.
data
(),
static_cast
<
int
>
(
cat_bitset
.
size
()),
static_cast
<
double
>
(
best_split_info
.
left_output
),
static_cast
<
double
>
(
best_split_info
.
left_output
),
static_cast
<
double
>
(
best_split_info
.
right_output
),
static_cast
<
double
>
(
best_split_info
.
right_output
),
static_cast
<
data_size_t
>
(
best_split_info
.
left_count
),
static_cast
<
data_size_t
>
(
best_split_info
.
left_count
),
...
@@ -711,26 +725,34 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
...
@@ -711,26 +725,34 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
missing_type
());
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
missing_type
());
}
}
#ifdef DEBUG
#ifdef DEBUG
CHECK
(
*
right_leaf
==
next_leaf_id
);
CHECK
(
*
right_leaf
==
next_leaf_id
);
#endif
#endif
// init the leaves that used on next iteration
// init the leaves that used on next iteration
if
(
best_split_info
.
left_count
<
best_split_info
.
right_count
)
{
if
(
best_split_info
.
left_count
<
best_split_info
.
right_count
)
{
CHECK_GT
(
best_split_info
.
left_count
,
0
);
CHECK_GT
(
best_split_info
.
left_count
,
0
);
smaller_leaf_splits_
->
Init
(
*
left_leaf
,
data_partition_
.
get
(),
best_split_info
.
left_sum_gradient
,
best_split_info
.
left_sum_hessian
);
smaller_leaf_splits_
->
Init
(
*
left_leaf
,
data_partition_
.
get
(),
larger_leaf_splits_
->
Init
(
*
right_leaf
,
data_partition_
.
get
(),
best_split_info
.
right_sum_gradient
,
best_split_info
.
right_sum_hessian
);
best_split_info
.
left_sum_gradient
,
best_split_info
.
left_sum_hessian
);
larger_leaf_splits_
->
Init
(
*
right_leaf
,
data_partition_
.
get
(),
best_split_info
.
right_sum_gradient
,
best_split_info
.
right_sum_hessian
);
}
else
{
}
else
{
CHECK_GT
(
best_split_info
.
right_count
,
0
);
CHECK_GT
(
best_split_info
.
right_count
,
0
);
smaller_leaf_splits_
->
Init
(
*
right_leaf
,
data_partition_
.
get
(),
best_split_info
.
right_sum_gradient
,
best_split_info
.
right_sum_hessian
);
smaller_leaf_splits_
->
Init
(
*
right_leaf
,
data_partition_
.
get
(),
larger_leaf_splits_
->
Init
(
*
left_leaf
,
data_partition_
.
get
(),
best_split_info
.
left_sum_gradient
,
best_split_info
.
left_sum_hessian
);
best_split_info
.
right_sum_gradient
,
best_split_info
.
right_sum_hessian
);
larger_leaf_splits_
->
Init
(
*
left_leaf
,
data_partition_
.
get
(),
best_split_info
.
left_sum_gradient
,
best_split_info
.
left_sum_hessian
);
}
}
constraints_
->
UpdateConstraints
(
constraints_
->
UpdateConstraints
(
is_numerical_split
,
*
left_leaf
,
*
right_leaf
,
is_numerical_split
,
*
left_leaf
,
*
right_leaf
,
best_split_info
.
monotone_type
,
best_split_info
.
monotone_type
,
best_split_info
.
right_output
,
best_split_info
.
right_output
,
best_split_info
.
left_output
);
best_split_info
.
left_output
);
}
}
void
SerialTreeLearner
::
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
std
::
function
<
double
(
const
label_t
*
,
int
)
>
residual_getter
,
void
SerialTreeLearner
::
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
std
::
function
<
double
(
const
label_t
*
,
int
)
>
residual_getter
,
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
{
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
{
...
...
src/treelearner/serial_tree_learner.h
View file @
da91c613
...
@@ -137,7 +137,12 @@ class SerialTreeLearner: public TreeLearner {
...
@@ -137,7 +137,12 @@ class SerialTreeLearner: public TreeLearner {
* \param left_leaf The index of left leaf after splitted.
* \param left_leaf The index of left leaf after splitted.
* \param right_leaf The index of right leaf after splitted.
* \param right_leaf The index of right leaf after splitted.
*/
*/
virtual
void
Split
(
Tree
*
tree
,
int
best_leaf
,
int
*
left_leaf
,
int
*
right_leaf
);
inline
virtual
void
Split
(
Tree
*
tree
,
int
best_leaf
,
int
*
left_leaf
,
int
*
right_leaf
)
{
SplitInner
(
tree
,
best_leaf
,
left_leaf
,
right_leaf
,
true
);
}
void
SplitInner
(
Tree
*
tree
,
int
best_leaf
,
int
*
left_leaf
,
int
*
right_leaf
,
bool
update_cnt
);
/* Force splits with forced_split_json dict and then return num splits forced.*/
/* Force splits with forced_split_json dict and then return num splits forced.*/
virtual
int32_t
ForceSplits
(
Tree
*
tree
,
const
Json
&
forced_split_json
,
int
*
left_leaf
,
virtual
int32_t
ForceSplits
(
Tree
*
tree
,
const
Json
&
forced_split_json
,
int
*
left_leaf
,
...
...
src/treelearner/voting_parallel_tree_learner.cpp
View file @
da91c613
...
@@ -429,7 +429,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
...
@@ -429,7 +429,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
template
<
typename
TREELEARNER_T
>
template
<
typename
TREELEARNER_T
>
void
VotingParallelTreeLearner
<
TREELEARNER_T
>::
Split
(
Tree
*
tree
,
int
best_Leaf
,
int
*
left_leaf
,
int
*
right_leaf
)
{
void
VotingParallelTreeLearner
<
TREELEARNER_T
>::
Split
(
Tree
*
tree
,
int
best_Leaf
,
int
*
left_leaf
,
int
*
right_leaf
)
{
TREELEARNER_T
::
Split
(
tree
,
best_Leaf
,
left_leaf
,
right_leaf
);
this
->
SplitInner
(
tree
,
best_Leaf
,
left_leaf
,
right_leaf
,
false
);
const
SplitInfo
&
best_split_info
=
this
->
best_split_per_leaf_
[
best_Leaf
];
const
SplitInfo
&
best_split_info
=
this
->
best_split_per_leaf_
[
best_Leaf
];
// set the global number of data for leaves
// set the global number of data for leaves
global_data_count_in_leaf_
[
*
left_leaf
]
=
best_split_info
.
left_count
;
global_data_count_in_leaf_
[
*
left_leaf
]
=
best_split_info
.
left_count
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment