"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "ed0a7f2c771569cf0b38a317b25bd12e415395cc"
Unverified Commit 2e9848c6 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] simplify Dataset processing of label (#5456)

parent 1444a748
......@@ -23,6 +23,12 @@ from .libpath import find_lib_path
_DatasetHandle = ctypes.c_void_p
_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
_LGBM_LabelType = Union[
list,
np.ndarray,
pd_Series,
pd_DataFrame
]
ZERO_THRESHOLD = 1e-35
......@@ -605,15 +611,6 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
return data, feature_name, categorical_feature, pandas_categorical
def _label_from_pandas(label):
if isinstance(label, pd_DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
_check_for_bad_pandas_dtypes(label.dtypes)
label = np.ravel(label.values.astype(np.float32, copy=False))
return label
def _dump_pandas_categorical(pandas_categorical, file_name=None):
categorical_json = json.dumps(pandas_categorical, default=json_default_with_numpy)
pandas_str = f'\npandas_categorical:{categorical_json}\n'
......@@ -1200,7 +1197,7 @@ class Dataset:
def __init__(
self,
data,
label=None,
label: Optional[_LGBM_LabelType] = None,
reference: Optional["Dataset"] = None,
weight=None,
group=None,
......@@ -1505,7 +1502,7 @@ class Dataset:
def _lazy_init(
self,
data,
label=None,
label: Optional[_LGBM_LabelType] = None,
reference: Optional["Dataset"] = None,
weight=None,
group=None,
......@@ -1525,7 +1522,6 @@ class Dataset:
feature_name,
categorical_feature,
self.pandas_categorical)
label = _label_from_pandas(label)
# process for args
params = {} if params is None else params
......@@ -1936,7 +1932,7 @@ class Dataset:
def create_valid(
self,
data,
label=None,
label: Optional[_LGBM_LabelType] = None,
weight=None,
group=None,
init_score=None,
......@@ -2276,7 +2272,7 @@ class Dataset:
ctypes.c_int(len(feature_name))))
return self
def set_label(self, label) -> "Dataset":
def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset":
"""Set label of Dataset.
Parameters
......@@ -2291,8 +2287,14 @@ class Dataset:
"""
self.label = label
if self.handle is not None:
label = list_to_1d_numpy(_label_from_pandas(label), name='label')
self.set_field('label', label)
if isinstance(label, pd_DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
_check_for_bad_pandas_dtypes(label.dtypes)
label_array = np.ravel(label.values.astype(np.float32, copy=False))
else:
label_array = list_to_1d_numpy(label, name='label')
self.set_field('label', label_array)
self.label = self.get_field('label') # original values can be modified at cpp side
return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment