"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "ce3e31219c568323d66fa6e918fffcb96b3df92a"
Unverified Commit ec5492f8 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] drop support for h2o datatable (#6894)

parent a7253603
...@@ -35,7 +35,7 @@ The LightGBM Python module can load data from: ...@@ -35,7 +35,7 @@ The LightGBM Python module can load data from:
- LibSVM (zero-based) / TSV / CSV format text file - LibSVM (zero-based) / TSV / CSV format text file
- NumPy 2D array(s), pandas DataFrame, H2O DataTable's Frame (deprecated), SciPy sparse matrix - NumPy 2D array(s), pandas DataFrame, pyarrow Table, SciPy sparse matrix
- LightGBM binary file - LightGBM binary file
......
...@@ -100,7 +100,6 @@ autodoc_default_options = { ...@@ -100,7 +100,6 @@ autodoc_default_options = {
autodoc_mock_imports = [ autodoc_mock_imports = [
"dask", "dask",
"dask.distributed", "dask.distributed",
"datatable",
"graphviz", "graphviz",
"matplotlib", "matplotlib",
"numpy", "numpy",
......
...@@ -35,7 +35,6 @@ from .compat import ( ...@@ -35,7 +35,6 @@ from .compat import (
arrow_is_floating, arrow_is_floating,
arrow_is_integer, arrow_is_integer,
concat, concat,
dt_DataTable,
pa_Array, pa_Array,
pa_chunked_array, pa_chunked_array,
pa_ChunkedArray, pa_ChunkedArray,
...@@ -116,7 +115,6 @@ _LGBM_TrainDataType = Union[ ...@@ -116,7 +115,6 @@ _LGBM_TrainDataType = Union[
Path, Path,
np.ndarray, np.ndarray,
pd_DataFrame, pd_DataFrame,
dt_DataTable,
scipy.sparse.spmatrix, scipy.sparse.spmatrix,
"Sequence", "Sequence",
List["Sequence"], List["Sequence"],
...@@ -137,7 +135,6 @@ _LGBM_PredictDataType = Union[ ...@@ -137,7 +135,6 @@ _LGBM_PredictDataType = Union[
Path, Path,
np.ndarray, np.ndarray,
pd_DataFrame, pd_DataFrame,
dt_DataTable,
scipy.sparse.spmatrix, scipy.sparse.spmatrix,
pa_Table, pa_Table,
] ]
...@@ -577,15 +574,6 @@ class LGBMDeprecationWarning(FutureWarning): ...@@ -577,15 +574,6 @@ class LGBMDeprecationWarning(FutureWarning):
pass pass
def _emit_datatable_deprecation_warning() -> None:
msg = (
"Support for 'datatable' in LightGBM is deprecated, and will be removed in a future release. "
"To avoid this warning, convert 'datatable' inputs to a supported format "
"(for example, use the 'to_numpy()' method)."
)
warnings.warn(msg, category=LGBMDeprecationWarning, stacklevel=2)
class _ConfigAliases: class _ConfigAliases:
# lazy evaluation to allow import without dynamic library, e.g., for docs generation # lazy evaluation to allow import without dynamic library, e.g., for docs generation
aliases = None aliases = None
...@@ -1112,7 +1100,7 @@ class _InnerPredictor: ...@@ -1112,7 +1100,7 @@ class _InnerPredictor:
Parameters Parameters
---------- ----------
data : str, pathlib.Path, numpy array, pandas DataFrame, pyarrow Table, H2O DataTable's Frame (deprecated) or scipy.sparse data : str, pathlib.Path, numpy array, pandas DataFrame, pyarrow Table or scipy.sparse
Data source for prediction. Data source for prediction.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM). If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
start_iteration : int, optional (default=0) start_iteration : int, optional (default=0)
...@@ -1225,14 +1213,6 @@ class _InnerPredictor: ...@@ -1225,14 +1213,6 @@ class _InnerPredictor:
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type, predict_type=predict_type,
) )
elif isinstance(data, dt_DataTable):
_emit_datatable_deprecation_warning()
preds, nrow = self.__pred_for_np2d(
mat=data.to_numpy(),
start_iteration=start_iteration,
num_iteration=num_iteration,
predict_type=predict_type,
)
else: else:
try: try:
_log_warning("Converting data to scipy sparse matrix.") _log_warning("Converting data to scipy sparse matrix.")
...@@ -1790,7 +1770,7 @@ class Dataset: ...@@ -1790,7 +1770,7 @@ class Dataset:
Parameters Parameters
---------- ----------
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame (deprecated), scipy.sparse, Sequence, list of Sequence, list of numpy array or pyarrow Table data : str, pathlib.Path, numpy array, pandas DataFrame, scipy.sparse, Sequence, list of Sequence, list of numpy array or pyarrow Table
Data source of Dataset. Data source of Dataset.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file. If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
...@@ -2196,9 +2176,6 @@ class Dataset: ...@@ -2196,9 +2176,6 @@ class Dataset:
raise TypeError("Data list can only be of ndarray or Sequence") raise TypeError("Data list can only be of ndarray or Sequence")
elif isinstance(data, Sequence): elif isinstance(data, Sequence):
self.__init_from_seqs([data], ref_dataset) self.__init_from_seqs([data], ref_dataset)
elif isinstance(data, dt_DataTable):
_emit_datatable_deprecation_warning()
self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset)
else: else:
try: try:
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
...@@ -2619,7 +2596,7 @@ class Dataset: ...@@ -2619,7 +2596,7 @@ class Dataset:
Parameters Parameters
---------- ----------
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame (deprecated), scipy.sparse, Sequence, list of Sequence or list of numpy array data : str, pathlib.Path, numpy array, pandas DataFrame, scipy.sparse, Sequence, list of Sequence or list of numpy array
Data source of Dataset. Data source of Dataset.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file. If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
...@@ -3276,7 +3253,7 @@ class Dataset: ...@@ -3276,7 +3253,7 @@ class Dataset:
Returns Returns
------- -------
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame (deprecated), scipy.sparse, Sequence, list of Sequence or list of numpy array or None data : str, pathlib.Path, numpy array, pandas DataFrame, scipy.sparse, Sequence, list of Sequence or list of numpy array or None
Raw data used in the Dataset construction. Raw data used in the Dataset construction.
""" """
if self._handle is None: if self._handle is None:
...@@ -3288,9 +3265,6 @@ class Dataset: ...@@ -3288,9 +3265,6 @@ class Dataset:
self.data = self.data[self.used_indices, :] self.data = self.data[self.used_indices, :]
elif isinstance(self.data, pd_DataFrame): elif isinstance(self.data, pd_DataFrame):
self.data = self.data.iloc[self.used_indices].copy() self.data = self.data.iloc[self.used_indices].copy()
elif isinstance(self.data, dt_DataTable):
_emit_datatable_deprecation_warning()
self.data = self.data[self.used_indices, :]
elif isinstance(self.data, Sequence): elif isinstance(self.data, Sequence):
self.data = self.data[self.used_indices] self.data = self.data[self.used_indices]
elif _is_list_of_sequences(self.data) and len(self.data) > 0: elif _is_list_of_sequences(self.data) and len(self.data) > 0:
...@@ -3477,9 +3451,6 @@ class Dataset: ...@@ -3477,9 +3451,6 @@ class Dataset:
self.data = np.hstack((self.data, other.data.toarray())) self.data = np.hstack((self.data, other.data.toarray()))
elif isinstance(other.data, pd_DataFrame): elif isinstance(other.data, pd_DataFrame):
self.data = np.hstack((self.data, other.data.values)) self.data = np.hstack((self.data, other.data.values))
elif isinstance(other.data, dt_DataTable):
_emit_datatable_deprecation_warning()
self.data = np.hstack((self.data, other.data.to_numpy()))
else: else:
self.data = None self.data = None
elif isinstance(self.data, scipy.sparse.spmatrix): elif isinstance(self.data, scipy.sparse.spmatrix):
...@@ -3488,9 +3459,6 @@ class Dataset: ...@@ -3488,9 +3459,6 @@ class Dataset:
self.data = scipy.sparse.hstack((self.data, other.data), format=sparse_format) self.data = scipy.sparse.hstack((self.data, other.data), format=sparse_format)
elif isinstance(other.data, pd_DataFrame): elif isinstance(other.data, pd_DataFrame):
self.data = scipy.sparse.hstack((self.data, other.data.values), format=sparse_format) self.data = scipy.sparse.hstack((self.data, other.data.values), format=sparse_format)
elif isinstance(other.data, dt_DataTable):
_emit_datatable_deprecation_warning()
self.data = scipy.sparse.hstack((self.data, other.data.to_numpy()), format=sparse_format)
else: else:
self.data = None self.data = None
elif isinstance(self.data, pd_DataFrame): elif isinstance(self.data, pd_DataFrame):
...@@ -3506,21 +3474,6 @@ class Dataset: ...@@ -3506,21 +3474,6 @@ class Dataset:
self.data = concat((self.data, pd_DataFrame(other.data.toarray())), axis=1, ignore_index=True) self.data = concat((self.data, pd_DataFrame(other.data.toarray())), axis=1, ignore_index=True)
elif isinstance(other.data, pd_DataFrame): elif isinstance(other.data, pd_DataFrame):
self.data = concat((self.data, other.data), axis=1, ignore_index=True) self.data = concat((self.data, other.data), axis=1, ignore_index=True)
elif isinstance(other.data, dt_DataTable):
_emit_datatable_deprecation_warning()
self.data = concat((self.data, pd_DataFrame(other.data.to_numpy())), axis=1, ignore_index=True)
else:
self.data = None
elif isinstance(self.data, dt_DataTable):
_emit_datatable_deprecation_warning()
if isinstance(other.data, np.ndarray):
self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data)))
elif isinstance(other.data, scipy.sparse.spmatrix):
self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data.toarray())))
elif isinstance(other.data, pd_DataFrame):
self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data.values)))
elif isinstance(other.data, dt_DataTable):
self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data.to_numpy())))
else: else:
self.data = None self.data = None
else: else:
...@@ -4717,7 +4670,7 @@ class Booster: ...@@ -4717,7 +4670,7 @@ class Booster:
Parameters Parameters
---------- ----------
data : str, pathlib.Path, numpy array, pandas DataFrame, pyarrow Table, H2O DataTable's Frame (deprecated) or scipy.sparse data : str, pathlib.Path, numpy array, pandas DataFrame, pyarrow Table or scipy.sparse
Data source for prediction. Data source for prediction.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM). If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
start_iteration : int, optional (default=0) start_iteration : int, optional (default=0)
...@@ -4798,7 +4751,7 @@ class Booster: ...@@ -4798,7 +4751,7 @@ class Booster:
Parameters Parameters
---------- ----------
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame (deprecated), scipy.sparse, Sequence, list of Sequence or list of numpy array data : str, pathlib.Path, numpy array, pandas DataFrame, scipy.sparse, Sequence, list of Sequence or list of numpy array
Data source for refit. Data source for refit.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM). If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array or pyarrow ChunkedArray label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array or pyarrow ChunkedArray
......
...@@ -202,25 +202,6 @@ try: ...@@ -202,25 +202,6 @@ try:
except ImportError: except ImportError:
GRAPHVIZ_INSTALLED = False GRAPHVIZ_INSTALLED = False
"""datatable"""
try:
import datatable
if hasattr(datatable, "Frame"):
dt_DataTable = datatable.Frame
else:
dt_DataTable = datatable.DataTable
DATATABLE_INSTALLED = True
except ImportError:
DATATABLE_INSTALLED = False
class dt_DataTable: # type: ignore
"""Dummy class for datatable.DataTable."""
def __init__(self, *args: Any, **kwargs: Any):
pass
"""dask""" """dask"""
try: try:
from dask import delayed from dask import delayed
......
...@@ -41,7 +41,6 @@ from .compat import ( ...@@ -41,7 +41,6 @@ from .compat import (
_LGBMRegressorBase, _LGBMRegressorBase,
_LGBMValidateData, _LGBMValidateData,
_sklearn_version, _sklearn_version,
dt_DataTable,
pd_DataFrame, pd_DataFrame,
) )
from .engine import train from .engine import train
...@@ -58,7 +57,6 @@ __all__ = [ ...@@ -58,7 +57,6 @@ __all__ = [
] ]
_LGBM_ScikitMatrixLike = Union[ _LGBM_ScikitMatrixLike = Union[
dt_DataTable,
List[Union[List[float], List[int]]], List[Union[List[float], List[int]]],
np.ndarray, np.ndarray,
pd_DataFrame, pd_DataFrame,
...@@ -945,7 +943,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -945,7 +943,7 @@ class LGBMModel(_LGBMModelBase):
params["metric"] = [e for e in eval_metrics_builtin if e not in params["metric"]] + params["metric"] params["metric"] = [e for e in eval_metrics_builtin if e not in params["metric"]] + params["metric"]
params["metric"] = [metric for metric in params["metric"] if metric is not None] params["metric"] = [metric for metric in params["metric"] if metric is not None]
if not isinstance(X, (pd_DataFrame, dt_DataTable)): if not isinstance(X, pd_DataFrame):
_X, _y = _LGBMValidateData( _X, _y = _LGBMValidateData(
self, self,
X, X,
...@@ -1077,7 +1075,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -1077,7 +1075,7 @@ class LGBMModel(_LGBMModelBase):
fit.__doc__ = ( fit.__doc__ = (
_lgbmmodel_doc_fit.format( _lgbmmodel_doc_fit.format(
X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame (deprecated), scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]", X_shape="numpy array, pandas DataFrame, scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
y_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples]", y_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples]",
sample_weight_shape="numpy array, pandas Series, list of int or float of shape = [n_samples] or None, optional (default=None)", sample_weight_shape="numpy array, pandas Series, list of int or float of shape = [n_samples] or None, optional (default=None)",
init_score_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task) or shape = [n_samples, n_classes] (for multi-class task) or None, optional (default=None)", init_score_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task) or shape = [n_samples, n_classes] (for multi-class task) or None, optional (default=None)",
...@@ -1104,7 +1102,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -1104,7 +1102,7 @@ class LGBMModel(_LGBMModelBase):
"""Docstring is set after definition, using a template.""" """Docstring is set after definition, using a template."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.") raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.")
if not isinstance(X, (pd_DataFrame, dt_DataTable)): if not isinstance(X, pd_DataFrame):
X = _LGBMValidateData( X = _LGBMValidateData(
self, self,
X, X,
...@@ -1154,7 +1152,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -1154,7 +1152,7 @@ class LGBMModel(_LGBMModelBase):
predict.__doc__ = _lgbmmodel_doc_predict.format( predict.__doc__ = _lgbmmodel_doc_predict.format(
description="Return the predicted value for each sample.", description="Return the predicted value for each sample.",
X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame (deprecated), scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]", X_shape="numpy array, pandas DataFrame, scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
output_name="predicted_result", output_name="predicted_result",
predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]", predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
...@@ -1648,7 +1646,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1648,7 +1646,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
predict_proba.__doc__ = _lgbmmodel_doc_predict.format( predict_proba.__doc__ = _lgbmmodel_doc_predict.format(
description="Return the predicted probability for each class for each sample.", description="Return the predicted probability for each class for each sample.",
X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame (deprecated), scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]", X_shape="numpy array, pandas DataFrame, scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
output_name="predicted_probability", output_name="predicted_probability",
predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]", predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
......
...@@ -24,10 +24,8 @@ from sklearn.utils.validation import check_is_fitted ...@@ -24,10 +24,8 @@ from sklearn.utils.validation import check_is_fitted
import lightgbm as lgb import lightgbm as lgb
from lightgbm.compat import ( from lightgbm.compat import (
DASK_INSTALLED, DASK_INSTALLED,
DATATABLE_INSTALLED,
PANDAS_INSTALLED, PANDAS_INSTALLED,
_sklearn_version, _sklearn_version,
dt_DataTable,
pd_DataFrame, pd_DataFrame,
pd_Series, pd_Series,
) )
...@@ -1883,14 +1881,12 @@ def test_predict_rejects_inputs_with_incorrect_number_of_features(predict_disabl ...@@ -1883,14 +1881,12 @@ def test_predict_rejects_inputs_with_incorrect_number_of_features(predict_disabl
assert preds.shape[0] == y.shape[0] assert preds.shape[0] == y.shape[0]
@pytest.mark.parametrize("X_type", ["dt_DataTable", "list2d", "numpy", "scipy_csc", "scipy_csr", "pd_DataFrame"]) @pytest.mark.parametrize("X_type", ["list2d", "numpy", "scipy_csc", "scipy_csr", "pd_DataFrame"])
@pytest.mark.parametrize("y_type", ["list1d", "numpy", "pd_Series", "pd_DataFrame"]) @pytest.mark.parametrize("y_type", ["list1d", "numpy", "pd_Series", "pd_DataFrame"])
@pytest.mark.parametrize("task", ["binary-classification", "multiclass-classification", "regression"]) @pytest.mark.parametrize("task", ["binary-classification", "multiclass-classification", "regression"])
def test_classification_and_regression_minimally_work_with_all_all_accepted_data_types(X_type, y_type, task, rng): def test_classification_and_regression_minimally_work_with_all_all_accepted_data_types(X_type, y_type, task, rng):
if any(t.startswith("pd_") for t in [X_type, y_type]) and not PANDAS_INSTALLED: if any(t.startswith("pd_") for t in [X_type, y_type]) and not PANDAS_INSTALLED:
pytest.skip("pandas is not installed") pytest.skip("pandas is not installed")
if any(t.startswith("dt_") for t in [X_type, y_type]) and not DATATABLE_INSTALLED:
pytest.skip("datatable is not installed")
X, y, g = _create_data(task, n_samples=2_000) X, y, g = _create_data(task, n_samples=2_000)
weights = np.abs(rng.standard_normal(size=(y.shape[0],))) weights = np.abs(rng.standard_normal(size=(y.shape[0],)))
...@@ -1902,9 +1898,7 @@ def test_classification_and_regression_minimally_work_with_all_all_accepted_data ...@@ -1902,9 +1898,7 @@ def test_classification_and_regression_minimally_work_with_all_all_accepted_data
raise ValueError(f"Unrecognized task '{task}'") raise ValueError(f"Unrecognized task '{task}'")
X_valid = X * 2 X_valid = X * 2
if X_type == "dt_DataTable": if X_type == "list2d":
X = dt_DataTable(X)
elif X_type == "list2d":
X = X.tolist() X = X.tolist()
elif X_type == "scipy_csc": elif X_type == "scipy_csc":
X = scipy.sparse.csc_matrix(X) X = scipy.sparse.csc_matrix(X)
...@@ -1960,22 +1954,18 @@ def test_classification_and_regression_minimally_work_with_all_all_accepted_data ...@@ -1960,22 +1954,18 @@ def test_classification_and_regression_minimally_work_with_all_all_accepted_data
raise ValueError(f"Unrecognized task: '{task}'") raise ValueError(f"Unrecognized task: '{task}'")
@pytest.mark.parametrize("X_type", ["dt_DataTable", "list2d", "numpy", "scipy_csc", "scipy_csr", "pd_DataFrame"]) @pytest.mark.parametrize("X_type", ["list2d", "numpy", "scipy_csc", "scipy_csr", "pd_DataFrame"])
@pytest.mark.parametrize("y_type", ["list1d", "numpy", "pd_DataFrame", "pd_Series"]) @pytest.mark.parametrize("y_type", ["list1d", "numpy", "pd_DataFrame", "pd_Series"])
@pytest.mark.parametrize("g_type", ["list1d_float", "list1d_int", "numpy", "pd_Series"]) @pytest.mark.parametrize("g_type", ["list1d_float", "list1d_int", "numpy", "pd_Series"])
def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type, g_type, rng): def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type, g_type, rng):
if any(t.startswith("pd_") for t in [X_type, y_type, g_type]) and not PANDAS_INSTALLED: if any(t.startswith("pd_") for t in [X_type, y_type, g_type]) and not PANDAS_INSTALLED:
pytest.skip("pandas is not installed") pytest.skip("pandas is not installed")
if any(t.startswith("dt_") for t in [X_type, y_type, g_type]) and not DATATABLE_INSTALLED:
pytest.skip("datatable is not installed")
X, y, g = _create_data(task="ranking", n_samples=1_000) X, y, g = _create_data(task="ranking", n_samples=1_000)
weights = np.abs(rng.standard_normal(size=(y.shape[0],))) weights = np.abs(rng.standard_normal(size=(y.shape[0],)))
init_score = np.full_like(y, np.mean(y)) init_score = np.full_like(y, np.mean(y))
X_valid = X * 2 X_valid = X * 2
if X_type == "dt_DataTable": if X_type == "list2d":
X = dt_DataTable(X)
elif X_type == "list2d":
X = X.tolist() X = X.tolist()
elif X_type == "scipy_csc": elif X_type == "scipy_csc":
X = scipy.sparse.csc_matrix(X) X = scipy.sparse.csc_matrix(X)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment