Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
501e6e62
Unverified
Commit
501e6e62
authored
Nov 09, 2023
by
david-cortes
Committed by
GitHub
Nov 08, 2023
Browse files
[python-package] Accept numpy generators as `random_state` (#6174)
parent
5e90255e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
9 deletions
+24
-9
python-package/lightgbm/compat.py
python-package/lightgbm/compat.py
+10
-0
python-package/lightgbm/dask.py
python-package/lightgbm/dask.py
+3
-3
python-package/lightgbm/sklearn.py
python-package/lightgbm/sklearn.py
+7
-3
tests/python_package_test/test_sklearn.py
tests/python_package_test/test_sklearn.py
+4
-3
No files found.
python-package/lightgbm/compat.py
View file @
501e6e62
...
@@ -36,6 +36,16 @@ except ImportError:
...
@@ -36,6 +36,16 @@ except ImportError:
concat
=
None
concat
=
None
"""numpy"""
try
:
from
numpy.random
import
Generator
as
np_random_Generator
except
ImportError
:
class
np_random_Generator
:
# type: ignore
"""Dummy class for np.random.Generator."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
pass
"""matplotlib"""
"""matplotlib"""
try
:
try
:
import
matplotlib
# noqa: F401
import
matplotlib
# noqa: F401
...
...
python-package/lightgbm/dask.py
View file @
501e6e62
...
@@ -1142,7 +1142,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
...
@@ -1142,7 +1142,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
colsample_bytree
:
float
=
1.
,
colsample_bytree
:
float
=
1.
,
reg_alpha
:
float
=
0.
,
reg_alpha
:
float
=
0.
,
reg_lambda
:
float
=
0.
,
reg_lambda
:
float
=
0.
,
random_state
:
Optional
[
Union
[
int
,
np
.
random
.
RandomState
]]
=
None
,
random_state
:
Optional
[
Union
[
int
,
np
.
random
.
RandomState
,
'np.random.Generator'
]]
=
None
,
n_jobs
:
Optional
[
int
]
=
None
,
n_jobs
:
Optional
[
int
]
=
None
,
importance_type
:
str
=
'split'
,
importance_type
:
str
=
'split'
,
client
:
Optional
[
Client
]
=
None
,
client
:
Optional
[
Client
]
=
None
,
...
@@ -1347,7 +1347,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
...
@@ -1347,7 +1347,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
colsample_bytree
:
float
=
1.
,
colsample_bytree
:
float
=
1.
,
reg_alpha
:
float
=
0.
,
reg_alpha
:
float
=
0.
,
reg_lambda
:
float
=
0.
,
reg_lambda
:
float
=
0.
,
random_state
:
Optional
[
Union
[
int
,
np
.
random
.
RandomState
]]
=
None
,
random_state
:
Optional
[
Union
[
int
,
np
.
random
.
RandomState
,
'np.random.Generator'
]]
=
None
,
n_jobs
:
Optional
[
int
]
=
None
,
n_jobs
:
Optional
[
int
]
=
None
,
importance_type
:
str
=
'split'
,
importance_type
:
str
=
'split'
,
client
:
Optional
[
Client
]
=
None
,
client
:
Optional
[
Client
]
=
None
,
...
@@ -1517,7 +1517,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
...
@@ -1517,7 +1517,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
colsample_bytree
:
float
=
1.
,
colsample_bytree
:
float
=
1.
,
reg_alpha
:
float
=
0.
,
reg_alpha
:
float
=
0.
,
reg_lambda
:
float
=
0.
,
reg_lambda
:
float
=
0.
,
random_state
:
Optional
[
Union
[
int
,
np
.
random
.
RandomState
]]
=
None
,
random_state
:
Optional
[
Union
[
int
,
np
.
random
.
RandomState
,
'np.random.Generator'
]]
=
None
,
n_jobs
:
Optional
[
int
]
=
None
,
n_jobs
:
Optional
[
int
]
=
None
,
importance_type
:
str
=
'split'
,
importance_type
:
str
=
'split'
,
client
:
Optional
[
Client
]
=
None
,
client
:
Optional
[
Client
]
=
None
,
...
...
python-package/lightgbm/sklearn.py
View file @
501e6e62
...
@@ -15,7 +15,7 @@ from .callback import _EvalResultDict, record_evaluation
...
@@ -15,7 +15,7 @@ from .callback import _EvalResultDict, record_evaluation
from
.compat
import
(
SKLEARN_INSTALLED
,
LGBMNotFittedError
,
_LGBMAssertAllFinite
,
_LGBMCheckArray
,
from
.compat
import
(
SKLEARN_INSTALLED
,
LGBMNotFittedError
,
_LGBMAssertAllFinite
,
_LGBMCheckArray
,
_LGBMCheckClassificationTargets
,
_LGBMCheckSampleWeight
,
_LGBMCheckXY
,
_LGBMClassifierBase
,
_LGBMCheckClassificationTargets
,
_LGBMCheckSampleWeight
,
_LGBMCheckXY
,
_LGBMClassifierBase
,
_LGBMComputeSampleWeight
,
_LGBMCpuCount
,
_LGBMLabelEncoder
,
_LGBMModelBase
,
_LGBMRegressorBase
,
_LGBMComputeSampleWeight
,
_LGBMCpuCount
,
_LGBMLabelEncoder
,
_LGBMModelBase
,
_LGBMRegressorBase
,
dt_DataTable
,
pd_DataFrame
)
dt_DataTable
,
np_random_Generator
,
pd_DataFrame
)
from
.engine
import
train
from
.engine
import
train
__all__
=
[
__all__
=
[
...
@@ -448,7 +448,7 @@ class LGBMModel(_LGBMModelBase):
...
@@ -448,7 +448,7 @@ class LGBMModel(_LGBMModelBase):
colsample_bytree
:
float
=
1.
,
colsample_bytree
:
float
=
1.
,
reg_alpha
:
float
=
0.
,
reg_alpha
:
float
=
0.
,
reg_lambda
:
float
=
0.
,
reg_lambda
:
float
=
0.
,
random_state
:
Optional
[
Union
[
int
,
np
.
random
.
RandomState
]]
=
None
,
random_state
:
Optional
[
Union
[
int
,
np
.
random
.
RandomState
,
'np.random.Generator'
]]
=
None
,
n_jobs
:
Optional
[
int
]
=
None
,
n_jobs
:
Optional
[
int
]
=
None
,
importance_type
:
str
=
'split'
,
importance_type
:
str
=
'split'
,
**
kwargs
**
kwargs
...
@@ -509,7 +509,7 @@ class LGBMModel(_LGBMModelBase):
...
@@ -509,7 +509,7 @@ class LGBMModel(_LGBMModelBase):
random_state : int, RandomState object or None, optional (default=None)
random_state : int, RandomState object or None, optional (default=None)
Random number seed.
Random number seed.
If int, this number is used to seed the C++ code.
If int, this number is used to seed the C++ code.
If RandomState object (numpy), a random integer is picked based on its state to seed the C++ code.
If RandomState
or Generator
object (numpy), a random integer is picked based on its state to seed the C++ code.
If None, default seeds in C++ code are used.
If None, default seeds in C++ code are used.
n_jobs : int or None, optional (default=None)
n_jobs : int or None, optional (default=None)
Number of parallel threads to use for training (can be changed at prediction time by
Number of parallel threads to use for training (can be changed at prediction time by
...
@@ -710,6 +710,10 @@ class LGBMModel(_LGBMModelBase):
...
@@ -710,6 +710,10 @@ class LGBMModel(_LGBMModelBase):
if
isinstance
(
params
[
'random_state'
],
np
.
random
.
RandomState
):
if
isinstance
(
params
[
'random_state'
],
np
.
random
.
RandomState
):
params
[
'random_state'
]
=
params
[
'random_state'
].
randint
(
np
.
iinfo
(
np
.
int32
).
max
)
params
[
'random_state'
]
=
params
[
'random_state'
].
randint
(
np
.
iinfo
(
np
.
int32
).
max
)
elif
isinstance
(
params
[
'random_state'
],
np_random_Generator
):
params
[
'random_state'
]
=
int
(
params
[
'random_state'
].
integers
(
np
.
iinfo
(
np
.
int32
).
max
)
)
if
self
.
_n_classes
>
2
:
if
self
.
_n_classes
>
2
:
for
alias
in
_ConfigAliases
.
get
(
'num_class'
):
for
alias
in
_ConfigAliases
.
get
(
'num_class'
):
params
.
pop
(
alias
,
None
)
params
.
pop
(
alias
,
None
)
...
...
tests/python_package_test/test_sklearn.py
View file @
501e6e62
...
@@ -534,11 +534,12 @@ def test_non_serializable_objects_in_callbacks(tmp_path):
...
@@ -534,11 +534,12 @@ def test_non_serializable_objects_in_callbacks(tmp_path):
assert
gbm
.
booster_
.
attr_set_inside_callback
==
40
assert
gbm
.
booster_
.
attr_set_inside_callback
==
40
def
test_random_state_object
():
@
pytest
.
mark
.
parametrize
(
"rng_constructor"
,
[
np
.
random
.
RandomState
,
np
.
random
.
default_rng
])
def
test_random_state_object
(
rng_constructor
):
X
,
y
=
load_iris
(
return_X_y
=
True
)
X
,
y
=
load_iris
(
return_X_y
=
True
)
X_train
,
X_test
,
y_train
,
y_test
=
train_test_split
(
X
,
y
,
test_size
=
0.1
,
random_state
=
42
)
X_train
,
X_test
,
y_train
,
y_test
=
train_test_split
(
X
,
y
,
test_size
=
0.1
,
random_state
=
42
)
state1
=
np
.
random
.
RandomState
(
123
)
state1
=
rng_constructor
(
123
)
state2
=
np
.
random
.
RandomState
(
123
)
state2
=
rng_constructor
(
123
)
clf1
=
lgb
.
LGBMClassifier
(
n_estimators
=
10
,
subsample
=
0.5
,
subsample_freq
=
1
,
random_state
=
state1
)
clf1
=
lgb
.
LGBMClassifier
(
n_estimators
=
10
,
subsample
=
0.5
,
subsample_freq
=
1
,
random_state
=
state1
)
clf2
=
lgb
.
LGBMClassifier
(
n_estimators
=
10
,
subsample
=
0.5
,
subsample_freq
=
1
,
random_state
=
state2
)
clf2
=
lgb
.
LGBMClassifier
(
n_estimators
=
10
,
subsample
=
0.5
,
subsample_freq
=
1
,
random_state
=
state2
)
# Test if random_state is properly stored
# Test if random_state is properly stored
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment