Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
49ea824f
Unverified
Commit
49ea824f
authored
Feb 29, 2020
by
Guolin Ke
Committed by
GitHub
Feb 29, 2020
Browse files
fix forced split (#2838)
parent
53137e25
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
7 deletions
+14
-7
src/treelearner/serial_tree_learner.cpp
src/treelearner/serial_tree_learner.cpp
+14
-7
No files found.
src/treelearner/serial_tree_learner.cpp
View file @
49ea824f
...
@@ -546,7 +546,13 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
...
@@ -546,7 +546,13 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
// split tree, will return right leaf
// split tree, will return right leaf
*
left_leaf
=
current_leaf
;
*
left_leaf
=
current_leaf
;
auto
next_leaf_id
=
tree
->
NextLeafId
();
if
(
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
bin_type
()
==
BinType
::
NumericalBin
)
{
if
(
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
bin_type
()
==
BinType
::
NumericalBin
)
{
data_partition_
->
Split
(
current_leaf
,
train_data_
,
inner_feature_index
,
&
current_split_info
.
threshold
,
1
,
current_split_info
.
default_left
,
next_leaf_id
);
current_split_info
.
left_count
=
data_partition_
->
leaf_count
(
*
left_leaf
);
current_split_info
.
right_count
=
data_partition_
->
leaf_count
(
next_leaf_id
);
*
right_leaf
=
tree
->
Split
(
current_leaf
,
*
right_leaf
=
tree
->
Split
(
current_leaf
,
inner_feature_index
,
inner_feature_index
,
current_split_info
.
feature
,
current_split_info
.
feature
,
...
@@ -561,9 +567,6 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
...
@@ -561,9 +567,6 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
static_cast
<
float
>
(
current_split_info
.
gain
),
static_cast
<
float
>
(
current_split_info
.
gain
),
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
missing_type
(),
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
missing_type
(),
current_split_info
.
default_left
);
current_split_info
.
default_left
);
data_partition_
->
Split
(
current_leaf
,
train_data_
,
inner_feature_index
,
&
current_split_info
.
threshold
,
1
,
current_split_info
.
default_left
,
*
right_leaf
);
}
else
{
}
else
{
std
::
vector
<
uint32_t
>
cat_bitset_inner
=
Common
::
ConstructBitset
(
std
::
vector
<
uint32_t
>
cat_bitset_inner
=
Common
::
ConstructBitset
(
current_split_info
.
cat_threshold
.
data
(),
current_split_info
.
num_cat_threshold
);
current_split_info
.
cat_threshold
.
data
(),
current_split_info
.
num_cat_threshold
);
...
@@ -574,6 +577,11 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
...
@@ -574,6 +577,11 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
}
}
std
::
vector
<
uint32_t
>
cat_bitset
=
Common
::
ConstructBitset
(
std
::
vector
<
uint32_t
>
cat_bitset
=
Common
::
ConstructBitset
(
threshold_int
.
data
(),
current_split_info
.
num_cat_threshold
);
threshold_int
.
data
(),
current_split_info
.
num_cat_threshold
);
data_partition_
->
Split
(
current_leaf
,
train_data_
,
inner_feature_index
,
cat_bitset_inner
.
data
(),
static_cast
<
int
>
(
cat_bitset_inner
.
size
()),
current_split_info
.
default_left
,
next_leaf_id
);
current_split_info
.
left_count
=
data_partition_
->
leaf_count
(
*
left_leaf
);
current_split_info
.
right_count
=
data_partition_
->
leaf_count
(
next_leaf_id
);
*
right_leaf
=
tree
->
SplitCategorical
(
current_leaf
,
*
right_leaf
=
tree
->
SplitCategorical
(
current_leaf
,
inner_feature_index
,
inner_feature_index
,
current_split_info
.
feature
,
current_split_info
.
feature
,
...
@@ -589,11 +597,10 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
...
@@ -589,11 +597,10 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
static_cast
<
double
>
(
current_split_info
.
right_sum_hessian
),
static_cast
<
double
>
(
current_split_info
.
right_sum_hessian
),
static_cast
<
float
>
(
current_split_info
.
gain
),
static_cast
<
float
>
(
current_split_info
.
gain
),
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
missing_type
());
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
missing_type
());
data_partition_
->
Split
(
current_leaf
,
train_data_
,
inner_feature_index
,
cat_bitset_inner
.
data
(),
static_cast
<
int
>
(
cat_bitset_inner
.
size
()),
current_split_info
.
default_left
,
*
right_leaf
);
}
}
#ifdef DEBUG
CHECK
(
*
right_leaf
==
next_leaf_id
);
#endif
if
(
current_split_info
.
left_count
<
current_split_info
.
right_count
)
{
if
(
current_split_info
.
left_count
<
current_split_info
.
right_count
)
{
left_smaller
=
true
;
left_smaller
=
true
;
smaller_leaf_splits_
->
Init
(
*
left_leaf
,
data_partition_
.
get
(),
smaller_leaf_splits_
->
Init
(
*
left_leaf
,
data_partition_
.
get
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment