Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
f84bfcf9
Unverified
Commit
f84bfcf9
authored
Dec 30, 2022
by
Belinda Trotta
Committed by
GitHub
Dec 30, 2022
Browse files
Check feature indexes in forced split file (fixes #5517) (#5653)
parent
51edbda7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
0 deletions
+48
-0
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+24
-0
src/boosting/gbdt.h
src/boosting/gbdt.h
+5
-0
tests/python_package_test/test_engine.py
tests/python_package_test/test_engine.py
+19
-0
No files found.
src/boosting/gbdt.cpp
View file @
f84bfcf9
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include <chrono>
#include <chrono>
#include <ctime>
#include <ctime>
#include <queue>
#include <sstream>
#include <sstream>
namespace
LightGBM
{
namespace
LightGBM
{
...
@@ -138,6 +139,9 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
...
@@ -138,6 +139,9 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
// get parser config file content
// get parser config file content
parser_config_str_
=
train_data_
->
parser_config_str
();
parser_config_str_
=
train_data_
->
parser_config_str
();
// check that forced splits does not use feature indices larger than dataset size
CheckForcedSplitFeatures
();
// if need bagging, create buffer
// if need bagging, create buffer
data_sample_strategy_
->
ResetSampleConfig
(
config_
.
get
(),
true
);
data_sample_strategy_
->
ResetSampleConfig
(
config_
.
get
(),
true
);
ResetGradientBuffers
();
ResetGradientBuffers
();
...
@@ -155,6 +159,26 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
...
@@ -155,6 +159,26 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
}
}
}
}
void
GBDT
::
CheckForcedSplitFeatures
()
{
std
::
queue
<
Json
>
forced_split_nodes
;
forced_split_nodes
.
push
(
forced_splits_json_
);
while
(
!
forced_split_nodes
.
empty
())
{
Json
node
=
forced_split_nodes
.
front
();
forced_split_nodes
.
pop
();
const
int
feature_index
=
node
[
"feature"
].
int_value
();
if
(
feature_index
>
max_feature_idx_
)
{
Log
::
Fatal
(
"Forced splits file includes feature index %d, but maximum feature index in dataset is %d"
,
feature_index
,
max_feature_idx_
);
}
if
(
node
.
object_items
().
count
(
"left"
)
>
0
)
{
forced_split_nodes
.
push
(
node
[
"left"
]);
}
if
(
node
.
object_items
().
count
(
"right"
)
>
0
)
{
forced_split_nodes
.
push
(
node
[
"right"
]);
}
}
}
void
GBDT
::
AddValidDataset
(
const
Dataset
*
valid_data
,
void
GBDT
::
AddValidDataset
(
const
Dataset
*
valid_data
,
const
std
::
vector
<
const
Metric
*>&
valid_metrics
)
{
const
std
::
vector
<
const
Metric
*>&
valid_metrics
)
{
if
(
!
train_data_
->
CheckAlign
(
*
valid_data
))
{
if
(
!
train_data_
->
CheckAlign
(
*
valid_data
))
{
...
...
src/boosting/gbdt.h
View file @
f84bfcf9
...
@@ -58,6 +58,11 @@ class GBDT : public GBDTBase {
...
@@ -58,6 +58,11 @@ class GBDT : public GBDTBase {
const
ObjectiveFunction
*
objective_function
,
const
ObjectiveFunction
*
objective_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
/*!
* \brief Traverse the tree of forced splits and check that all indices are less than the number of features.
*/
void
CheckForcedSplitFeatures
();
/*!
/*!
* \brief Merge model from other boosting object. Will insert to the front of current boosting object
* \brief Merge model from other boosting object. Will insert to the front of current boosting object
* \param other
* \param other
...
...
tests/python_package_test/test_engine.py
View file @
f84bfcf9
...
@@ -2887,6 +2887,25 @@ def test_node_level_subcol():
...
@@ -2887,6 +2887,25 @@ def test_node_level_subcol():
assert
ret
!=
ret2
assert
ret
!=
ret2
def
test_forced_split_feature_indices
(
tmp_path
):
X
,
y
=
make_synthetic_regression
()
forced_split
=
{
"feature"
:
0
,
"threshold"
:
0.5
,
"left"
:
{
"feature"
:
X
.
shape
[
1
],
"threshold"
:
0.5
},
}
tmp_split_file
=
tmp_path
/
"forced_split.json"
with
open
(
tmp_split_file
,
"w"
)
as
f
:
f
.
write
(
json
.
dumps
(
forced_split
))
lgb_train
=
lgb
.
Dataset
(
X
,
y
)
params
=
{
"objective"
:
"regression"
,
"forcedsplits_filename"
:
tmp_split_file
}
with
pytest
.
raises
(
lgb
.
basic
.
LightGBMError
,
match
=
"Forced splits file includes feature index"
):
bst
=
lgb
.
train
(
params
,
lgb_train
)
def
test_forced_bins
():
def
test_forced_bins
():
x
=
np
.
empty
((
100
,
2
))
x
=
np
.
empty
((
100
,
2
))
x
[:,
0
]
=
np
.
arange
(
0
,
1
,
0.01
)
x
[:,
0
]
=
np
.
arange
(
0
,
1
,
0.01
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment