Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
861de1c1
Commit
861de1c1
authored
Feb 02, 2019
by
Nikita Titov
Committed by
Guolin Ke
Feb 02, 2019
Browse files
improved model loading routines (#1979)
parent
40486b6c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
24 deletions
+46
-24
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+21
-14
src/boosting/gbdt_model_text.cpp
src/boosting/gbdt_model_text.cpp
+3
-3
tests/python_package_test/test_engine.py
tests/python_package_test/test_engine.py
+8
-0
tests/python_package_test/test_sklearn.py
tests/python_package_test/test_sklearn.py
+14
-7
No files found.
python-package/lightgbm/basic.py
View file @
861de1c1
...
@@ -306,21 +306,28 @@ def _dump_pandas_categorical(pandas_categorical, file_name=None):
...
@@ -306,21 +306,28 @@ def _dump_pandas_categorical(pandas_categorical, file_name=None):
def
_load_pandas_categorical
(
file_name
=
None
,
model_str
=
None
):
def
_load_pandas_categorical
(
file_name
=
None
,
model_str
=
None
):
pandas_key
=
'pandas_categorical:'
offset
=
-
len
(
pandas_key
)
if
file_name
is
not
None
:
if
file_name
is
not
None
:
with
open
(
file_name
,
'r'
)
as
f
:
max_offset
=
-
os
.
path
.
getsize
(
file_name
)
with
open
(
file_name
,
'rb'
)
as
f
:
while
True
:
if
offset
<
max_offset
:
offset
=
max_offset
f
.
seek
(
offset
,
os
.
SEEK_END
)
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
last_line
=
lines
[
-
1
]
if
len
(
lines
)
>=
2
:
if
last_line
.
strip
()
==
""
:
break
last_line
=
lines
[
-
2
]
offset
*=
2
if
last_line
.
startswith
(
'pandas_categorical:'
):
last_line
=
decode_string
(
lines
[
-
1
]).
strip
()
return
json
.
loads
(
last_line
[
len
(
'pandas_categorical:'
):])
if
not
last_line
.
startswith
(
pandas_key
):
last_line
=
decode_string
(
lines
[
-
2
]).
strip
()
elif
model_str
is
not
None
:
elif
model_str
is
not
None
:
lines
=
model_str
.
split
(
'
\n
'
)
idx
=
model_str
.
rfind
(
'
\n
'
,
0
,
offset
)
last_line
=
lines
[
-
1
]
last_line
=
model_str
[
idx
:].
strip
()
if
last_line
.
strip
()
==
""
:
if
last_line
.
startswith
(
pandas_key
):
last_line
=
lines
[
-
2
]
return
json
.
loads
(
last_line
[
len
(
pandas_key
):])
if
last_line
.
startswith
(
'pandas_categorical:'
):
else
:
return
json
.
loads
(
last_line
[
len
(
'pandas_categorical:'
):])
return
None
return
None
...
...
src/boosting/gbdt_model_text.cpp
View file @
861de1c1
...
@@ -349,8 +349,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
...
@@ -349,8 +349,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
std
::
unordered_map
<
std
::
string
,
std
::
string
>
key_vals
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
key_vals
;
while
(
p
<
end
)
{
while
(
p
<
end
)
{
auto
line_len
=
Common
::
GetLine
(
p
);
auto
line_len
=
Common
::
GetLine
(
p
);
std
::
string
cur_line
(
p
,
line_len
);
if
(
line_len
>
0
)
{
if
(
line_len
>
0
)
{
std
::
string
cur_line
(
p
,
line_len
);
if
(
!
Common
::
StartsWith
(
cur_line
,
"Tree="
))
{
if
(
!
Common
::
StartsWith
(
cur_line
,
"Tree="
))
{
auto
strs
=
Common
::
Split
(
cur_line
.
c_str
(),
'='
);
auto
strs
=
Common
::
Split
(
cur_line
.
c_str
(),
'='
);
if
(
strs
.
size
()
==
1
)
{
if
(
strs
.
size
()
==
1
)
{
...
@@ -442,8 +442,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
...
@@ -442,8 +442,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
if
(
!
key_vals
.
count
(
"tree_sizes"
))
{
if
(
!
key_vals
.
count
(
"tree_sizes"
))
{
while
(
p
<
end
)
{
while
(
p
<
end
)
{
auto
line_len
=
Common
::
GetLine
(
p
);
auto
line_len
=
Common
::
GetLine
(
p
);
std
::
string
cur_line
(
p
,
line_len
);
if
(
line_len
>
0
)
{
if
(
line_len
>
0
)
{
std
::
string
cur_line
(
p
,
line_len
);
if
(
Common
::
StartsWith
(
cur_line
,
"Tree="
))
{
if
(
Common
::
StartsWith
(
cur_line
,
"Tree="
))
{
p
+=
line_len
;
p
+=
line_len
;
p
=
Common
::
SkipNewLine
(
p
);
p
=
Common
::
SkipNewLine
(
p
);
...
@@ -491,8 +491,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
...
@@ -491,8 +491,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
std
::
stringstream
ss
;
std
::
stringstream
ss
;
while
(
p
<
end
)
{
while
(
p
<
end
)
{
auto
line_len
=
Common
::
GetLine
(
p
);
auto
line_len
=
Common
::
GetLine
(
p
);
std
::
string
cur_line
(
p
,
line_len
);
if
(
line_len
>
0
)
{
if
(
line_len
>
0
)
{
std
::
string
cur_line
(
p
,
line_len
);
if
(
cur_line
==
std
::
string
(
"parameters:"
))
{
if
(
cur_line
==
std
::
string
(
"parameters:"
))
{
is_inparameter
=
true
;
is_inparameter
=
true
;
}
else
if
(
cur_line
==
std
::
string
(
"end of parameters"
))
{
}
else
if
(
cur_line
==
std
::
string
(
"end of parameters"
))
{
...
...
tests/python_package_test/test_engine.py
View file @
861de1c1
...
@@ -551,9 +551,11 @@ class TestEngine(unittest.TestCase):
...
@@ -551,9 +551,11 @@ class TestEngine(unittest.TestCase):
"B"
:
np
.
random
.
permutation
([
1
,
3
]
*
30
),
"B"
:
np
.
random
.
permutation
([
1
,
3
]
*
30
),
"C"
:
np
.
random
.
permutation
([
0.1
,
-
0.1
,
0.2
,
0.2
]
*
15
),
"C"
:
np
.
random
.
permutation
([
0.1
,
-
0.1
,
0.2
,
0.2
]
*
15
),
"D"
:
np
.
random
.
permutation
([
True
,
False
]
*
30
)})
"D"
:
np
.
random
.
permutation
([
True
,
False
]
*
30
)})
cat_cols
=
[]
for
col
in
[
"A"
,
"B"
,
"C"
,
"D"
]:
for
col
in
[
"A"
,
"B"
,
"C"
,
"D"
]:
X
[
col
]
=
X
[
col
].
astype
(
'category'
)
X
[
col
]
=
X
[
col
].
astype
(
'category'
)
X_test
[
col
]
=
X_test
[
col
].
astype
(
'category'
)
X_test
[
col
]
=
X_test
[
col
].
astype
(
'category'
)
cat_cols
.
append
(
X
[
col
].
cat
.
categories
.
tolist
())
params
=
{
params
=
{
'objective'
:
'binary'
,
'objective'
:
'binary'
,
'metric'
:
'binary_logloss'
,
'metric'
:
'binary_logloss'
,
...
@@ -588,6 +590,12 @@ class TestEngine(unittest.TestCase):
...
@@ -588,6 +590,12 @@ class TestEngine(unittest.TestCase):
np
.
testing
.
assert_almost_equal
(
pred0
,
pred4
)
np
.
testing
.
assert_almost_equal
(
pred0
,
pred4
)
np
.
testing
.
assert_almost_equal
(
pred0
,
pred5
)
np
.
testing
.
assert_almost_equal
(
pred0
,
pred5
)
np
.
testing
.
assert_almost_equal
(
pred0
,
pred6
)
np
.
testing
.
assert_almost_equal
(
pred0
,
pred6
)
self
.
assertListEqual
(
gbm0
.
pandas_categorical
,
cat_cols
)
self
.
assertListEqual
(
gbm1
.
pandas_categorical
,
cat_cols
)
self
.
assertListEqual
(
gbm2
.
pandas_categorical
,
cat_cols
)
self
.
assertListEqual
(
gbm3
.
pandas_categorical
,
cat_cols
)
self
.
assertListEqual
(
gbm4
.
pandas_categorical
,
cat_cols
)
self
.
assertListEqual
(
gbm5
.
pandas_categorical
,
cat_cols
)
def
test_reference_chain
(
self
):
def
test_reference_chain
(
self
):
X
=
np
.
random
.
normal
(
size
=
(
100
,
2
))
X
=
np
.
random
.
normal
(
size
=
(
100
,
2
))
...
...
tests/python_package_test/test_sklearn.py
View file @
861de1c1
...
@@ -215,25 +215,32 @@ class TestSklearn(unittest.TestCase):
...
@@ -215,25 +215,32 @@ class TestSklearn(unittest.TestCase):
"B"
:
np
.
random
.
permutation
([
1
,
3
]
*
30
),
"B"
:
np
.
random
.
permutation
([
1
,
3
]
*
30
),
"C"
:
np
.
random
.
permutation
([
0.1
,
-
0.1
,
0.2
,
0.2
]
*
15
),
"C"
:
np
.
random
.
permutation
([
0.1
,
-
0.1
,
0.2
,
0.2
]
*
15
),
"D"
:
np
.
random
.
permutation
([
True
,
False
]
*
30
)})
"D"
:
np
.
random
.
permutation
([
True
,
False
]
*
30
)})
cat_cols
=
[]
for
col
in
[
"A"
,
"B"
,
"C"
,
"D"
]:
for
col
in
[
"A"
,
"B"
,
"C"
,
"D"
]:
X
[
col
]
=
X
[
col
].
astype
(
'category'
)
X
[
col
]
=
X
[
col
].
astype
(
'category'
)
X_test
[
col
]
=
X_test
[
col
].
astype
(
'category'
)
X_test
[
col
]
=
X_test
[
col
].
astype
(
'category'
)
cat_cols
.
append
(
X
[
col
].
cat
.
categories
.
tolist
())
gbm0
=
lgb
.
sklearn
.
LGBMClassifier
().
fit
(
X
,
y
)
gbm0
=
lgb
.
sklearn
.
LGBMClassifier
().
fit
(
X
,
y
)
pred0
=
list
(
gbm0
.
predict
(
X_test
)
)
pred0
=
gbm0
.
predict
(
X_test
)
gbm1
=
lgb
.
sklearn
.
LGBMClassifier
().
fit
(
X
,
y
,
categorical_feature
=
[
0
])
gbm1
=
lgb
.
sklearn
.
LGBMClassifier
().
fit
(
X
,
pd
.
Series
(
y
)
,
categorical_feature
=
[
0
])
pred1
=
list
(
gbm1
.
predict
(
X_test
)
)
pred1
=
gbm1
.
predict
(
X_test
)
gbm2
=
lgb
.
sklearn
.
LGBMClassifier
().
fit
(
X
,
y
,
categorical_feature
=
[
'A'
])
gbm2
=
lgb
.
sklearn
.
LGBMClassifier
().
fit
(
X
,
y
,
categorical_feature
=
[
'A'
])
pred2
=
list
(
gbm2
.
predict
(
X_test
)
)
pred2
=
gbm2
.
predict
(
X_test
)
gbm3
=
lgb
.
sklearn
.
LGBMClassifier
().
fit
(
X
,
y
,
categorical_feature
=
[
'A'
,
'B'
,
'C'
,
'D'
])
gbm3
=
lgb
.
sklearn
.
LGBMClassifier
().
fit
(
X
,
y
,
categorical_feature
=
[
'A'
,
'B'
,
'C'
,
'D'
])
pred3
=
list
(
gbm3
.
predict
(
X_test
)
)
pred3
=
gbm3
.
predict
(
X_test
)
gbm3
.
booster_
.
save_model
(
'categorical.model'
)
gbm3
.
booster_
.
save_model
(
'categorical.model'
)
gbm4
=
lgb
.
Booster
(
model_file
=
'categorical.model'
)
gbm4
=
lgb
.
Booster
(
model_file
=
'categorical.model'
)
pred4
=
list
(
gbm4
.
predict
(
X_test
)
)
pred4
=
gbm4
.
predict
(
X_test
)
pred_prob
=
list
(
gbm0
.
predict_proba
(
X_test
)[:,
1
]
)
pred_prob
=
gbm0
.
predict_proba
(
X_test
)[:,
1
]
np
.
testing
.
assert_almost_equal
(
pred0
,
pred1
)
np
.
testing
.
assert_almost_equal
(
pred0
,
pred1
)
np
.
testing
.
assert_almost_equal
(
pred0
,
pred2
)
np
.
testing
.
assert_almost_equal
(
pred0
,
pred2
)
np
.
testing
.
assert_almost_equal
(
pred0
,
pred3
)
np
.
testing
.
assert_almost_equal
(
pred0
,
pred3
)
np
.
testing
.
assert_almost_equal
(
pred_prob
,
pred4
)
np
.
testing
.
assert_almost_equal
(
pred_prob
,
pred4
)
self
.
assertListEqual
(
gbm0
.
booster_
.
pandas_categorical
,
cat_cols
)
self
.
assertListEqual
(
gbm1
.
booster_
.
pandas_categorical
,
cat_cols
)
self
.
assertListEqual
(
gbm2
.
booster_
.
pandas_categorical
,
cat_cols
)
self
.
assertListEqual
(
gbm3
.
booster_
.
pandas_categorical
,
cat_cols
)
self
.
assertListEqual
(
gbm4
.
pandas_categorical
,
cat_cols
)
def
test_predict
(
self
):
def
test_predict
(
self
):
iris
=
load_iris
()
iris
=
load_iris
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment