Unverified Commit 2c9d3320 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

[python] convert datatable to numpy directly (#1970)

* convert datatable to numpy directly

* fix according to comments

* updated more docstrings

* simplified isinstance check

* Update compat.py
parent 107b50b1
......@@ -36,7 +36,7 @@ The LightGBM Python module can load data from:
- libsvm/tsv/csv/txt format file
- NumPy 2D array(s), pandas DataFrame, SciPy sparse matrix
- NumPy 2D array(s), pandas DataFrame, H2O DataTable, SciPy sparse matrix
- LightGBM binary file
......
......@@ -13,7 +13,7 @@ from tempfile import NamedTemporaryFile
import numpy as np
import scipy.sparse
from .compat import (DataFrame, Series,
from .compat import (DataFrame, Series, DataTable,
decode_string, string_type,
integer_types, numeric_types,
json, json_default_with_numpy,
......@@ -409,7 +409,7 @@ class _InnerPredictor(object):
Parameters
----------
data : string, numpy array, pandas DataFrame or scipy.sparse
data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse
Data source for prediction.
When data type is string, it represents the path of txt file.
num_iteration : int, optional (default=-1)
......@@ -471,6 +471,8 @@ class _InnerPredictor(object):
except BaseException:
raise ValueError('Cannot convert data list to numpy array.')
preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type)
elif isinstance(data, DataTable):
preds, nrow = self.__pred_for_np2d(data.to_numpy(), num_iteration, predict_type)
else:
try:
warnings.warn('Converting data to scipy sparse matrix.')
......@@ -650,7 +652,7 @@ class Dataset(object):
Parameters
----------
data : string, numpy array, pandas DataFrame, scipy.sparse or list of numpy arrays
data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse or list of numpy arrays
Data source of Dataset.
If string, it represents the path to txt file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
......@@ -789,6 +791,8 @@ class Dataset(object):
self.__init_from_np2d(data, params_str, ref_dataset)
elif isinstance(data, list) and len(data) > 0 and all(isinstance(x, np.ndarray) for x in data):
self.__init_from_list_np2d(data, params_str, ref_dataset)
elif isinstance(data, DataTable):
self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset)
else:
try:
csr = scipy.sparse.csr_matrix(data)
......@@ -1005,7 +1009,7 @@ class Dataset(object):
Parameters
----------
data : string, numpy array, pandas DataFrame, scipy.sparse or list of numpy arrays
data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse or list of numpy arrays
Data source of Dataset.
If string, it represents the path to txt file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
......@@ -1395,7 +1399,7 @@ class Dataset(object):
Returns
-------
data : string, numpy array, pandas DataFrame, scipy.sparse, list of numpy arrays or None
data : string, numpy array, pandas DataFrame, H2O DataTable, scipy.sparse, list of numpy arrays or None
Raw data used in the Dataset construction.
"""
if self.handle is None:
......@@ -1405,6 +1409,8 @@ class Dataset(object):
self.data = self.data[self.used_indices, :]
elif isinstance(self.data, DataFrame):
self.data = self.data.iloc[self.used_indices].copy()
elif isinstance(self.data, DataTable):
self.data = self.data[self.used_indices, :]
else:
warnings.warn("Cannot subset {} type of raw data.\n"
"Returning original raw data".format(type(self.data).__name__))
......@@ -2156,7 +2162,7 @@ class Booster(object):
Parameters
----------
data : string, numpy array, pandas DataFrame or scipy.sparse
data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse
Data source for prediction.
If string, it represents the path to txt file.
num_iteration : int or None, optional (default=None)
......@@ -2201,7 +2207,7 @@ class Booster(object):
Parameters
----------
data : string, numpy array, pandas DataFrame or scipy.sparse
data : string, numpy array, pandas DataFrame, H2O DataTable or scipy.sparse
Data source for refit.
If string, it represents the path to txt file.
label : list, numpy 1-D array or pandas Series / one-column DataFrame
......
......@@ -90,6 +90,19 @@ try:
except ImportError:
GRAPHVIZ_INSTALLED = False
"""datatable"""
try:
from datatable import DataTable
DATATABLE_INSTALLED = True
except ImportError:
DATATABLE_INSTALLED = False
class DataTable(object):
"""Dummy class for DataTable."""
pass
"""sklearn"""
try:
from sklearn.base import BaseEstimator
......
......@@ -11,7 +11,7 @@ from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength,
_LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight,
argc_, range_, string_type, DataFrame)
argc_, range_, string_type, DataFrame, DataTable)
from .engine import train
......@@ -479,7 +479,7 @@ class LGBMModel(_LGBMModelBase):
eval_metric = [eval_metric] if isinstance(eval_metric, (string_type, type(None))) else eval_metric
params['metric'] = set(original_metric + eval_metric)
if not isinstance(X, DataFrame):
if not isinstance(X, (DataFrame, DataTable)):
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(_X, _y, sample_weight)
else:
......@@ -595,7 +595,7 @@ class LGBMModel(_LGBMModelBase):
"""
if self._n_features is None:
raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.")
if not isinstance(X, DataFrame):
if not isinstance(X, (DataFrame, DataTable)):
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
n_features = X.shape[1]
if self._n_features != n_features:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment