Commit acbd4f34 authored by Guolin Ke's avatar Guolin Ke
Browse files

[python-package] free old dataset handle if set predictor/reference

parent 5b4b5d65
...@@ -540,12 +540,19 @@ class Dataset(object): ...@@ -540,12 +540,19 @@ class Dataset(object):
self.categorical_feature = categorical_feature self.categorical_feature = categorical_feature
self.params = params self.params = params
self.free_raw_data = free_raw_data self.free_raw_data = free_raw_data
self._is_constructed = False
self.used_indices = None self.used_indices = None
self._predictor = None self._predictor = None
def __del__(self): def __del__(self):
_safe_call(_LIB.LGBM_DatasetFree(self.handle)) self._free_handle()
def _is_constructed(self):
return self.handle is not None
def _free_handle(self):
if self._is_constructed():
_safe_call(_LIB.LGBM_DatasetFree(self.handle))
self.handle = None
def _lazy_init(self, data, label=None, max_bin=255, reference=None, def _lazy_init(self, data, label=None, max_bin=255, reference=None,
weight=None, group=None, predictor=None, weight=None, group=None, predictor=None,
...@@ -587,7 +594,7 @@ class Dataset(object): ...@@ -587,7 +594,7 @@ class Dataset(object):
"""process for reference dataset""" """process for reference dataset"""
ref_dataset = None ref_dataset = None
if isinstance(reference, Dataset): if isinstance(reference, Dataset):
ref_dataset = reference.handle ref_dataset = reference.construct().handle
elif reference is not None: elif reference is not None:
raise TypeError('Reference dataset should be None or dataset instance') raise TypeError('Reference dataset should be None or dataset instance')
"""start construct data""" """start construct data"""
...@@ -718,8 +725,7 @@ class Dataset(object): ...@@ -718,8 +725,7 @@ class Dataset(object):
def construct(self): def construct(self):
"""Lazy init""" """Lazy init"""
if not self._is_constructed: if not self._is_constructed():
self._is_constructed = True
if self.reference is not None: if self.reference is not None:
if self.used_indices is None: if self.used_indices is None:
"""create valid""" """create valid"""
...@@ -729,10 +735,10 @@ class Dataset(object): ...@@ -729,10 +735,10 @@ class Dataset(object):
else: else:
"""construct subset""" """construct subset"""
used_indices = list_to_1d_numpy(self.used_indices, np.int32, name='used_indices') used_indices = list_to_1d_numpy(self.used_indices, np.int32, name='used_indices')
handle, self.handle = self.handle, ctypes.c_void_p() self.handle = ctypes.c_void_p()
params_str = param_dict_to_str(self.params) params_str = param_dict_to_str(self.params)
_safe_call(_LIB.LGBM_DatasetGetSubset( _safe_call(_LIB.LGBM_DatasetGetSubset(
handle, self.reference.construct().handle,
used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ctypes.c_int(used_indices.shape[0]), ctypes.c_int(used_indices.shape[0]),
c_str(params_str), c_str(params_str),
...@@ -791,7 +797,6 @@ class Dataset(object): ...@@ -791,7 +797,6 @@ class Dataset(object):
categorical_feature=self.categorical_feature, params=params) categorical_feature=self.categorical_feature, params=params)
ret._predictor = self._predictor ret._predictor = self._predictor
ret.used_indices = used_indices ret.used_indices = used_indices
ret.handle = self.handle
return ret return ret
def save_binary(self, filename): def save_binary(self, filename):
...@@ -824,6 +829,8 @@ class Dataset(object): ...@@ -824,6 +829,8 @@ class Dataset(object):
data: numpy array or list or None data: numpy array or list or None
The array ofdata to be set The array ofdata to be set
""" """
if not self._is_constructed():
raise Exception("cannot set filed before construct dataset handle")
if data is None: if data is None:
"""set to None""" """set to None"""
_safe_call(_LIB.LGBM_DatasetSetField( _safe_call(_LIB.LGBM_DatasetSetField(
...@@ -865,6 +872,8 @@ class Dataset(object): ...@@ -865,6 +872,8 @@ class Dataset(object):
info : array info : array
A numpy array of information of the data A numpy array of information of the data
""" """
if not self._is_constructed():
raise Exception("cannot Get filed before construct dataset handle")
tmp_out_len = ctypes.c_int() tmp_out_len = ctypes.c_int()
out_type = ctypes.c_int() out_type = ctypes.c_int()
ret = ctypes.POINTER(ctypes.c_void_p)() ret = ctypes.POINTER(ctypes.c_void_p)()
...@@ -899,7 +908,7 @@ class Dataset(object): ...@@ -899,7 +908,7 @@ class Dataset(object):
return return
if self.data is not None: if self.data is not None:
self.categorical_feature = categorical_feature self.categorical_feature = categorical_feature
self._is_constructed = False self._free_handle()
else: else:
raise LightGBMError("Cannot set categorical feature after freed raw data,\ raise LightGBMError("Cannot set categorical feature after freed raw data,\
Set free_raw_data=False when construct Dataset to avoid this.") Set free_raw_data=False when construct Dataset to avoid this.")
...@@ -913,7 +922,7 @@ class Dataset(object): ...@@ -913,7 +922,7 @@ class Dataset(object):
return return
if self.data is not None: if self.data is not None:
self._predictor = predictor self._predictor = predictor
self._is_constructed = False self._free_handle()
else: else:
raise LightGBMError("Cannot set predictor after freed raw data,Set free_raw_data=False when construct Dataset to avoid this.") raise LightGBMError("Cannot set predictor after freed raw data,Set free_raw_data=False when construct Dataset to avoid this.")
...@@ -933,7 +942,7 @@ class Dataset(object): ...@@ -933,7 +942,7 @@ class Dataset(object):
return return
if self.data is not None: if self.data is not None:
self.reference = reference self.reference = reference
self._is_constructed = False self._free_handle()
else: else:
raise LightGBMError("Cannot set reference after freed raw data,\ raise LightGBMError("Cannot set reference after freed raw data,\
Set free_raw_data=False when construct Dataset to avoid this.") Set free_raw_data=False when construct Dataset to avoid this.")
...@@ -948,7 +957,7 @@ class Dataset(object): ...@@ -948,7 +957,7 @@ class Dataset(object):
Feature names Feature names
""" """
self.feature_name = feature_name self.feature_name = feature_name
if self._is_constructed and feature_name is not None: if self._is_constructed() and feature_name is not None:
if len(feature_name) != self.num_feature(): if len(feature_name) != self.num_feature():
raise ValueError("Length of feature_name({}) and num_feature({}) don't match".format(len(feature_name), self.num_feature())) raise ValueError("Length of feature_name({}) and num_feature({}) don't match".format(len(feature_name), self.num_feature()))
c_feature_name = [c_str(name) for name in feature_name] c_feature_name = [c_str(name) for name in feature_name]
...@@ -967,7 +976,7 @@ class Dataset(object): ...@@ -967,7 +976,7 @@ class Dataset(object):
The label information to be set into Dataset The label information to be set into Dataset
""" """
self.label = label self.label = label
if self._is_constructed: if self._is_constructed():
label = list_to_1d_numpy(label, name='label') label = list_to_1d_numpy(label, name='label')
self.set_field('label', label) self.set_field('label', label)
...@@ -981,7 +990,7 @@ class Dataset(object): ...@@ -981,7 +990,7 @@ class Dataset(object):
Weight for each data point Weight for each data point
""" """
self.weight = weight self.weight = weight
if self._is_constructed and weight is not None: if self._is_constructed() and weight is not None:
weight = list_to_1d_numpy(weight, name='weight') weight = list_to_1d_numpy(weight, name='weight')
self.set_field('weight', weight) self.set_field('weight', weight)
...@@ -995,7 +1004,7 @@ class Dataset(object): ...@@ -995,7 +1004,7 @@ class Dataset(object):
Init score for booster Init score for booster
""" """
self.init_score = init_score self.init_score = init_score
if self._is_constructed and init_score is not None: if self._is_constructed() and init_score is not None:
init_score = list_to_1d_numpy(init_score, name='init_score') init_score = list_to_1d_numpy(init_score, name='init_score')
self.set_field('init_score', init_score) self.set_field('init_score', init_score)
...@@ -1009,7 +1018,7 @@ class Dataset(object): ...@@ -1009,7 +1018,7 @@ class Dataset(object):
Group size of each group Group size of each group
""" """
self.group = group self.group = group
if self._is_constructed and group is not None: if self._is_constructed() and group is not None:
group = list_to_1d_numpy(group, np.int32, name='group') group = list_to_1d_numpy(group, np.int32, name='group')
self.set_field('group', group) self.set_field('group', group)
...@@ -1021,7 +1030,7 @@ class Dataset(object): ...@@ -1021,7 +1030,7 @@ class Dataset(object):
------- -------
label : array label : array
""" """
if self.label is None and self._is_constructed: if self.label is None and self._is_constructed():
self.label = self.get_field('label') self.label = self.get_field('label')
return self.label return self.label
...@@ -1033,7 +1042,7 @@ class Dataset(object): ...@@ -1033,7 +1042,7 @@ class Dataset(object):
------- -------
weight : array weight : array
""" """
if self.weight is None and self._is_constructed: if self.weight is None and self._is_constructed():
self.weight = self.get_field('weight') self.weight = self.get_field('weight')
return self.weight return self.weight
...@@ -1045,7 +1054,7 @@ class Dataset(object): ...@@ -1045,7 +1054,7 @@ class Dataset(object):
------- -------
init_score : array init_score : array
""" """
if self.init_score is None and self._is_constructed: if self.init_score is None and self._is_constructed():
self.init_score = self.get_field('init_score') self.init_score = self.get_field('init_score')
return self.init_score return self.init_score
...@@ -1057,7 +1066,7 @@ class Dataset(object): ...@@ -1057,7 +1066,7 @@ class Dataset(object):
------- -------
init_score : array init_score : array
""" """
if self.group is None and self._is_constructed: if self.group is None and self._is_constructed():
self.group = self.get_field('group') self.group = self.get_field('group')
if self.group is not None: if self.group is not None:
# group data from LightGBM is boundaries data, need to convert to group size # group data from LightGBM is boundaries data, need to convert to group size
...@@ -1075,7 +1084,7 @@ class Dataset(object): ...@@ -1075,7 +1084,7 @@ class Dataset(object):
------- -------
number of rows : int number of rows : int
""" """
if self._is_constructed: if self._is_constructed():
ret = ctypes.c_int() ret = ctypes.c_int()
_safe_call(_LIB.LGBM_DatasetGetNumData(self.handle, _safe_call(_LIB.LGBM_DatasetGetNumData(self.handle,
ctypes.byref(ret))) ctypes.byref(ret)))
...@@ -1091,7 +1100,7 @@ class Dataset(object): ...@@ -1091,7 +1100,7 @@ class Dataset(object):
------- -------
number of columns : int number of columns : int
""" """
if self._is_constructed: if self._is_constructed():
ret = ctypes.c_int() ret = ctypes.c_int()
_safe_call(_LIB.LGBM_DatasetGetNumFeature(self.handle, _safe_call(_LIB.LGBM_DatasetGetNumFeature(self.handle,
ctypes.byref(ret))) ctypes.byref(ret)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment