Commit b12f9968 authored by Guolin Ke's avatar Guolin Ke
Browse files

add pandas support

parent a1567983
......@@ -6,7 +6,41 @@ import numpy as np
from .basic import LightGBMError, Predictor, Dataset, Booster, is_str
from . import callback
def _construct_dataset(data, reference=None,
# pandas
try:
from pandas import DataFrame
except ImportError:
class DataFrame(object):
pass
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
'bool': 'i'}
def _data_from_pandas(data):
if isinstance(data, DataFrame):
data_dtypes = data.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
bad_fields = [data.columns[i] for i, dtype in
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
data = data.values.astype('float')
return data
def _label_from_pandas(label):
if isinstance(label, DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
label_dtypes = label.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes):
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
label = label.values.astype('float')
return label
def _construct_dataset(X_y, reference=None,
params=None, other_fields=None, predictor=None):
if 'max_bin' in params:
max_bin = int(params['max_bin'])
......@@ -21,18 +55,20 @@ def _construct_dataset(data, reference=None,
weight = None if 'weight' not in other_fields else other_fields['weight']
group = None if 'group' not in other_fields else other_fields['group']
init_score = None if 'init_score' not in other_fields else other_fields['init_score']
if is_str(X_y):
data = X_y
label = None
else:
if len(X_y) != 2:
raise TypeError("should pass (data, label) pair")
data = _data_from_pandas(X_y[0])
label = _label_from_pandas(X_y[1])
if reference is None:
if is_str(data):
ret = Dataset(data, label=None, max_bin=max_bin,
weight=weight, group=group, predictor=predictor, params=params)
else:
ret = Dataset(data[0], data[1], max_bin=max_bin,
weight=weight, group=group, predictor=predictor, params=params)
ret = Dataset(data, label=label, max_bin=max_bin,
weight=weight, group=group, predictor=predictor, params=params)
else:
if is_str(data):
ret = reference.create_valid(data, label=None, weight=weight, group=group, params=params)
else:
ret = reference.create_valid(data[0], data[1], weight, group, params=params)
ret = reference.create_valid(data, label=label, weight=weight, group=group, params=params)
if init_score is not None:
ret.set_init_score(init_score)
return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment