Commit f893fbf6 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

simplify Dataset class (#163)

* simplify Dataset class

* simplify check output; fix deprecated warning
parent 21ee5947
...@@ -305,7 +305,7 @@ class _InnerPredictor(object): ...@@ -305,7 +305,7 @@ class _InnerPredictor(object):
------- -------
Prediction result Prediction result
""" """
if isinstance(data, (_InnerDataset, Dataset)): if isinstance(data, Dataset):
raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead")
predict_type = C_API_PREDICT_NORMAL predict_type = C_API_PREDICT_NORMAL
if raw_score: if raw_score:
...@@ -493,46 +493,64 @@ def _label_from_pandas(label): ...@@ -493,46 +493,64 @@ def _label_from_pandas(label):
return label return label
class _InnerDataset(object): class Dataset(object):
"""_InnerDataset used in LightGBM. """Dataset in LightGBM."""
_InnerDataset is a internal data structure that used by LightGBM.
This class is not exposed. Please use Dataset instead
"""
def __init__(self, data, label=None, max_bin=255, reference=None, def __init__(self, data, label=None, max_bin=255, reference=None,
weight=None, group=None, predictor=None, weight=None, group=None, silent=False,
silent=False, feature_name=None, feature_name=None, categorical_feature=None, params=None,
categorical_feature=None, params=None): free_raw_data=True):
""" """
_InnerDataset used in LightGBM.
Parameters Parameters
---------- ----------
data : string/numpy array/scipy.sparse data : string/numpy array/scipy.sparse
Data source of _InnerDataset. Data source of Dataset.
When data type is string, it represents the path of txt file When data type is string, it represents the path of txt file
label : list or numpy 1-D array, optional label : list or numpy 1-D array, optional
Label of the data Label of the data
max_bin : int, required max_bin : int, required
Max number of discrete bin for features Max number of discrete bin for features
reference : Other _InnerDataset, optional reference : Other Dataset, optional
If this dataset validation, need to use training data as reference If this dataset validation, need to use training data as reference
weight : list or numpy 1-D array , optional weight : list or numpy 1-D array , optional
Weight for each instance. Weight for each instance.
group : list or numpy 1-D array , optional group : list or numpy 1-D array , optional
Group/query size for dataset Group/query size for dataset
predictor : _InnerPredictor
Used for continuned train
silent : boolean, optional silent : boolean, optional
Whether print messages during construction Whether print messages during construction
feature_name : list of str feature_name : list of str
Feature names Feature names
categorical_feature : list of str or int categorical_feature : list of str or int
Categorical features, type int represents index, \ Categorical features,
type int represents index,
type str represents feature names (need to specify feature_name as well) type str represents feature names (need to specify feature_name as well)
params: dict, optional params: dict, optional
Other parameters Other parameters
free_raw_data: Bool
True if need to free raw data after construct inner dataset
""" """
self.handle = None
self.data = data
self.label = label
self.max_bin = max_bin
self.reference = reference
self.weight = weight
self.group = group
self.silent = silent
self.feature_name = feature_name
self.categorical_feature = categorical_feature
self.params = params
self.free_raw_data = free_raw_data
self._is_constructed = False
self.used_indices = None
self._predictor = None
def __del__(self):
_safe_call(_LIB.LGBM_DatasetFree(self.handle))
def _lazy_init(self, data, label=None, max_bin=255, reference=None,
weight=None, group=None, predictor=None,
silent=False, feature_name=None,
categorical_feature=None, params=None):
if data is None: if data is None:
self.handle = None self.handle = None
return return
...@@ -568,7 +586,7 @@ class _InnerDataset(object): ...@@ -568,7 +586,7 @@ class _InnerDataset(object):
params_str = param_dict_to_str(params) params_str = param_dict_to_str(params)
"""process for reference dataset""" """process for reference dataset"""
ref_dataset = None ref_dataset = None
if isinstance(reference, _InnerDataset): if isinstance(reference, Dataset):
ref_dataset = reference.handle ref_dataset = reference.handle
elif reference is not None: elif reference is not None:
raise TypeError('Reference dataset should be None or dataset instance') raise TypeError('Reference dataset should be None or dataset instance')
...@@ -595,7 +613,7 @@ class _InnerDataset(object): ...@@ -595,7 +613,7 @@ class _InnerDataset(object):
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
self.__init_from_csr(csr, params_str, ref_dataset) self.__init_from_csr(csr, params_str, ref_dataset)
except: except:
raise TypeError('Cannot initialize _InnerDataset from {}'.format(type(data).__name__)) raise TypeError('Cannot initialize Dataset from {}'.format(type(data).__name__))
if label is not None: if label is not None:
self.set_label(label) self.set_label(label)
if self.get_label() is None: if self.get_label() is None:
...@@ -625,65 +643,6 @@ class _InnerDataset(object): ...@@ -625,65 +643,6 @@ class _InnerDataset(object):
# set feature names # set feature names
self.set_feature_name(feature_name) self.set_feature_name(feature_name)
def create_valid(self, data, label=None, weight=None, group=None,
silent=False, params=None):
"""
Create validation data align with current dataset
Parameters
----------
data : string/numpy array/scipy.sparse
Data source of _InnerDataset.
When data type is string, it represents the path of txt file
label : list or numpy 1-D array, optional
Label of the training data.
weight : list or numpy 1-D array , optional
Weight for each instance.
group : list or numpy 1-D array , optional
Group/query size for dataset
silent : boolean, optional
Whether print messages during construction
params: dict, optional
Other parameters
"""
return _InnerDataset(data, label=label, max_bin=self.max_bin, reference=self,
weight=weight, group=group, predictor=self.predictor,
silent=silent, params=params)
def subset(self, used_indices, params=None):
"""
Get subset of current dataset
"""
used_indices = list_to_1d_numpy(used_indices, np.int32, name='used_indices')
ret = _InnerDataset(None)
ret.handle = ctypes.c_void_p()
params_str = param_dict_to_str(params)
_safe_call(_LIB.LGBM_DatasetGetSubset(
self.handle,
used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
used_indices.shape[0],
c_str(params_str),
ctypes.byref(ret.handle)))
ret.max_bin = self.max_bin
ret.predictor = self.predictor
if ret.get_label() is None:
raise ValueError("Label should not be None")
return ret
def set_feature_name(self, feature_name):
"""
set feature names
"""
if feature_name is None:
return
if len(feature_name) != self.num_feature():
raise ValueError("Length of feature_name({}) and num_feature({}) don't match".format(len(feature_name), self.num_feature()))
c_feature_name = [c_str(name) for name in feature_name]
_safe_call(_LIB.LGBM_DatasetSetFeatureNames(
self.handle,
c_array(ctypes.c_char_p, c_feature_name),
len(feature_name)))
def __init_from_np2d(self, mat, params_str, ref_dataset): def __init_from_np2d(self, mat, params_str, ref_dataset):
""" """
Initialize data from a 2-D numpy matrix. Initialize data from a 2-D numpy matrix.
...@@ -757,44 +716,105 @@ class _InnerDataset(object): ...@@ -757,44 +716,105 @@ class _InnerDataset(object):
ref_dataset, ref_dataset,
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
def __del__(self): def construct(self):
_safe_call(_LIB.LGBM_DatasetFree(self.handle)) """Lazy init"""
if not self._is_constructed:
self._is_constructed = True
if self.reference is not None:
if self.used_indices is None:
"""create valid"""
self._lazy_init(self.data, label=self.label, max_bin=self.max_bin, reference=self.reference,
weight=self.weight, group=self.group, predictor=self._predictor,
silent=self.silent, params=self.params)
else:
"""construct subset"""
used_indices = list_to_1d_numpy(self.used_indices, np.int32, name='used_indices')
handle, self.handle = self.handle, ctypes.c_void_p()
params_str = param_dict_to_str(self.params)
_safe_call(_LIB.LGBM_DatasetGetSubset(
handle,
used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
used_indices.shape[0],
c_str(params_str),
ctypes.byref(self.handle)))
if self.get_label() is None:
raise ValueError("Label should not be None.")
else:
"""create train"""
self._lazy_init(self.data, label=self.label, max_bin=self.max_bin,
weight=self.weight, group=self.group, predictor=self._predictor,
silent=self.silent, feature_name=self.feature_name,
categorical_feature=self.categorical_feature, params=self.params)
if self.free_raw_data:
self.data = None
return self
def get_field(self, field_name): def create_valid(self, data, label=None, weight=None, group=None,
"""Get property from the _InnerDataset. silent=False, params=None):
"""
Create validation data align with current dataset
Parameters Parameters
---------- ----------
field_name: str data : string/numpy array/scipy.sparse
The field name of the information Data source of Dataset.
When data type is string, it represents the path of txt file
label : list or numpy 1-D array, optional
Label of the training data.
weight : list or numpy 1-D array , optional
Weight for each instance.
group : list or numpy 1-D array , optional
Group/query size for dataset
silent : boolean, optional
Whether print messages during construction
params: dict, optional
Other parameters
"""
ret = Dataset(data, label=label, max_bin=self.max_bin, reference=self,
weight=weight, group=group, silent=silent, params=params,
free_raw_data=self.free_raw_data)
ret._set_predictor(self._predictor)
return ret
Returns def subset(self, used_indices, params=None):
-------
info : array
A numpy array of information of the data
""" """
tmp_out_len = ctypes.c_int64() Get subset of current dataset
out_type = ctypes.c_int32()
ret = ctypes.POINTER(ctypes.c_void_p)() Parameters
_safe_call(_LIB.LGBM_DatasetGetField( ----------
self.handle, used_indices : list of int
c_str(field_name), Used indices of this subset
ctypes.byref(tmp_out_len), params : dict
ctypes.byref(ret), Other parameters
ctypes.byref(out_type))) """
if out_type.value != FIELD_TYPE_MAPPER[field_name]: ret = Dataset(None, reference=self, feature_name=self.feature_name,
raise TypeError("Return type error for get_field") categorical_feature=self.categorical_feature, params=params)
if tmp_out_len.value == 0: ret._predictor = self._predictor
return None ret.used_indices = used_indices
if out_type.value == C_API_DTYPE_INT32: ret.handle = self.handle
return cint32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)), tmp_out_len.value) return ret
elif out_type.value == C_API_DTYPE_FLOAT32:
return cfloat32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), tmp_out_len.value) def save_binary(self, filename):
"""
Save Dataset to binary file
Parameters
----------
filename : string
Name of the output file.
"""
_safe_call(_LIB.LGBM_DatasetSaveBinary(
self.construct().handle,
c_str(filename)))
def _update_params(self, params):
if not self.params:
self.params = params
else: else:
raise TypeError("Unknown type") self.params.update(params)
def set_field(self, field_name, data): def set_field(self, field_name, data):
"""Set property into the _InnerDataset. """Set property into the Dataset.
Parameters Parameters
---------- ----------
...@@ -832,241 +852,38 @@ class _InnerDataset(object): ...@@ -832,241 +852,38 @@ class _InnerDataset(object):
len(data), len(data),
type_data)) type_data))
def save_binary(self, filename): def get_field(self, field_name):
"""Save _InnerDataset to binary file """Get property from the Dataset.
Parameters
----------
filename : string
Name of the output file.
"""
_safe_call(_LIB.LGBM_DatasetSaveBinary(
self.handle,
c_str(filename)))
def set_label(self, label):
"""Set label of _InnerDataset
Parameters
----------
label: numpy array or list or None
The label information to be set into _InnerDataset
"""
label = list_to_1d_numpy(label, name='label')
self.set_field('label', label)
def set_weight(self, weight):
""" Set weight of each instance.
Parameters
----------
weight : numpy array or list or None
Weight for each data point
"""
if weight is not None:
weight = list_to_1d_numpy(weight, name='weight')
self.set_field('weight', weight)
def set_init_score(self, score):
"""Set init score of booster to start from.
Parameters
----------
score: numpy array or list or None
Init score for booster
"""
if score is not None:
score = list_to_1d_numpy(score, name='init_score')
self.set_field('init_score', score)
def set_group(self, group):
"""Set group size of _InnerDataset (used for ranking).
Parameters Parameters
---------- ----------
group : numpy array or list or None field_name: str
Group size of each group The field name of the information
"""
if group is not None:
group = list_to_1d_numpy(group, np.int32, name='group')
self.set_field('group', group)
def get_label(self):
"""Get the label of the _InnerDataset.
Returns
-------
label : array
"""
return self.get_field('label')
def get_weight(self):
"""Get the weight of the _InnerDataset.
Returns
-------
weight : array
"""
return self.get_field('weight')
def get_init_score(self):
"""Get the initial score of the _InnerDataset.
Returns
-------
init_score : array
"""
return self.get_field('init_score')
def get_group(self):
"""Get the initial score of the _InnerDataset.
Returns
-------
init_score : array
"""
return self.get_field('group')
def num_data(self):
"""Get the number of rows in the _InnerDataset.
Returns
-------
number of rows : int
"""
ret = ctypes.c_int64()
_safe_call(_LIB.LGBM_DatasetGetNumData(self.handle,
ctypes.byref(ret)))
return ret.value
def num_feature(self):
"""Get the number of columns (features) in the _InnerDataset.
Returns Returns
------- -------
number of columns : int info : array
""" A numpy array of information of the data
ret = ctypes.c_int64()
_safe_call(_LIB.LGBM_DatasetGetNumFeature(self.handle,
ctypes.byref(ret)))
return ret.value
class Dataset(object):
"""High level Dataset used in LightGBM.
"""
def __init__(self, data, label=None, max_bin=255, reference=None,
weight=None, group=None, silent=False,
feature_name=None, categorical_feature=None, params=None,
free_raw_data=True):
"""
Parameters
----------
data : string/numpy array/scipy.sparse
Data source of Dataset.
When data type is string, it represents the path of txt file
label : list or numpy 1-D array, optional
Label of the data
max_bin : int, required
Max number of discrete bin for features
reference : Other Dataset, optional
If this dataset validation, need to use training data as reference
weight : list or numpy 1-D array , optional
Weight for each instance.
group : list or numpy 1-D array , optional
Group/query size for dataset
silent : boolean, optional
Whether print messages during construction
feature_name : list of str
Feature names
categorical_feature : list of str or int
Categorical features,
type int represents index,
type str represents feature names (need to specify feature_name as well)
params: dict, optional
Other parameters
free_raw_data: Bool
True if need to free raw data after construct inner dataset
"""
self.data = data
self.label = label
self.max_bin = max_bin
self.reference = reference
self.weight = weight
self.group = group
self.silent = silent
self.feature_name = feature_name
self.categorical_feature = categorical_feature
self.params = params
self.free_raw_data = free_raw_data
self.inner_dataset = None
self.used_indices = None
self._predictor = None
def create_valid(self, data, label=None, weight=None, group=None,
silent=False, params=None):
"""
Create validation data align with current dataset
Parameters
----------
data : string/numpy array/scipy.sparse
Data source of Dataset.
When data type is string, it represents the path of txt file
label : list or numpy 1-D array, optional
Label of the training data.
weight : list or numpy 1-D array , optional
Weight for each instance.
group : list or numpy 1-D array , optional
Group/query size for dataset
silent : boolean, optional
Whether print messages during construction
params: dict, optional
Other parameters
""" """
ret = Dataset(data, label=label, max_bin=self.max_bin, reference=self, tmp_out_len = ctypes.c_int64()
weight=weight, group=group, out_type = ctypes.c_int32()
silent=silent, params=params, free_raw_data=self.free_raw_data) ret = ctypes.POINTER(ctypes.c_void_p)()
ret._set_predictor(self._predictor) _safe_call(_LIB.LGBM_DatasetGetField(
return ret self.handle,
c_str(field_name),
def _update_params(self, params): ctypes.byref(tmp_out_len),
if not self.params: ctypes.byref(ret),
self.params = params ctypes.byref(out_type)))
if out_type.value != FIELD_TYPE_MAPPER[field_name]:
raise TypeError("Return type error for get_field")
if tmp_out_len.value == 0:
return None
if out_type.value == C_API_DTYPE_INT32:
return cint32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)), tmp_out_len.value)
elif out_type.value == C_API_DTYPE_FLOAT32:
return cfloat32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), tmp_out_len.value)
else: else:
self.params.update(params) raise TypeError("Unknown type")
def construct(self):
"""
Lazy init
"""
if self.inner_dataset is None:
if self.reference is not None:
if self.used_indices is None:
self.inner_dataset = self.reference._get_inner_dataset().create_valid(
self.data, self.label,
self.weight, self.group,
self.silent, self.params)
else:
"""construct subset"""
self.inner_dataset = self.reference._get_inner_dataset().subset(
self.used_indices, self.params)
else:
self.inner_dataset = _InnerDataset(self.data, self.label, self.max_bin,
None, self.weight, self.group, self._predictor,
self.silent, self.feature_name,
self.categorical_feature, self.params)
if self.free_raw_data:
self.data = None
def _get_inner_dataset(self):
"""get inner dataset"""
self.construct()
return self.inner_dataset
def __is_constructed(self):
"""check inner_dataset is constructed or not"""
return self.inner_dataset is not None
def set_categorical_feature(self, categorical_feature): def set_categorical_feature(self, categorical_feature):
""" """
...@@ -1082,7 +899,7 @@ class Dataset(object): ...@@ -1082,7 +899,7 @@ class Dataset(object):
return return
if self.data is not None: if self.data is not None:
self.categorical_feature = categorical_feature self.categorical_feature = categorical_feature
self.inner_dataset = None self._is_constructed = False
else: else:
raise LightGBMError("Cannot set categorical feature after freed raw data,\ raise LightGBMError("Cannot set categorical feature after freed raw data,\
Set free_raw_data=False when construct Dataset to avoid this.") Set free_raw_data=False when construct Dataset to avoid this.")
...@@ -1096,7 +913,7 @@ class Dataset(object): ...@@ -1096,7 +913,7 @@ class Dataset(object):
return return
if self.data is not None: if self.data is not None:
self._predictor = predictor self._predictor = predictor
self.inner_dataset = None self._is_constructed = False
else: else:
raise LightGBMError("Cannot set predictor after freed raw data,Set free_raw_data=False when construct Dataset to avoid this.") raise LightGBMError("Cannot set predictor after freed raw data,Set free_raw_data=False when construct Dataset to avoid this.")
...@@ -1116,7 +933,7 @@ class Dataset(object): ...@@ -1116,7 +933,7 @@ class Dataset(object):
return return
if self.data is not None: if self.data is not None:
self.reference = reference self.reference = reference
self.inner_dataset = None self._is_constructed = False
else: else:
raise LightGBMError("Cannot set reference after freed raw data,\ raise LightGBMError("Cannot set reference after freed raw data,\
Set free_raw_data=False when construct Dataset to avoid this.") Set free_raw_data=False when construct Dataset to avoid this.")
...@@ -1131,39 +948,14 @@ class Dataset(object): ...@@ -1131,39 +948,14 @@ class Dataset(object):
Feature names Feature names
""" """
self.feature_name = feature_name self.feature_name = feature_name
if self.__is_constructed(): if self._is_constructed and feature_name is not None:
self.inner_dataset.set_feature_name(self.feature_name) if len(feature_name) != self.num_feature():
raise ValueError("Length of feature_name({}) and num_feature({}) don't match".format(len(feature_name), self.num_feature()))
def subset(self, used_indices, params=None): c_feature_name = [c_str(name) for name in feature_name]
""" _safe_call(_LIB.LGBM_DatasetSetFeatureNames(
Get subset of current dataset self.handle,
c_array(ctypes.c_char_p, c_feature_name),
Parameters len(feature_name)))
----------
used_indices : list of int
Used indices of this subset
params : dict
Other parameters
"""
ret = Dataset(None)
ret.feature_name = self.feature_name
ret.categorical_feature = self.categorical_feature
ret.reference = self
ret._predictor = self._predictor
ret.used_indices = used_indices
ret.params = params
return ret
def save_binary(self, filename):
"""
Save Dataset to binary file
Parameters
----------
filename : string
Name of the output file.
"""
self._get_inner_dataset().save_binary(filename)
def set_label(self, label): def set_label(self, label):
""" """
...@@ -1175,8 +967,9 @@ class Dataset(object): ...@@ -1175,8 +967,9 @@ class Dataset(object):
The label information to be set into Dataset The label information to be set into Dataset
""" """
self.label = label self.label = label
if self.__is_constructed(): if self._is_constructed:
self.inner_dataset.set_label(self.label) label = list_to_1d_numpy(label, name='label')
self.set_field('label', label)
def set_weight(self, weight): def set_weight(self, weight):
""" """
...@@ -1188,8 +981,9 @@ class Dataset(object): ...@@ -1188,8 +981,9 @@ class Dataset(object):
Weight for each data point Weight for each data point
""" """
self.weight = weight self.weight = weight
if self.__is_constructed(): if self._is_constructed and weight is not None:
self.inner_dataset.set_weight(self.weight) weight = list_to_1d_numpy(weight, name='weight')
self.set_field('weight', weight)
def set_init_score(self, init_score): def set_init_score(self, init_score):
""" """
...@@ -1201,8 +995,9 @@ class Dataset(object): ...@@ -1201,8 +995,9 @@ class Dataset(object):
Init score for booster Init score for booster
""" """
self.init_score = init_score self.init_score = init_score
if self.__is_constructed(): if self._is_constructed and init_score is not None:
self.inner_dataset.set_init_score(self.init_score) init_score = list_to_1d_numpy(init_score, name='init_score')
self.set_field('init_score', init_score)
def set_group(self, group): def set_group(self, group):
""" """
...@@ -1214,8 +1009,9 @@ class Dataset(object): ...@@ -1214,8 +1009,9 @@ class Dataset(object):
Group size of each group Group size of each group
""" """
self.group = group self.group = group
if self.__is_constructed(): if self._is_constructed and group is not None:
self.inner_dataset.set_group(self.group) group = list_to_1d_numpy(group, np.int32, name='group')
self.set_field('group', group)
def get_label(self): def get_label(self):
""" """
...@@ -1225,8 +1021,8 @@ class Dataset(object): ...@@ -1225,8 +1021,8 @@ class Dataset(object):
------- -------
label : array label : array
""" """
if self.label is None and self.__is_constructed(): if self.label is None and self._is_constructed:
self.label = self.inner_dataset.get_label() self.label = self.get_field('label')
return self.label return self.label
def get_weight(self): def get_weight(self):
...@@ -1237,8 +1033,8 @@ class Dataset(object): ...@@ -1237,8 +1033,8 @@ class Dataset(object):
------- -------
weight : array weight : array
""" """
if self.weight is None and self.__is_constructed(): if self.weight is None and self._is_constructed:
self.weight = self.inner_dataset.get_weight() self.weight = self.get_field('weight')
return self.weight return self.weight
def get_init_score(self): def get_init_score(self):
...@@ -1249,8 +1045,8 @@ class Dataset(object): ...@@ -1249,8 +1045,8 @@ class Dataset(object):
------- -------
init_score : array init_score : array
""" """
if self.init_score is None and self.__is_constructed(): if self.init_score is None and self._is_constructed:
self.init_score = self.inner_dataset.get_init_score() self.init_score = self.get_field('init_score')
return self.init_score return self.init_score
def get_group(self): def get_group(self):
...@@ -1261,8 +1057,8 @@ class Dataset(object): ...@@ -1261,8 +1057,8 @@ class Dataset(object):
------- -------
init_score : array init_score : array
""" """
if self.group is None and self.__is_constructed(): if self.group is None and self._is_constructed:
self.group = self.inner_dataset.get_group() self.group = self.get_field('group')
if self.group is not None: if self.group is not None:
# group data from LightGBM is boundaries data, need to convert to group size # group data from LightGBM is boundaries data, need to convert to group size
new_group = [] new_group = []
...@@ -1279,8 +1075,11 @@ class Dataset(object): ...@@ -1279,8 +1075,11 @@ class Dataset(object):
------- -------
number of rows : int number of rows : int
""" """
if self.__is_constructed(): if self._is_constructed:
return self.inner_dataset.num_data() ret = ctypes.c_int64()
_safe_call(_LIB.LGBM_DatasetGetNumData(self.handle,
ctypes.byref(ret)))
return ret.value
else: else:
raise LightGBMError("Cannot call num_data before construct, please call it explicitly") raise LightGBMError("Cannot call num_data before construct, please call it explicitly")
...@@ -1292,15 +1091,17 @@ class Dataset(object): ...@@ -1292,15 +1091,17 @@ class Dataset(object):
------- -------
number of columns : int number of columns : int
""" """
if self.__is_constructed(): if self._is_constructed:
return self.inner_dataset.num_feature() ret = ctypes.c_int64()
_safe_call(_LIB.LGBM_DatasetGetNumFeature(self.handle,
ctypes.byref(ret)))
return ret.value
else: else:
raise LightGBMError("Cannot call num_feature before construct, please call it explicitly") raise LightGBMError("Cannot call num_feature before construct, please call it explicitly")
class Booster(object): class Booster(object):
""""A Booster of LightGBM. """"Booster in LightGBM."""
"""
def __init__(self, params=None, train_set=None, model_file=None, silent=False): def __init__(self, params=None, train_set=None, model_file=None, silent=False):
""" """
Initialize the Booster. Initialize the Booster.
...@@ -1333,7 +1134,7 @@ class Booster(object): ...@@ -1333,7 +1134,7 @@ class Booster(object):
params_str = param_dict_to_str(params) params_str = param_dict_to_str(params)
"""construct booster object""" """construct booster object"""
_safe_call(_LIB.LGBM_BoosterCreate( _safe_call(_LIB.LGBM_BoosterCreate(
train_set._get_inner_dataset().handle, train_set.construct().handle,
c_str(params_str), c_str(params_str),
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
"""save reference to data""" """save reference to data"""
...@@ -1427,7 +1228,7 @@ class Booster(object): ...@@ -1427,7 +1228,7 @@ class Booster(object):
raise LightGBMError("Add validation data failed, you should use same predictor for these data") raise LightGBMError("Add validation data failed, you should use same predictor for these data")
_safe_call(_LIB.LGBM_BoosterAddValidData( _safe_call(_LIB.LGBM_BoosterAddValidData(
self.handle, self.handle,
data._get_inner_dataset().handle)) data.construct().handle))
self.valid_sets.append(data) self.valid_sets.append(data)
self.name_valid_sets.append(name) self.name_valid_sets.append(name)
self.__num_dataset += 1 self.__num_dataset += 1
...@@ -1481,7 +1282,7 @@ class Booster(object): ...@@ -1481,7 +1282,7 @@ class Booster(object):
self.train_set = train_set self.train_set = train_set
_safe_call(_LIB.LGBM_BoosterResetTrainingData( _safe_call(_LIB.LGBM_BoosterResetTrainingData(
self.handle, self.handle,
self.train_set._get_inner_dataset().handle)) self.train_set.construct().handle))
self.__inner_predict_buffer[0] = None self.__inner_predict_buffer[0] = None
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
if fobj is None: if fobj is None:
......
...@@ -5,7 +5,7 @@ from __future__ import absolute_import ...@@ -5,7 +5,7 @@ from __future__ import absolute_import
import inspect import inspect
import numpy as np import numpy as np
from .basic import LightGBMError, Dataset from .basic import LightGBMError, Dataset, IS_PY3
from .engine import train from .engine import train
'''sklearn''' '''sklearn'''
try: try:
...@@ -26,6 +26,13 @@ except ImportError: ...@@ -26,6 +26,13 @@ except ImportError:
LGBMLabelEncoder = None LGBMLabelEncoder = None
def _argc(func):
if IS_PY3:
return len(inspect.signature(func).parameters)
else:
return len(inspect.getargspec(func).args)
def _objective_function_wrapper(func): def _objective_function_wrapper(func):
"""Decorate an objective function """Decorate an objective function
Note: for multi-class task, the y_pred is group by class_id first, then group by row_id Note: for multi-class task, the y_pred is group by class_id first, then group by row_id
...@@ -57,7 +64,7 @@ def _objective_function_wrapper(func): ...@@ -57,7 +64,7 @@ def _objective_function_wrapper(func):
def inner(preds, dataset): def inner(preds, dataset):
"""internal function""" """internal function"""
labels = dataset.get_label() labels = dataset.get_label()
argc = len(inspect.getargspec(func).args) argc = _argc(func)
if argc == 2: if argc == 2:
grad, hess = func(labels, preds) grad, hess = func(labels, preds)
elif argc == 3: elif argc == 3:
...@@ -122,7 +129,7 @@ def _eval_function_wrapper(func): ...@@ -122,7 +129,7 @@ def _eval_function_wrapper(func):
def inner(preds, dataset): def inner(preds, dataset):
"""internal function""" """internal function"""
labels = dataset.get_label() labels = dataset.get_label()
argc = len(inspect.getargspec(func).args) argc = _argc(func)
if argc == 2: if argc == 2:
return func(labels, preds) return func(labels, preds)
elif argc == 3: elif argc == 3:
......
...@@ -189,12 +189,13 @@ def test_booster(): ...@@ -189,12 +189,13 @@ def test_booster():
LIB.LGBM_BoosterCreate(train, c_str("app=binary metric=auc num_leaves=31 verbose=0"), ctypes.byref(booster)) LIB.LGBM_BoosterCreate(train, c_str("app=binary metric=auc num_leaves=31 verbose=0"), ctypes.byref(booster))
LIB.LGBM_BoosterAddValidData(booster, test) LIB.LGBM_BoosterAddValidData(booster, test)
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
for i in range(100): for i in range(1, 101):
LIB.LGBM_BoosterUpdateOneIter(booster, ctypes.byref(is_finished)) LIB.LGBM_BoosterUpdateOneIter(booster, ctypes.byref(is_finished))
result = np.array([0.0], dtype=np.float64) result = np.array([0.0], dtype=np.float64)
out_len = ctypes.c_ulong(0) out_len = ctypes.c_ulong(0)
LIB.LGBM_BoosterGetEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_double))) LIB.LGBM_BoosterGetEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
print('%d Iteration test AUC %f' % (i, result[0])) if i % 10 == 0:
print('%d Iteration test AUC %f' % (i, result[0]))
LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt')) LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt'))
LIB.LGBM_BoosterFree(booster) LIB.LGBM_BoosterFree(booster)
test_free_dataset(train) test_free_dataset(train)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment