Commit c67d2890 authored by Guolin Ke's avatar Guolin Ke
Browse files

move pandas support into basic.py

parent b12f9968
...@@ -358,6 +358,39 @@ class Predictor(object): ...@@ -358,6 +358,39 @@ class Predictor(object):
raise ValueError("incorrect number for predict result") raise ValueError("incorrect number for predict result")
return preds, nrow return preds, nrow
# pandas
try:
from pandas import DataFrame
except ImportError:
class DataFrame(object):
pass
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
'bool': 'i'}
def _data_from_pandas(data):
if isinstance(data, DataFrame):
data_dtypes = data.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
bad_fields = [data.columns[i] for i, dtype in
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
data = data.values.astype('float')
return data
def _label_from_pandas(label):
if isinstance(label, DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
label_dtypes = label.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes):
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
label = label.values.astype('float')
return label
class Dataset(object): class Dataset(object):
"""Dataset used in LightGBM. """Dataset used in LightGBM.
...@@ -398,6 +431,8 @@ class Dataset(object): ...@@ -398,6 +431,8 @@ class Dataset(object):
if data is None: if data is None:
self.handle = None self.handle = None
return return
data = _data_from_pandas(data)
label = _label_from_pandas(label)
self.data_has_header = False self.data_has_header = False
"""process for args""" """process for args"""
params = {} if params is None else params params = {} if params is None else params
......
...@@ -6,40 +6,6 @@ import numpy as np ...@@ -6,40 +6,6 @@ import numpy as np
from .basic import LightGBMError, Predictor, Dataset, Booster, is_str from .basic import LightGBMError, Predictor, Dataset, Booster, is_str
from . import callback from . import callback
# pandas
try:
from pandas import DataFrame
except ImportError:
class DataFrame(object):
pass
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
'bool': 'i'}
def _data_from_pandas(data):
if isinstance(data, DataFrame):
data_dtypes = data.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
bad_fields = [data.columns[i] for i, dtype in
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
data = data.values.astype('float')
return data
def _label_from_pandas(label):
if isinstance(label, DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
label_dtypes = label.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes):
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
label = label.values.astype('float')
return label
def _construct_dataset(X_y, reference=None, def _construct_dataset(X_y, reference=None,
params=None, other_fields=None, predictor=None): params=None, other_fields=None, predictor=None):
if 'max_bin' in params: if 'max_bin' in params:
...@@ -61,8 +27,8 @@ def _construct_dataset(X_y, reference=None, ...@@ -61,8 +27,8 @@ def _construct_dataset(X_y, reference=None,
else: else:
if len(X_y) != 2: if len(X_y) != 2:
raise TypeError("should pass (data, label) pair") raise TypeError("should pass (data, label) pair")
data = _data_from_pandas(X_y[0]) data = X_y[0]
label = _label_from_pandas(X_y[1]) label = X_y[1]
if reference is None: if reference is None:
ret = Dataset(data, label=label, max_bin=max_bin, ret = Dataset(data, label=label, max_bin=max_bin,
weight=weight, group=group, predictor=predictor, params=params) weight=weight, group=group, predictor=predictor, params=params)
......
...@@ -54,7 +54,7 @@ def test_regression(): ...@@ -54,7 +54,7 @@ def test_regression():
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
lgb_model = lgb.LGBMRegressor().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2') lgb_model = lgb.LGBMRegressor().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2')
preds = lgb_model.predict(x_test) preds = lgb_model.predict(x_test)
assert mean_squared_error(preds, y_test) < 30 assert mean_squared_error(preds, y_test) < 40
def test_regression_with_custom_objective(): def test_regression_with_custom_objective():
from sklearn.metrics import mean_squared_error from sklearn.metrics import mean_squared_error
...@@ -71,7 +71,7 @@ def test_regression_with_custom_objective(): ...@@ -71,7 +71,7 @@ def test_regression_with_custom_objective():
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
lgb_model = lgb.LGBMRegressor(objective=objective_ls).fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2') lgb_model = lgb.LGBMRegressor(objective=objective_ls).fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2')
preds = lgb_model.predict(x_test) preds = lgb_model.predict(x_test)
assert mean_squared_error(preds, y_test) < 30 assert mean_squared_error(preds, y_test) < 40
def test_binary_classification_with_custom_objective(): def test_binary_classification_with_custom_objective():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment