Unverified Commit 501e6e62 authored by david-cortes's avatar david-cortes Committed by GitHub
Browse files

[python-package] Accept numpy generators as `random_state` (#6174)

parent 5e90255e
...@@ -36,6 +36,16 @@ except ImportError: ...@@ -36,6 +36,16 @@ except ImportError:
concat = None concat = None
"""numpy"""
try:
from numpy.random import Generator as np_random_Generator
except ImportError:
class np_random_Generator: # type: ignore
"""Dummy class for np.random.Generator."""
def __init__(self, *args, **kwargs):
pass
"""matplotlib""" """matplotlib"""
try: try:
import matplotlib # noqa: F401 import matplotlib # noqa: F401
......
...@@ -1142,7 +1142,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1142,7 +1142,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
colsample_bytree: float = 1., colsample_bytree: float = 1.,
reg_alpha: float = 0., reg_alpha: float = 0.,
reg_lambda: float = 0., reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None, random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None, n_jobs: Optional[int] = None,
importance_type: str = 'split', importance_type: str = 'split',
client: Optional[Client] = None, client: Optional[Client] = None,
...@@ -1347,7 +1347,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1347,7 +1347,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
colsample_bytree: float = 1., colsample_bytree: float = 1.,
reg_alpha: float = 0., reg_alpha: float = 0.,
reg_lambda: float = 0., reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None, random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None, n_jobs: Optional[int] = None,
importance_type: str = 'split', importance_type: str = 'split',
client: Optional[Client] = None, client: Optional[Client] = None,
...@@ -1517,7 +1517,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1517,7 +1517,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
colsample_bytree: float = 1., colsample_bytree: float = 1.,
reg_alpha: float = 0., reg_alpha: float = 0.,
reg_lambda: float = 0., reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None, random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None, n_jobs: Optional[int] = None,
importance_type: str = 'split', importance_type: str = 'split',
client: Optional[Client] = None, client: Optional[Client] = None,
......
...@@ -15,7 +15,7 @@ from .callback import _EvalResultDict, record_evaluation ...@@ -15,7 +15,7 @@ from .callback import _EvalResultDict, record_evaluation
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray, from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase, _LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMCpuCount, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, _LGBMComputeSampleWeight, _LGBMCpuCount, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase,
dt_DataTable, pd_DataFrame) dt_DataTable, np_random_Generator, pd_DataFrame)
from .engine import train from .engine import train
__all__ = [ __all__ = [
...@@ -448,7 +448,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -448,7 +448,7 @@ class LGBMModel(_LGBMModelBase):
colsample_bytree: float = 1., colsample_bytree: float = 1.,
reg_alpha: float = 0., reg_alpha: float = 0.,
reg_lambda: float = 0., reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None, random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None, n_jobs: Optional[int] = None,
importance_type: str = 'split', importance_type: str = 'split',
**kwargs **kwargs
...@@ -509,7 +509,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -509,7 +509,7 @@ class LGBMModel(_LGBMModelBase):
random_state : int, RandomState object or None, optional (default=None) random_state : int, RandomState object or None, optional (default=None)
Random number seed. Random number seed.
If int, this number is used to seed the C++ code. If int, this number is used to seed the C++ code.
If RandomState object (numpy), a random integer is picked based on its state to seed the C++ code. If RandomState or Generator object (numpy), a random integer is picked based on its state to seed the C++ code.
If None, default seeds in C++ code are used. If None, default seeds in C++ code are used.
n_jobs : int or None, optional (default=None) n_jobs : int or None, optional (default=None)
Number of parallel threads to use for training (can be changed at prediction time by Number of parallel threads to use for training (can be changed at prediction time by
...@@ -710,6 +710,10 @@ class LGBMModel(_LGBMModelBase): ...@@ -710,6 +710,10 @@ class LGBMModel(_LGBMModelBase):
if isinstance(params['random_state'], np.random.RandomState): if isinstance(params['random_state'], np.random.RandomState):
params['random_state'] = params['random_state'].randint(np.iinfo(np.int32).max) params['random_state'] = params['random_state'].randint(np.iinfo(np.int32).max)
elif isinstance(params['random_state'], np_random_Generator):
params['random_state'] = int(
params['random_state'].integers(np.iinfo(np.int32).max)
)
if self._n_classes > 2: if self._n_classes > 2:
for alias in _ConfigAliases.get('num_class'): for alias in _ConfigAliases.get('num_class'):
params.pop(alias, None) params.pop(alias, None)
......
...@@ -534,11 +534,12 @@ def test_non_serializable_objects_in_callbacks(tmp_path): ...@@ -534,11 +534,12 @@ def test_non_serializable_objects_in_callbacks(tmp_path):
assert gbm.booster_.attr_set_inside_callback == 40 assert gbm.booster_.attr_set_inside_callback == 40
def test_random_state_object(): @pytest.mark.parametrize("rng_constructor", [np.random.RandomState, np.random.default_rng])
def test_random_state_object(rng_constructor):
X, y = load_iris(return_X_y=True) X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
state1 = np.random.RandomState(123) state1 = rng_constructor(123)
state2 = np.random.RandomState(123) state2 = rng_constructor(123)
clf1 = lgb.LGBMClassifier(n_estimators=10, subsample=0.5, subsample_freq=1, random_state=state1) clf1 = lgb.LGBMClassifier(n_estimators=10, subsample=0.5, subsample_freq=1, random_state=state1)
clf2 = lgb.LGBMClassifier(n_estimators=10, subsample=0.5, subsample_freq=1, random_state=state2) clf2 = lgb.LGBMClassifier(n_estimators=10, subsample=0.5, subsample_freq=1, random_state=state2)
# Test if random_state is properly stored # Test if random_state is properly stored
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment