Unverified Commit 2c9d3320 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

[python] convert datatable to numpy directly (#1970)

* convert datatable to numpy directly

* fix according to comments

* updated more docstrings

* simplified isinstance check

* Update compat.py
parent 107b50b1
...@@ -36,7 +36,7 @@ The LightGBM Python module can load data from: ...@@ -36,7 +36,7 @@ The LightGBM Python module can load data from:
- libsvm/tsv/csv/txt format file - libsvm/tsv/csv/txt format file
- NumPy 2D array(s), pandas DataFrame, SciPy sparse matrix - NumPy 2D array(s), pandas DataFrame, H2O DataTable, SciPy sparse matrix
- LightGBM binary file - LightGBM binary file
......
...@@ -13,7 +13,7 @@ from tempfile import NamedTemporaryFile ...@@ -13,7 +13,7 @@ from tempfile import NamedTemporaryFile
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
from .compat import (DataFrame, Series, from .compat import (DataFrame, Series, DataTable,
decode_string, string_type, decode_string, string_type,
integer_types, numeric_types, integer_types, numeric_types,
json, json_default_with_numpy, json, json_default_with_numpy,
...@@ -409,7 +409,7 @@ class _InnerPredictor(object): ...@@ -409,7 +409,7 @@ class _InnerPredictor(object):
Parameters Parameters
---------- ----------
data : string, numpy array, pandas DataFrame or scipy.sparse data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse
Data source for prediction. Data source for prediction.
When data type is string, it represents the path of txt file. When data type is string, it represents the path of txt file.
num_iteration : int, optional (default=-1) num_iteration : int, optional (default=-1)
...@@ -471,6 +471,8 @@ class _InnerPredictor(object): ...@@ -471,6 +471,8 @@ class _InnerPredictor(object):
except BaseException: except BaseException:
raise ValueError('Cannot convert data list to numpy array.') raise ValueError('Cannot convert data list to numpy array.')
preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type) preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type)
elif isinstance(data, DataTable):
preds, nrow = self.__pred_for_np2d(data.to_numpy(), num_iteration, predict_type)
else: else:
try: try:
warnings.warn('Converting data to scipy sparse matrix.') warnings.warn('Converting data to scipy sparse matrix.')
...@@ -650,7 +652,7 @@ class Dataset(object): ...@@ -650,7 +652,7 @@ class Dataset(object):
Parameters Parameters
---------- ----------
data : string, numpy array, pandas DataFrame, scipy.sparse or list of numpy arrays data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse or list of numpy arrays
Data source of Dataset. Data source of Dataset.
If string, it represents the path to txt file. If string, it represents the path to txt file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
...@@ -789,6 +791,8 @@ class Dataset(object): ...@@ -789,6 +791,8 @@ class Dataset(object):
self.__init_from_np2d(data, params_str, ref_dataset) self.__init_from_np2d(data, params_str, ref_dataset)
elif isinstance(data, list) and len(data) > 0 and all(isinstance(x, np.ndarray) for x in data): elif isinstance(data, list) and len(data) > 0 and all(isinstance(x, np.ndarray) for x in data):
self.__init_from_list_np2d(data, params_str, ref_dataset) self.__init_from_list_np2d(data, params_str, ref_dataset)
elif isinstance(data, DataTable):
self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset)
else: else:
try: try:
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
...@@ -1005,7 +1009,7 @@ class Dataset(object): ...@@ -1005,7 +1009,7 @@ class Dataset(object):
Parameters Parameters
---------- ----------
data : string, numpy array, pandas DataFrame, scipy.sparse or list of numpy arrays data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse or list of numpy arrays
Data source of Dataset. Data source of Dataset.
If string, it represents the path to txt file. If string, it represents the path to txt file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
...@@ -1395,7 +1399,7 @@ class Dataset(object): ...@@ -1395,7 +1399,7 @@ class Dataset(object):
Returns Returns
------- -------
data : string, numpy array, pandas DataFrame, scipy.sparse, list of numpy arrays or None data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse, list of numpy arrays or None
Raw data used in the Dataset construction. Raw data used in the Dataset construction.
""" """
if self.handle is None: if self.handle is None:
...@@ -1405,6 +1409,8 @@ class Dataset(object): ...@@ -1405,6 +1409,8 @@ class Dataset(object):
self.data = self.data[self.used_indices, :] self.data = self.data[self.used_indices, :]
elif isinstance(self.data, DataFrame): elif isinstance(self.data, DataFrame):
self.data = self.data.iloc[self.used_indices].copy() self.data = self.data.iloc[self.used_indices].copy()
elif isinstance(self.data, DataTable):
self.data = self.data[self.used_indices, :]
else: else:
warnings.warn("Cannot subset {} type of raw data.\n" warnings.warn("Cannot subset {} type of raw data.\n"
"Returning original raw data".format(type(self.data).__name__)) "Returning original raw data".format(type(self.data).__name__))
...@@ -2156,7 +2162,7 @@ class Booster(object): ...@@ -2156,7 +2162,7 @@ class Booster(object):
Parameters Parameters
---------- ----------
data : string, numpy array, pandas DataFrame or scipy.sparse data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse
Data source for prediction. Data source for prediction.
If string, it represents the path to txt file. If string, it represents the path to txt file.
num_iteration : int or None, optional (default=None) num_iteration : int or None, optional (default=None)
...@@ -2201,7 +2207,7 @@ class Booster(object): ...@@ -2201,7 +2207,7 @@ class Booster(object):
Parameters Parameters
---------- ----------
data : string, numpy array, pandas DataFrame or scipy.sparse data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse
Data source for refit. Data source for refit.
If string, it represents the path to txt file. If string, it represents the path to txt file.
label : list, numpy 1-D array or pandas Series / one-column DataFrame label : list, numpy 1-D array or pandas Series / one-column DataFrame
......
...@@ -90,6 +90,19 @@ try: ...@@ -90,6 +90,19 @@ try:
except ImportError: except ImportError:
GRAPHVIZ_INSTALLED = False GRAPHVIZ_INSTALLED = False
"""datatable"""
try:
from datatable import DataTable
DATATABLE_INSTALLED = True
except ImportError:
DATATABLE_INSTALLED = False
class DataTable(object):
"""Dummy class for DataTable."""
pass
"""sklearn""" """sklearn"""
try: try:
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
......
...@@ -11,7 +11,7 @@ from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase, ...@@ -11,7 +11,7 @@ from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase, LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength, _LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength,
_LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight, _LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight,
argc_, range_, string_type, DataFrame) argc_, range_, string_type, DataFrame, DataTable)
from .engine import train from .engine import train
...@@ -479,7 +479,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -479,7 +479,7 @@ class LGBMModel(_LGBMModelBase):
eval_metric = [eval_metric] if isinstance(eval_metric, (string_type, type(None))) else eval_metric eval_metric = [eval_metric] if isinstance(eval_metric, (string_type, type(None))) else eval_metric
params['metric'] = set(original_metric + eval_metric) params['metric'] = set(original_metric + eval_metric)
if not isinstance(X, DataFrame): if not isinstance(X, (DataFrame, DataTable)):
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2) _X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(_X, _y, sample_weight) _LGBMCheckConsistentLength(_X, _y, sample_weight)
else: else:
...@@ -595,7 +595,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -595,7 +595,7 @@ class LGBMModel(_LGBMModelBase):
""" """
if self._n_features is None: if self._n_features is None:
raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.") raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.")
if not isinstance(X, DataFrame): if not isinstance(X, (DataFrame, DataTable)):
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
n_features = X.shape[1] n_features = X.shape[1]
if self._n_features != n_features: if self._n_features != n_features:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment