Commit acbd4f34 authored by Guolin Ke's avatar Guolin Ke
Browse files

[python-package] free old dataset handle if set predictor/reference

parent 5b4b5d65
......@@ -540,12 +540,19 @@ class Dataset(object):
self.categorical_feature = categorical_feature
self.params = params
self.free_raw_data = free_raw_data
self._is_constructed = False
self.used_indices = None
self._predictor = None
def __del__(self):
_safe_call(_LIB.LGBM_DatasetFree(self.handle))
self._free_handle()
def _is_constructed(self):
return self.handle is not None
def _free_handle(self):
if self._is_constructed():
_safe_call(_LIB.LGBM_DatasetFree(self.handle))
self.handle = None
def _lazy_init(self, data, label=None, max_bin=255, reference=None,
weight=None, group=None, predictor=None,
......@@ -587,7 +594,7 @@ class Dataset(object):
"""process for reference dataset"""
ref_dataset = None
if isinstance(reference, Dataset):
ref_dataset = reference.handle
ref_dataset = reference.construct().handle
elif reference is not None:
raise TypeError('Reference dataset should be None or dataset instance')
"""start construct data"""
......@@ -718,8 +725,7 @@ class Dataset(object):
def construct(self):
"""Lazy init"""
if not self._is_constructed:
self._is_constructed = True
if not self._is_constructed():
if self.reference is not None:
if self.used_indices is None:
"""create valid"""
......@@ -729,10 +735,10 @@ class Dataset(object):
else:
"""construct subset"""
used_indices = list_to_1d_numpy(self.used_indices, np.int32, name='used_indices')
handle, self.handle = self.handle, ctypes.c_void_p()
self.handle = ctypes.c_void_p()
params_str = param_dict_to_str(self.params)
_safe_call(_LIB.LGBM_DatasetGetSubset(
handle,
self.reference.construct().handle,
used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ctypes.c_int(used_indices.shape[0]),
c_str(params_str),
......@@ -791,7 +797,6 @@ class Dataset(object):
categorical_feature=self.categorical_feature, params=params)
ret._predictor = self._predictor
ret.used_indices = used_indices
ret.handle = self.handle
return ret
def save_binary(self, filename):
......@@ -824,6 +829,8 @@ class Dataset(object):
data: numpy array or list or None
The array ofdata to be set
"""
if not self._is_constructed():
raise Exception("cannot set filed before construct dataset handle")
if data is None:
"""set to None"""
_safe_call(_LIB.LGBM_DatasetSetField(
......@@ -865,6 +872,8 @@ class Dataset(object):
info : array
A numpy array of information of the data
"""
if not self._is_constructed():
raise Exception("cannot Get filed before construct dataset handle")
tmp_out_len = ctypes.c_int()
out_type = ctypes.c_int()
ret = ctypes.POINTER(ctypes.c_void_p)()
......@@ -899,7 +908,7 @@ class Dataset(object):
return
if self.data is not None:
self.categorical_feature = categorical_feature
self._is_constructed = False
self._free_handle()
else:
raise LightGBMError("Cannot set categorical feature after freed raw data,\
Set free_raw_data=False when construct Dataset to avoid this.")
......@@ -913,7 +922,7 @@ class Dataset(object):
return
if self.data is not None:
self._predictor = predictor
self._is_constructed = False
self._free_handle()
else:
raise LightGBMError("Cannot set predictor after freed raw data,Set free_raw_data=False when construct Dataset to avoid this.")
......@@ -933,7 +942,7 @@ class Dataset(object):
return
if self.data is not None:
self.reference = reference
self._is_constructed = False
self._free_handle()
else:
raise LightGBMError("Cannot set reference after freed raw data,\
Set free_raw_data=False when construct Dataset to avoid this.")
......@@ -948,7 +957,7 @@ class Dataset(object):
Feature names
"""
self.feature_name = feature_name
if self._is_constructed and feature_name is not None:
if self._is_constructed() and feature_name is not None:
if len(feature_name) != self.num_feature():
raise ValueError("Length of feature_name({}) and num_feature({}) don't match".format(len(feature_name), self.num_feature()))
c_feature_name = [c_str(name) for name in feature_name]
......@@ -967,7 +976,7 @@ class Dataset(object):
The label information to be set into Dataset
"""
self.label = label
if self._is_constructed:
if self._is_constructed():
label = list_to_1d_numpy(label, name='label')
self.set_field('label', label)
......@@ -981,7 +990,7 @@ class Dataset(object):
Weight for each data point
"""
self.weight = weight
if self._is_constructed and weight is not None:
if self._is_constructed() and weight is not None:
weight = list_to_1d_numpy(weight, name='weight')
self.set_field('weight', weight)
......@@ -995,7 +1004,7 @@ class Dataset(object):
Init score for booster
"""
self.init_score = init_score
if self._is_constructed and init_score is not None:
if self._is_constructed() and init_score is not None:
init_score = list_to_1d_numpy(init_score, name='init_score')
self.set_field('init_score', init_score)
......@@ -1009,7 +1018,7 @@ class Dataset(object):
Group size of each group
"""
self.group = group
if self._is_constructed and group is not None:
if self._is_constructed() and group is not None:
group = list_to_1d_numpy(group, np.int32, name='group')
self.set_field('group', group)
......@@ -1021,7 +1030,7 @@ class Dataset(object):
-------
label : array
"""
if self.label is None and self._is_constructed:
if self.label is None and self._is_constructed():
self.label = self.get_field('label')
return self.label
......@@ -1033,7 +1042,7 @@ class Dataset(object):
-------
weight : array
"""
if self.weight is None and self._is_constructed:
if self.weight is None and self._is_constructed():
self.weight = self.get_field('weight')
return self.weight
......@@ -1045,7 +1054,7 @@ class Dataset(object):
-------
init_score : array
"""
if self.init_score is None and self._is_constructed:
if self.init_score is None and self._is_constructed():
self.init_score = self.get_field('init_score')
return self.init_score
......@@ -1057,7 +1066,7 @@ class Dataset(object):
-------
init_score : array
"""
if self.group is None and self._is_constructed:
if self.group is None and self._is_constructed():
self.group = self.get_field('group')
if self.group is not None:
# group data from LightGBM is boundaries data, need to convert to group size
......@@ -1075,7 +1084,7 @@ class Dataset(object):
-------
number of rows : int
"""
if self._is_constructed:
if self._is_constructed():
ret = ctypes.c_int()
_safe_call(_LIB.LGBM_DatasetGetNumData(self.handle,
ctypes.byref(ret)))
......@@ -1091,7 +1100,7 @@ class Dataset(object):
-------
number of columns : int
"""
if self._is_constructed:
if self._is_constructed():
ret = ctypes.c_int()
_safe_call(_LIB.LGBM_DatasetGetNumFeature(self.handle,
ctypes.byref(ret)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment