"src/vscode:/vscode.git/clone" did not exist on "f0cfbff63f8a228784e52e033533e5cb3fa1b97b"
Unverified Commit 501e6e62 authored by david-cortes's avatar david-cortes Committed by GitHub
Browse files

[python-package] Accept numpy generators as `random_state` (#6174)

parent 5e90255e
......@@ -36,6 +36,16 @@ except ImportError:
concat = None
"""numpy"""
try:
from numpy.random import Generator as np_random_Generator
except ImportError:
class np_random_Generator: # type: ignore
"""Dummy class for np.random.Generator."""
def __init__(self, *args, **kwargs):
pass
"""matplotlib"""
try:
import matplotlib # noqa: F401
......
......@@ -1142,7 +1142,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
client: Optional[Client] = None,
......@@ -1347,7 +1347,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
client: Optional[Client] = None,
......@@ -1517,7 +1517,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
client: Optional[Client] = None,
......
......@@ -15,7 +15,7 @@ from .callback import _EvalResultDict, record_evaluation
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMCpuCount, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase,
dt_DataTable, pd_DataFrame)
dt_DataTable, np_random_Generator, pd_DataFrame)
from .engine import train
__all__ = [
......@@ -448,7 +448,7 @@ class LGBMModel(_LGBMModelBase):
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
**kwargs
......@@ -509,7 +509,7 @@ class LGBMModel(_LGBMModelBase):
random_state : int, RandomState object or None, optional (default=None)
Random number seed.
If int, this number is used to seed the C++ code.
If RandomState object (numpy), a random integer is picked based on its state to seed the C++ code.
If RandomState or Generator object (numpy), a random integer is picked based on its state to seed the C++ code.
If None, default seeds in C++ code are used.
n_jobs : int or None, optional (default=None)
Number of parallel threads to use for training (can be changed at prediction time by
......@@ -710,6 +710,10 @@ class LGBMModel(_LGBMModelBase):
if isinstance(params['random_state'], np.random.RandomState):
params['random_state'] = params['random_state'].randint(np.iinfo(np.int32).max)
elif isinstance(params['random_state'], np_random_Generator):
params['random_state'] = int(
params['random_state'].integers(np.iinfo(np.int32).max)
)
if self._n_classes > 2:
for alias in _ConfigAliases.get('num_class'):
params.pop(alias, None)
......
......@@ -534,11 +534,12 @@ def test_non_serializable_objects_in_callbacks(tmp_path):
assert gbm.booster_.attr_set_inside_callback == 40
def test_random_state_object():
@pytest.mark.parametrize("rng_constructor", [np.random.RandomState, np.random.default_rng])
def test_random_state_object(rng_constructor):
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
state1 = np.random.RandomState(123)
state2 = np.random.RandomState(123)
state1 = rng_constructor(123)
state2 = rng_constructor(123)
clf1 = lgb.LGBMClassifier(n_estimators=10, subsample=0.5, subsample_freq=1, random_state=state1)
clf2 = lgb.LGBMClassifier(n_estimators=10, subsample=0.5, subsample_freq=1, random_state=state2)
# Test if random_state is properly stored
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment