Unverified Commit 2e9848c6 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] simplify Dataset processing of label (#5456)

parent 1444a748
...@@ -23,6 +23,12 @@ from .libpath import find_lib_path ...@@ -23,6 +23,12 @@ from .libpath import find_lib_path
_DatasetHandle = ctypes.c_void_p _DatasetHandle = ctypes.c_void_p
_LGBM_EvalFunctionResultType = Tuple[str, float, bool] _LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool] _LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
_LGBM_LabelType = Union[
list,
np.ndarray,
pd_Series,
pd_DataFrame
]
ZERO_THRESHOLD = 1e-35 ZERO_THRESHOLD = 1e-35
...@@ -605,15 +611,6 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica ...@@ -605,15 +611,6 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
return data, feature_name, categorical_feature, pandas_categorical return data, feature_name, categorical_feature, pandas_categorical
def _label_from_pandas(label):
if isinstance(label, pd_DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
_check_for_bad_pandas_dtypes(label.dtypes)
label = np.ravel(label.values.astype(np.float32, copy=False))
return label
def _dump_pandas_categorical(pandas_categorical, file_name=None): def _dump_pandas_categorical(pandas_categorical, file_name=None):
categorical_json = json.dumps(pandas_categorical, default=json_default_with_numpy) categorical_json = json.dumps(pandas_categorical, default=json_default_with_numpy)
pandas_str = f'\npandas_categorical:{categorical_json}\n' pandas_str = f'\npandas_categorical:{categorical_json}\n'
...@@ -1200,7 +1197,7 @@ class Dataset: ...@@ -1200,7 +1197,7 @@ class Dataset:
def __init__( def __init__(
self, self,
data, data,
label=None, label: Optional[_LGBM_LabelType] = None,
reference: Optional["Dataset"] = None, reference: Optional["Dataset"] = None,
weight=None, weight=None,
group=None, group=None,
...@@ -1505,7 +1502,7 @@ class Dataset: ...@@ -1505,7 +1502,7 @@ class Dataset:
def _lazy_init( def _lazy_init(
self, self,
data, data,
label=None, label: Optional[_LGBM_LabelType] = None,
reference: Optional["Dataset"] = None, reference: Optional["Dataset"] = None,
weight=None, weight=None,
group=None, group=None,
...@@ -1525,7 +1522,6 @@ class Dataset: ...@@ -1525,7 +1522,6 @@ class Dataset:
feature_name, feature_name,
categorical_feature, categorical_feature,
self.pandas_categorical) self.pandas_categorical)
label = _label_from_pandas(label)
# process for args # process for args
params = {} if params is None else params params = {} if params is None else params
...@@ -1936,7 +1932,7 @@ class Dataset: ...@@ -1936,7 +1932,7 @@ class Dataset:
def create_valid( def create_valid(
self, self,
data, data,
label=None, label: Optional[_LGBM_LabelType] = None,
weight=None, weight=None,
group=None, group=None,
init_score=None, init_score=None,
...@@ -2276,7 +2272,7 @@ class Dataset: ...@@ -2276,7 +2272,7 @@ class Dataset:
ctypes.c_int(len(feature_name)))) ctypes.c_int(len(feature_name))))
return self return self
def set_label(self, label) -> "Dataset": def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset":
"""Set label of Dataset. """Set label of Dataset.
Parameters Parameters
...@@ -2291,8 +2287,14 @@ class Dataset: ...@@ -2291,8 +2287,14 @@ class Dataset:
""" """
self.label = label self.label = label
if self.handle is not None: if self.handle is not None:
label = list_to_1d_numpy(_label_from_pandas(label), name='label') if isinstance(label, pd_DataFrame):
self.set_field('label', label) if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
_check_for_bad_pandas_dtypes(label.dtypes)
label_array = np.ravel(label.values.astype(np.float32, copy=False))
else:
label_array = list_to_1d_numpy(label, name='label')
self.set_field('label', label_array)
self.label = self.get_field('label') # original values can be modified at cpp side self.label = self.get_field('label') # original values can be modified at cpp side
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment