Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
b793cd82
Unverified
Commit
b793cd82
authored
Oct 06, 2023
by
José Morales
Committed by
GitHub
Oct 06, 2023
Browse files
ignore unknown parameters when loading from model file (#6126)
parent
8f577de0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
5 deletions
+25
-5
src/boosting/gbdt.h
src/boosting/gbdt.h
+8
-3
tests/python_package_test/test_engine.py
tests/python_package_test/test_engine.py
+17
-2
No files found.
src/boosting/gbdt.h
View file @
b793cd82
...
...
@@ -179,15 +179,20 @@ class GBDT : public GBDTBase {
const
auto
pair
=
Common
::
Split
(
line
.
c_str
(),
":"
);
if
(
pair
[
1
]
==
" ]"
)
continue
;
const
auto
param
=
pair
[
0
].
substr
(
1
);
const
auto
value_str
=
pair
[
1
].
substr
(
1
,
pair
[
1
].
size
()
-
2
);
auto
iter
=
param_types
.
find
(
param
);
if
(
iter
==
param_types
.
end
())
{
Log
::
Warning
(
"Ignoring unrecognized parameter '%s' found in model string."
,
param
.
c_str
());
continue
;
}
std
::
string
param_type
=
iter
->
second
;
if
(
first
)
{
first
=
false
;
str_buf
<<
"
\"
"
;
}
else
{
str_buf
<<
",
\"
"
;
}
const
auto
param
=
pair
[
0
].
substr
(
1
);
const
auto
value_str
=
pair
[
1
].
substr
(
1
,
pair
[
1
].
size
()
-
2
);
const
auto
param_type
=
param_types
.
at
(
param
);
str_buf
<<
param
<<
"
\"
: "
;
if
(
param_type
==
"string"
)
{
str_buf
<<
"
\"
"
<<
value_str
<<
"
\"
"
;
...
...
tests/python_package_test/test_engine.py
View file @
b793cd82
...
...
@@ -1470,7 +1470,7 @@ def test_feature_name_with_non_ascii():
assert
feature_names
==
gbm2
.
feature_name
()
def
test_parameters_are_loaded_from_model_file
(
tmp_path
):
def
test_parameters_are_loaded_from_model_file
(
tmp_path
,
capsys
):
X
=
np
.
hstack
([
np
.
random
.
rand
(
100
,
1
),
np
.
random
.
randint
(
0
,
5
,
(
100
,
2
))])
y
=
np
.
random
.
rand
(
100
)
ds
=
lgb
.
Dataset
(
X
,
y
)
...
...
@@ -1487,8 +1487,18 @@ def test_parameters_are_loaded_from_model_file(tmp_path):
'num_threads'
:
1
,
}
model_file
=
tmp_path
/
'model.txt'
lgb
.
train
(
params
,
ds
,
num_boost_round
=
1
,
categorical_feature
=
[
1
,
2
]).
save_model
(
model_file
)
orig_bst
=
lgb
.
train
(
params
,
ds
,
num_boost_round
=
1
,
categorical_feature
=
[
1
,
2
])
orig_bst
.
save_model
(
model_file
)
with
model_file
.
open
(
'rt'
)
as
f
:
model_contents
=
f
.
readlines
()
params_start
=
model_contents
.
index
(
'parameters:
\n
'
)
model_contents
.
insert
(
params_start
+
1
,
'[max_conflict_rate: 0]
\n
'
)
with
model_file
.
open
(
'wt'
)
as
f
:
f
.
writelines
(
model_contents
)
bst
=
lgb
.
Booster
(
model_file
=
model_file
)
expected_msg
=
"[LightGBM] [Warning] Ignoring unrecognized parameter 'max_conflict_rate' found in model string."
stdout
=
capsys
.
readouterr
().
out
assert
expected_msg
in
stdout
set_params
=
{
k
:
bst
.
params
[
k
]
for
k
in
params
.
keys
()}
assert
set_params
==
params
assert
bst
.
params
[
'categorical_feature'
]
==
[
1
,
2
]
...
...
@@ -1498,6 +1508,11 @@ def test_parameters_are_loaded_from_model_file(tmp_path):
bst2
=
lgb
.
Booster
(
params
=
{
'num_leaves'
:
7
},
model_file
=
model_file
)
assert
bst
.
params
==
bst2
.
params
# check inference isn't affected by unknown parameter
orig_preds
=
orig_bst
.
predict
(
X
)
preds
=
bst
.
predict
(
X
)
np
.
testing
.
assert_allclose
(
preds
,
orig_preds
)
def
test_save_load_copy_pickle
():
def
train_and_predict
(
init_model
=
None
,
return_model
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment