Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
646267d2
Unverified
Commit
646267d2
authored
Feb 20, 2021
by
James Lamb
Committed by
GitHub
Feb 20, 2021
Browse files
[dask] use more specific method names on _DaskLGBMModel (#4004)
parent
7f91dc66
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
15 deletions
+15
-15
python-package/lightgbm/dask.py
python-package/lightgbm/dask.py
+15
-15
No files found.
python-package/lightgbm/dask.py
View file @
646267d2
...
...
@@ -465,7 +465,7 @@ class _DaskLGBMModel:
return
_get_dask_client
(
client
=
self
.
client
)
def
_lgb_getstate
(
self
)
->
Dict
[
Any
,
Any
]:
def
_lgb_
dask_
getstate
(
self
)
->
Dict
[
Any
,
Any
]:
"""Remove un-picklable attributes before serialization."""
client
=
self
.
__dict__
.
pop
(
"client"
,
None
)
self
.
_other_params
.
pop
(
"client"
,
None
)
...
...
@@ -474,7 +474,7 @@ class _DaskLGBMModel:
self
.
client
=
client
return
out
def
_fit
(
def
_lgb_dask
_fit
(
self
,
model_factory
:
Type
[
LGBMModel
],
X
:
_DaskMatrixLike
,
...
...
@@ -501,20 +501,20 @@ class _DaskLGBMModel:
)
self
.
set_params
(
**
model
.
get_params
())
self
.
_copy_extra_params
(
model
,
self
)
self
.
_
lgb_dask_
copy_extra_params
(
model
,
self
)
return
self
def
_to_local
(
self
,
model_factory
:
Type
[
LGBMModel
])
->
LGBMModel
:
def
_lgb_dask
_to_local
(
self
,
model_factory
:
Type
[
LGBMModel
])
->
LGBMModel
:
params
=
self
.
get_params
()
params
.
pop
(
"client"
,
None
)
model
=
model_factory
(
**
params
)
self
.
_copy_extra_params
(
self
,
model
)
self
.
_
lgb_dask_
copy_extra_params
(
self
,
model
)
model
.
_other_params
.
pop
(
"client"
,
None
)
return
model
@
staticmethod
def
_copy_extra_params
(
source
:
Union
[
"_DaskLGBMModel"
,
LGBMModel
],
dest
:
Union
[
"_DaskLGBMModel"
,
LGBMModel
])
->
None
:
def
_lgb_dask
_copy_extra_params
(
source
:
Union
[
"_DaskLGBMModel"
,
LGBMModel
],
dest
:
Union
[
"_DaskLGBMModel"
,
LGBMModel
])
->
None
:
params
=
source
.
get_params
()
attributes
=
source
.
__dict__
extra_param_names
=
set
(
attributes
.
keys
()).
difference
(
params
.
keys
())
...
...
@@ -590,7 +590,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
__init__
.
__doc__
=
_base_doc
[:
_base_doc
.
find
(
'Note
\n
'
)]
def
__getstate__
(
self
)
->
Dict
[
Any
,
Any
]:
return
self
.
_lgb_getstate
()
return
self
.
_lgb_
dask_
getstate
()
def
fit
(
self
,
...
...
@@ -600,7 +600,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
**
kwargs
:
Any
)
->
"DaskLGBMClassifier"
:
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return
self
.
_fit
(
return
self
.
_
lgb_dask_
fit
(
model_factory
=
LGBMClassifier
,
X
=
X
,
y
=
y
,
...
...
@@ -670,7 +670,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
model : lightgbm.LGBMClassifier
Local underlying model.
"""
return
self
.
_to_local
(
LGBMClassifier
)
return
self
.
_
lgb_dask_
to_local
(
LGBMClassifier
)
class
DaskLGBMRegressor
(
LGBMRegressor
,
_DaskLGBMModel
):
...
...
@@ -741,7 +741,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
__init__
.
__doc__
=
_base_doc
[:
_base_doc
.
find
(
'Note
\n
'
)]
def
__getstate__
(
self
)
->
Dict
[
Any
,
Any
]:
return
self
.
_lgb_getstate
()
return
self
.
_lgb_
dask_
getstate
()
def
fit
(
self
,
...
...
@@ -751,7 +751,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
**
kwargs
:
Any
)
->
"DaskLGBMRegressor"
:
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return
self
.
_fit
(
return
self
.
_
lgb_dask_
fit
(
model_factory
=
LGBMRegressor
,
X
=
X
,
y
=
y
,
...
...
@@ -802,7 +802,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
model : lightgbm.LGBMRegressor
Local underlying model.
"""
return
self
.
_to_local
(
LGBMRegressor
)
return
self
.
_
lgb_dask_
to_local
(
LGBMRegressor
)
class
DaskLGBMRanker
(
LGBMRanker
,
_DaskLGBMModel
):
...
...
@@ -873,7 +873,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
__init__
.
__doc__
=
_base_doc
[:
_base_doc
.
find
(
'Note
\n
'
)]
def
__getstate__
(
self
)
->
Dict
[
Any
,
Any
]:
return
self
.
_lgb_getstate
()
return
self
.
_lgb_
dask_
getstate
()
def
fit
(
self
,
...
...
@@ -888,7 +888,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
if
init_score
is
not
None
:
raise
RuntimeError
(
'init_score is not currently supported in lightgbm.dask'
)
return
self
.
_fit
(
return
self
.
_
lgb_dask_
fit
(
model_factory
=
LGBMRanker
,
X
=
X
,
y
=
y
,
...
...
@@ -939,4 +939,4 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
model : lightgbm.LGBMRanker
Local underlying model.
"""
return
self
.
_to_local
(
LGBMRanker
)
return
self
.
_
lgb_dask_
to_local
(
LGBMRanker
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment