Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
ed1e4f8e
Commit
ed1e4f8e
authored
Aug 05, 2017
by
wxchan
Committed by
Guolin Ke
Aug 18, 2017
Browse files
[python] Dataset params back up before training (#786)
* params back up * refine logic
parent
2367b463
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
8 deletions
+19
-8
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+8
-1
python-package/lightgbm/engine.py
python-package/lightgbm/engine.py
+11
-7
No files found.
python-package/lightgbm/basic.py
View file @
ed1e4f8e
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
"""Wrapper c_api of LightGBM"""
"""Wrapper c_api of LightGBM"""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
copy
import
ctypes
import
ctypes
import
os
import
os
import
warnings
import
warnings
...
@@ -591,11 +592,12 @@ class Dataset(object):
...
@@ -591,11 +592,12 @@ class Dataset(object):
self
.
silent
=
silent
self
.
silent
=
silent
self
.
feature_name
=
feature_name
self
.
feature_name
=
feature_name
self
.
categorical_feature
=
categorical_feature
self
.
categorical_feature
=
categorical_feature
self
.
params
=
params
self
.
params
=
copy
.
deepcopy
(
params
)
self
.
free_raw_data
=
free_raw_data
self
.
free_raw_data
=
free_raw_data
self
.
used_indices
=
None
self
.
used_indices
=
None
self
.
_predictor
=
None
self
.
_predictor
=
None
self
.
pandas_categorical
=
None
self
.
pandas_categorical
=
None
self
.
params_back_up
=
None
def
__del__
(
self
):
def
__del__
(
self
):
self
.
_free_handle
()
self
.
_free_handle
()
...
@@ -872,8 +874,13 @@ class Dataset(object):
...
@@ -872,8 +874,13 @@ class Dataset(object):
if
not
self
.
params
:
if
not
self
.
params
:
self
.
params
=
params
self
.
params
=
params
else
:
else
:
self
.
params_back_up
=
copy
.
deepcopy
(
self
.
params
)
self
.
params
.
update
(
params
)
self
.
params
.
update
(
params
)
def
_reverse_update_params
(
self
):
self
.
params
=
copy
.
deepcopy
(
self
.
params_back_up
)
self
.
params_back_up
=
None
def
set_field
(
self
,
field_name
,
data
):
def
set_field
(
self
,
field_name
,
data
):
"""Set property into the Dataset.
"""Set property into the Dataset.
...
...
python-package/lightgbm/engine.py
View file @
ed1e4f8e
...
@@ -128,14 +128,13 @@ def train(params, train_set, num_boost_round=100,
...
@@ -128,14 +128,13 @@ def train(params, train_set, num_boost_round=100,
continue
continue
if
not
isinstance
(
valid_data
,
Dataset
):
if
not
isinstance
(
valid_data
,
Dataset
):
raise
TypeError
(
"Traninig only accepts Dataset object"
)
raise
TypeError
(
"Traninig only accepts Dataset object"
)
valid_data
.
_update_params
(
params
)
valid_data
.
set_reference
(
train_set
)
valid_data
.
set_reference
(
train_set
)
reduced_valid_sets
.
append
(
valid_data
)
reduced_valid_sets
.
append
(
valid_data
)
if
valid_names
is
not
None
and
len
(
valid_names
)
>
i
:
if
valid_names
is
not
None
and
len
(
valid_names
)
>
i
:
name_valid_sets
.
append
(
valid_names
[
i
])
name_valid_sets
.
append
(
valid_names
[
i
])
else
:
else
:
name_valid_sets
.
append
(
'valid_'
+
str
(
i
))
name_valid_sets
.
append
(
'valid_'
+
str
(
i
))
for
valid_data
in
valid_sets
:
valid_data
.
_update_params
(
params
)
"""process callbacks"""
"""process callbacks"""
if
callbacks
is
None
:
if
callbacks
is
None
:
callbacks
=
set
()
callbacks
=
set
()
...
@@ -165,11 +164,16 @@ def train(params, train_set, num_boost_round=100,
...
@@ -165,11 +164,16 @@ def train(params, train_set, num_boost_round=100,
callbacks_after_iter
=
sorted
(
callbacks_after_iter
,
key
=
attrgetter
(
'order'
))
callbacks_after_iter
=
sorted
(
callbacks_after_iter
,
key
=
attrgetter
(
'order'
))
"""construct booster"""
"""construct booster"""
booster
=
Booster
(
params
=
params
,
train_set
=
train_set
)
try
:
if
is_valid_contain_train
:
booster
=
Booster
(
params
=
params
,
train_set
=
train_set
)
booster
.
set_train_data_name
(
train_data_name
)
if
is_valid_contain_train
:
for
valid_set
,
name_valid_set
in
zip
(
reduced_valid_sets
,
name_valid_sets
):
booster
.
set_train_data_name
(
train_data_name
)
booster
.
add_valid
(
valid_set
,
name_valid_set
)
for
valid_set
,
name_valid_set
in
zip
(
reduced_valid_sets
,
name_valid_sets
):
booster
.
add_valid
(
valid_set
,
name_valid_set
)
finally
:
train_set
.
_reverse_update_params
()
for
valid_set
in
reduced_valid_sets
:
valid_set
.
_reverse_update_params
()
booster
.
best_iteration
=
0
booster
.
best_iteration
=
0
"""start training"""
"""start training"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment