Commit 629fc047 authored by Guolin Ke's avatar Guolin Ke
Browse files

more flexity python basic object

parent b41e0f0a
...@@ -37,6 +37,7 @@ public: ...@@ -37,6 +37,7 @@ public:
/*! /*!
* \brief Merge model from other boosting object * \brief Merge model from other boosting object
Will insert to the front of current boosting object
* \param other * \param other
*/ */
virtual void MergeFrom(const Boosting* other) = 0; virtual void MergeFrom(const Boosting* other) = 0;
......
...@@ -126,16 +126,27 @@ C_API_DTYPE_INT64 =3 ...@@ -126,16 +126,27 @@ C_API_DTYPE_INT64 =3
"""Matric is row major in python""" """Matric is row major in python"""
C_API_IS_ROW_MAJOR =1 C_API_IS_ROW_MAJOR =1
C_API_PREDICT_NORMAL =0
C_API_PREDICT_RAW_SCORE =1
C_API_PREDICT_LEAF_INDEX =2
FIELD_TYPE_MAPPER = {"label":C_API_DTYPE_FLOAT32,
"wegiht":C_API_DTYPE_FLOAT32,
"init_score":C_API_DTYPE_FLOAT32,
"group_id":C_API_DTYPE_INT32,
"group":C_API_DTYPE_INT32,
}
def c_float_array(data): def c_float_array(data):
"""Convert numpy array / list to c float array.""" """Convert numpy array / list to c float array."""
if isinstance(data, list): if isinstance(data, list):
data = np.array(data, copy=False) data = np.array(data, copy=False)
if is_numpy_1d_array(data): if is_numpy_1d_array(data):
if data.dtype == np.float32: if data.dtype == np.float32:
ptr_data = data.ctypes.data_as(ctypes.c_float) ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
type_data = C_API_DTYPE_FLOAT32 type_data = C_API_DTYPE_FLOAT32
elif data.dtype == np.float64: elif data.dtype == np.float64:
ptr_data = data.ctypes.data_as(ctypes.c_double) ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
type_data = C_API_DTYPE_FLOAT64 type_data = C_API_DTYPE_FLOAT64
else: else:
raise TypeError("expected np.float32 or np.float64, met type({})".format(data.dtype)) raise TypeError("expected np.float32 or np.float64, met type({})".format(data.dtype))
...@@ -149,10 +160,10 @@ def c_int_array(data): ...@@ -149,10 +160,10 @@ def c_int_array(data):
data = np.array(data, copy=False) data = np.array(data, copy=False)
if is_numpy_1d_array(data): if is_numpy_1d_array(data):
if data.dtype == np.int32: if data.dtype == np.int32:
ptr_data = data.ctypes.data_as(ctypes.c_int32) ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
type_data = C_API_DTYPE_INT32 type_data = C_API_DTYPE_INT32
elif data.dtype == np.int64: elif data.dtype == np.int64:
ptr_data = data.ctypes.data_as(ctypes.c_int64) ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int64))
type_data = C_API_DTYPE_INT64 type_data = C_API_DTYPE_INT64
else: else:
raise TypeError("expected np.int32 or np.int64, met type({})".format(data.dtype)) raise TypeError("expected np.int32 or np.int64, met type({})".format(data.dtype))
...@@ -160,19 +171,188 @@ def c_int_array(data): ...@@ -160,19 +171,188 @@ def c_int_array(data):
raise TypeError("Unknow type({})".format(type(data).__name__)) raise TypeError("Unknow type({})".format(type(data).__name__))
return (ptr_data, type_data) return (ptr_data, type_data)
class Predictor(object):
""""A Predictor of LightGBM.
"""
def __init__(self,model_file=None, params=None, booster_handle=None, is_manage_handle=True):
# pylint: disable=invalid-name
"""Initialize the Booster.
Parameters
----------
model_file : string
Path to the model file.
params : dict
Parameters for boosters.
"""
self.handle = ctypes.c_void_p()
self.__is_manage_handle = True
if model_file is not None:
"""Prediction task"""
out_num_total_model = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
c_str(model_file),
ctypes.byref(out_num_total_model),
ctypes.byref(self.handle)))
self.__num_total_model = out_num_total_model.value
tmp_out_len = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle,
ctypes.byref(tmp_out_len)))
self.num_class = tmp_out_len.value
elif booster_handle is not None:
self.__is_manage_handle = is_manage_handle
self.handle = booster_handle
tmp_out_len = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle,
ctypes.byref(tmp_out_len)))
self.num_class = tmp_out_len.value
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
self.handle,
ctypes.byref(tmp_out_len)))
self.__num_total_model = self.num_class * tmp_out_len.value
else:
raise TypeError('Need Model file to create a booster')
def __del__(self):
if self.__is_manage_handle:
_safe_call(_LIB.LGBM_BoosterFree(self.handle))
def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True):
if isinstance(data, Dataset):
raise TypeError("cannot use Dataset instance for prediction, please use raw data instead")
predict_type = C_API_PREDICT_NORMAL
if raw_score:
predict_type = C_API_PREDICT_RAW_SCORE
if pred_leaf:
predict_type = C_API_PREDICT_LEAF_INDEX
int_data_has_header = 0
if data_has_header:
int_data_has_header = 1
if is_str(data):
tmp_pred_fname = tempfile.NamedTemporaryFile(prefix="lightgbm_tmp_pred_").name
_safe_call(_LIB.LGBM_BoosterPredictForFile(
self.handle,
c_str(data),
int_data_has_header,
predict_type,
num_iteration,
c_str(tmp_pred_fname)))
lines = open(tmp_pred_fname,"r").readlines()
nrow = len(lines)
preds = []
for line in lines:
for token in line.split('\t'):
preds.append(float(token))
preds = np.array(preds, copy=False)
os.remove(tmp_pred_fname)
elif isinstance(data, scipy.sparse.csr_matrix):
preds, nrow = self.__pred_for_csr(data, num_iteration, predict_type)
elif isinstance(data, np.ndarray):
preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type)
else:
try:
csr = scipy.sparse.csr_matrix(data)
res = self.__pred_for_csr(csr, num_iteration, predict_type)
except:
raise TypeError('can not predict data for type {}'.format(type(data).__name__))
if pred_leaf:
preds = preds.astype(np.int32)
if preds.size != nrow and is_reshape:
if preds.size % nrow == 0:
ncol = int(preds.size / nrow)
preds = preds.reshape(nrow, ncol)
else:
raise ValueError('len of predict result(%d) cannot be divide nrow(%d)' %(preds.size, nrow) )
return preds
def __pred_for_np2d(self, mat, num_iteration, predict_type):
"""
Predict for a 2-D numpy matrix.
"""
if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional')
if mat.dtype == np.float32 or mat.dtype == np.float64:
data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False)
else:
"""change non-float data to float data, need to copy"""
data = np.array(mat.reshape(mat.size), dtype=np.float32)
ptr_data, type_ptr_data = c_float_array(data)
n_preds = self.num_class * mat.shape[0]
if predict_type == C_API_PREDICT_LEAF_INDEX:
if num_iteration > 0:
n_preds *= num_iteration
else:
used_iteration = self.__num_total_model / self.num_class
n_preds *= used_iteration
preds = np.zeros(n_preds, dtype=np.float32)
out_num_preds = ctypes.c_int64(0)
_safe_call(LIB.LGBM_BoosterPredictForMat(
self.handle,
ptr_data,
type_ptr_data,
mat.shape[0],
mat.shape[1],
C_API_IS_ROW_MAJOR,
predict_type,
num_iteration,
ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
))
if n_preds != out_num_preds.value:
raise ValueError("incorrect number for predict result")
return preds, mat.shape[0]
def __pred_for_csr(self, csr, num_iteration, predict_type):
"""
Predict for a csr data
"""
nrow = len(csr.indptr) - 1
n_preds = self.num_class * nrow
if predict_type == C_API_PREDICT_LEAF_INDEX:
if num_iteration > 0:
n_preds *= num_iteration
else:
used_iteration = self.__num_total_model / self.num_class
n_preds *= used_iteration
preds = np.zeros(n_preds, dtype=np.float32)
out_num_preds = ctypes.c_int64(0)
ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr)
ptr_data, type_ptr_data = c_float_array(csr.data)
_safe_call(LIB.LGBM_BoosterPredictForCSR(
self.handle,
ptr_indptr,
type_ptr_indptr,
csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ptr_data,
type_ptr_data,
len(csr.indptr),
len(csr.data),
csr.shape[1],
predict_type,
num_iteration,
ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
))
if n_preds != out_num_preds.value:
raise ValueError("incorrect number for predict result")
return preds, nrow
class Dataset(object): class Dataset(object):
"""Dataset used in LightGBM. """Dataset used in LightGBM.
Dataset is a internal data structure that used by LightGBM Dataset is a internal data structure that used by LightGBM
You can construct Dataset from numpy.arrays
""" """
_feature_names = None def __init__(self, data, label=None, max_bin=255, reference=None,
weight=None, group_id=None, predictor=None,
def __init__(self, data, max_bin=255, reference=None, silent=False, params=None):
label=None, weight=None, group_id=None,
silent=False, feature_names=None,
other_params=None, is_continue_train=False):
""" """
Dataset used in LightGBM. Dataset used in LightGBM.
...@@ -181,41 +361,35 @@ class Dataset(object): ...@@ -181,41 +361,35 @@ class Dataset(object):
data : string/numpy array/scipy.sparse data : string/numpy array/scipy.sparse
Data source of Dataset. Data source of Dataset.
When data is string type, it represents the path of txt file, When data is string type, it represents the path of txt file,
label : list or numpy 1-D array, optional
Label of the data
max_bin : int, required max_bin : int, required
max number of discrete bin for features max number of discrete bin for features
reference : Other Dataset, optional reference : Other Dataset, optional
If this dataset validation, need to use training data as reference If this dataset validation, need to use training data as reference
label : list or numpy 1-D array, optional
Label of the training data.
weight : list or numpy 1-D array , optional weight : list or numpy 1-D array , optional
Weight for each instance. Weight for each instance.
group_id : list or numpy 1-D array , optional group_id : list or numpy 1-D array , optional
group/query id for each instance. Note: if having group/query id, data should group by this id group/query id for each instance. Note: if having group/query id, data should group by this id
silent : boolean, optional silent : boolean, optional
Whether print messages during construction Whether print messages during construction
feature_names : list, optional params: dict, optional
Set names for features.
other_params: dict, optional
other parameters other parameters
""" """
if data is None: if data is None:
self.handle = None self.handle = None
return return
"""save raw data for continue train """
if is_continue_train:
self.raw_data = data
else:
self.raw_data = None
self.data_has_header = False self.data_has_header = False
"""process for args""" """process for args"""
if params is None:
params = {} params = {}
self.max_bin = max_bin
self.predictor = predictor
params["max_bin"] = max_bin params["max_bin"] = max_bin
if silent: if silent:
params["verbose"] = 0 params["verbose"] = 0
if other_params: else:
other_params.update(params) params["verbose"] = 1
params = other_params
params_str = dict_to_str(params) params_str = dict_to_str(params)
"""process for reference dataset""" """process for reference dataset"""
ref_dataset = None ref_dataset = None
...@@ -228,7 +402,7 @@ class Dataset(object): ...@@ -228,7 +402,7 @@ class Dataset(object):
"""check data has header or not""" """check data has header or not"""
if "has_header" in params or "header" in params: if "has_header" in params or "header" in params:
if params["has_header"].lower() == "true" or params["header"].lower() == "true": if params["has_header"].lower() == "true" or params["header"].lower() == "true":
data_has_header = True self.data_has_header = True
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_CreateDatasetFromFile( _safe_call(_LIB.LGBM_CreateDatasetFromFile(
c_str(data), c_str(data),
...@@ -242,8 +416,6 @@ class Dataset(object): ...@@ -242,8 +416,6 @@ class Dataset(object):
else: else:
try: try:
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
if self.raw_data is not None:
self.raw_data = csr
self.__init_from_csr(csr) self.__init_from_csr(csr)
except: except:
raise TypeError('can not initialize Dataset from {}'.format(type(data).__name__)) raise TypeError('can not initialize Dataset from {}'.format(type(data).__name__))
...@@ -253,14 +425,52 @@ class Dataset(object): ...@@ -253,14 +425,52 @@ class Dataset(object):
self.__group = None self.__group = None
if label is not None: if label is not None:
self.set_label(label) self.set_label(label)
if self.get_label() is None:
raise ValueError("label should not be None")
if weight is not None: if weight is not None:
self.set_weight(weight) self.set_weight(weight)
if group_id is not None: if group_id is not None:
self.set_group_id(group_id) self.set_group_id(group_id)
self.feature_names = feature_names # load init score
if self.predictor is not None and isinstance(self.predictor, Predictor):
init_score = self.predictor.predict(data,
raw_score=True,
data_has_header=self.data_has_header,
is_reshape=False)
if self.predictor.num_class > 1:
# need re group init score
new_init_score = np.zeros(init_score.size(), dtype=np.float32)
num_data = self.num_data()
for i in range(num_data):
for j in range(self.predictor.num_class):
new_init_score[j * num_data + i] = init_score[i * self.predictor.num_class + j]
init_score = new_init_score
self.set_init_score(init_score)
def new_valid_dataset(self, data, label=None, weight=None, group_id=None,
silent=False, params=None):
"""
Create validation data align with current dataset
def free_raw_data(self): Parameters
self.raw_data = None ----------
data : string/numpy array/scipy.sparse
Data source of Dataset.
When data is string type, it represents the path of txt file,
label : list or numpy 1-D array, optional
Label of the training data.
weight : list or numpy 1-D array , optional
Weight for each instance.
group_id : list or numpy 1-D array , optional
group/query id for each instance. Note: if having group/query id, data should group by this id
silent : boolean, optional
Whether print messages during construction
other_params: dict, optional
other parameters
"""
return Dataset(data, label=label, max_bin=self.max_bin, reference=self,
weight=weight, group_id=group_id, predictor=self.predictor,
silent=silent, params=params)
def __init_from_np2d(self, mat, params_str, ref_dataset): def __init_from_np2d(self, mat, params_str, ref_dataset):
""" """
...@@ -301,7 +511,7 @@ class Dataset(object): ...@@ -301,7 +511,7 @@ class Dataset(object):
_safe_call(_LIB.LGBM_CreateDatasetFromCSR( _safe_call(_LIB.LGBM_CreateDatasetFromCSR(
ptr_indptr, ptr_indptr,
type_ptr_indptr, type_ptr_indptr,
csr.indices.ctypes.data_as(ctypes.c_int32), csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ptr_data, ptr_data,
type_ptr_data, type_ptr_data,
len(csr.indptr), len(csr.indptr),
...@@ -327,19 +537,23 @@ class Dataset(object): ...@@ -327,19 +537,23 @@ class Dataset(object):
info : array info : array
a numpy array of information of the data a numpy array of information of the data
""" """
out_len = ctypes.c_int32() tmp_out_len = ctypes.c_int64()
out_type = ctypes.c_int32() out_type = ctypes.c_int32()
ret = ctypes.POINTER(ctypes.c_void_p)() ret = ctypes.POINTER(ctypes.c_void_p)()
_safe_call(_LIB.LGBM_DatasetGetField( _safe_call(_LIB.LGBM_DatasetGetField(
self.handle, self.handle,
c_str(field_name), c_str(field_name),
ctypes.byref(out_len), ctypes.byref(tmp_out_len),
ctypes.byref(ret), ctypes.byref(ret),
ctypes.byref(out_type))) ctypes.byref(out_type)))
if out_type.value != FIELD_TYPE_MAPPER[field_name]:
raise TypeError("Return type error for get_field")
if tmp_out_len.value == 0:
return None
if out_type.value == C_API_DTYPE_INT32: if out_type.value == C_API_DTYPE_INT32:
return cint32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(c_int32), out_len.value)) return cint32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)), tmp_out_len.value)
elif out_type.value == C_API_DTYPE_FLOAT32: elif out_type.value == C_API_DTYPE_FLOAT32:
return cfloat32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(c_float), out_len.value)) return cfloat32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), tmp_out_len.value)
else: else:
raise TypeError("unknow type") raise TypeError("unknow type")
...@@ -351,19 +565,29 @@ class Dataset(object): ...@@ -351,19 +565,29 @@ class Dataset(object):
field_name: str field_name: str
The field name of the information The field name of the information
data: numpy array or list data: numpy array or list or None
The array ofdata to be set The array ofdata to be set
""" """
if data is None:
_safe_call(_LIB.LGBM_DatasetSetField(
self.handle,
c_str(field_name),
None,
0,
FIELD_TYPE_MAPPER[field_name]))
return
if not is_numpy_1d_array(data): if not is_numpy_1d_array(data):
raise TypeError("Unknow type({})".format(type(data).__name__)) raise TypeError("Unknow type({})".format(type(data).__name__))
if data.dtype == np.float32: if data.dtype == np.float32:
ptr_data = data.ctypes.data_as(ctypes.c_float) ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
type_data = C_API_DTYPE_FLOAT32 type_data = C_API_DTYPE_FLOAT32
elif data.dtype == np.int32: elif data.dtype == np.int32:
ptr_data = data.ctypes.data_as(ctypes.c_int32) ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
type_data = C_API_DTYPE_INT32 type_data = C_API_DTYPE_INT32
else: else:
raise TypeError("excepted np.float32 or np.int32, met type({})".format(data.dtype)) raise TypeError("excepted np.float32 or np.int32, met type({})".format(data.dtype))
if type_data != FIELD_TYPE_MAPPER[field_name]:
raise TypeError("type error for set_field")
_safe_call(_LIB.LGBM_DatasetSetField( _safe_call(_LIB.LGBM_DatasetSetField(
self.handle, self.handle,
c_str(field_name), c_str(field_name),
...@@ -406,6 +630,7 @@ class Dataset(object): ...@@ -406,6 +630,7 @@ class Dataset(object):
weight : array like weight : array like
Weight for each data point Weight for each data point
""" """
if weight is not None:
weight = list_to_1d_numpy(weight, np.float32) weight = list_to_1d_numpy(weight, np.float32)
if weight.dtype != np.float32: if weight.dtype != np.float32:
weight = weight.astype(np.float32, copy=False) weight = weight.astype(np.float32, copy=False)
...@@ -419,10 +644,11 @@ class Dataset(object): ...@@ -419,10 +644,11 @@ class Dataset(object):
score: array like score: array like
""" """
if score is not None:
score = list_to_1d_numpy(score, np.float32) score = list_to_1d_numpy(score, np.float32)
if score.dtype != np.float32: if score.dtype != np.float32:
score = score.astype(np.float32, copy=False) score = score.astype(np.float32, copy=False)
self.__init_score = init_score self.__init_score = score
self.set_field('init_score', score) self.set_field('init_score', score)
def set_group(self, group): def set_group(self, group):
...@@ -433,6 +659,7 @@ class Dataset(object): ...@@ -433,6 +659,7 @@ class Dataset(object):
group : array like group : array like
Group size of each group Group size of each group
""" """
if group is not None:
group = list_to_1d_numpy(group, np.int32) group = list_to_1d_numpy(group, np.int32)
if group.dtype != np.int32: if group.dtype != np.int32:
group = group.astype(np.int32, copy=False) group = group.astype(np.int32, copy=False)
...@@ -448,6 +675,7 @@ class Dataset(object): ...@@ -448,6 +675,7 @@ class Dataset(object):
group : array like group : array like
group_id of Dataset (used for ranking). group_id of Dataset (used for ranking).
""" """
if group_id is not None:
group_id = list_to_1d_numpy(group_id, np.int32) group_id = list_to_1d_numpy(group_id, np.int32)
if group_id.dtype != np.int32: if group_id.dtype != np.int32:
group_id = group_id.astype(np.int32, copy=False) group_id = group_id.astype(np.int32, copy=False)
...@@ -462,6 +690,8 @@ class Dataset(object): ...@@ -462,6 +690,8 @@ class Dataset(object):
""" """
if self.__label is None: if self.__label is None:
self.__label = self.get_field('label') self.__label = self.get_field('label')
if self.__label is None:
raise TypeError("label should not be None")
return self.__label return self.__label
def get_weight(self): def get_weight(self):
...@@ -521,58 +751,11 @@ class Dataset(object): ...@@ -521,58 +751,11 @@ class Dataset(object):
ctypes.byref(ret))) ctypes.byref(ret)))
return ret.value return ret.value
@property
def feature_names(self):
"""Get feature names (column labels).
Returns
-------
feature_names : list
"""
if self._feature_names is None:
self._feature_names = ['Column_{0}'.format(i) for i in range(self.num_col())]
return self._feature_names
@feature_names.setter
def feature_names(self, feature_names):
"""Set feature names (column labels).
Parameters
----------
feature_names : list
Labels for features
"""
if feature_names is not None:
# validate feature name
if not isinstance(feature_names, list):
feature_names = list(feature_names)
if len(feature_names) != len(set(feature_names)):
raise ValueError('feature_names must be unique')
if len(feature_names) != self.num_col():
msg = 'feature_names must have the same length as data'
raise ValueError(msg)
# prohibit to use symbols may affect to parse. e.g. []<
if not all(isinstance(f, STRING_TYPES) and
not any(x in f for x in set(('[', ']', '<')))
for f in feature_names):
raise ValueError('feature_names may not contain [, ] or <')
self._feature_names = feature_names
else:
self._feature_names = None
C_API_PREDICT_NORMAL =0
C_API_PREDICT_RAW_SCORE =1
C_API_PREDICT_LEAF_INDEX =2
class Booster(object): class Booster(object):
""""A Booster of of LightGBM. """"A Booster of of LightGBM.
""" """
def __init__(self, params=None, train_set=None, model_file=None, silent=False):
feature_names = None
def __init__(self,params=None,
train_set=None, valid_sets=None,
name_valid_sets=None, model_file=None):
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Initialize the Booster. """Initialize the Booster.
...@@ -582,83 +765,46 @@ class Booster(object): ...@@ -582,83 +765,46 @@ class Booster(object):
Parameters for boosters. Parameters for boosters.
train_set : Dataset train_set : Dataset
training dataset training dataset
valid_sets : List of Dataset or None
validation datasets
name_valid_sets : List of string
name of validation datasets
model_file : string model_file : string
Path to the model file. Path to the model file.
If tarin_set is not None, used for continued train.
else used for loading model prediction task
""" """
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
self.__need_reload_eval_info = True
self.__is_manage_handle = True
if params is None:
params = {}
if silent:
params["verbose"] = 0
else:
params["verbose"] = 1
if train_set is not None: if train_set is not None:
"""Training task""" """Training task"""
if not isinstance(train_set, Dataset): if not isinstance(train_set, Dataset):
raise TypeError('training data should be Dataset instance, met{}'.format(type(train_set).__name__)) raise TypeError('training data should be Dataset instance, met{}'.format(type(train_set).__name__))
valid_handles = None
n_valid = 0
if valid_sets is not None:
for valid in valid_sets:
if not isinstance(valid, Dataset):
raise TypeError('valid data should be Dataset instance, met{}'.format(type(valid).__name__))
valid_handles = c_array(ctypes.c_void_p, [valid.handle for valid in valid_sets])
if name_valid_sets is None:
name_valid_sets = ["valid_{}".format(x+1) for x in range(len(valid_sets)) ]
if len(valid_sets) != len(name_valid_sets):
raise Exception('len of valid_sets should be equal with len of name_valid_sets')
n_valid = len(valid_sets)
ref_input_model = None
params_str = dict_to_str(params) params_str = dict_to_str(params)
if model_file is not None:
ref_input_model = c_str(model_file)
"""construct booster object""" """construct booster object"""
_safe_call(_LIB.LGBM_BoosterCreate( _safe_call(_LIB.LGBM_BoosterCreate(
train_set.handle, train_set.handle,
valid_handles,
n_valid,
c_str(params_str), c_str(params_str),
ref_input_model,
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
"""if need to continue train"""
if model_file is not None:
self.__init_continue_train(train_set)
if valid_sets is not None:
for valid in valid_sets:
self.__init_continue_train(valid)
"""save reference to data""" """save reference to data"""
self.train_set = train_set self.train_set = train_set
self.valid_sets = valid_sets self.valid_sets = []
self.name_valid_sets = name_valid_sets self.name_valid_sets = []
self.__num_dataset = 1 + n_valid self.__num_dataset = 1
self.__training_score = None self.init_predictor = train_set.predictor
out_len = ctypes.c_int64(0) if self.init_predictor is not None:
_safe_call(_LIB.LGBM_BoosterMerge(
self.handle,
self.init_predictor.handle))
out_num_class = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses( _safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle, self.handle,
ctypes.byref(out_len))) ctypes.byref(out_num_class)))
self.__num_class = out_len.value self.__num_class = out_num_class.value
"""buffer for inner predict""" """buffer for inner predict"""
self.__inner_predict_buffer = [None for _ in range(self.__num_dataset)] self.__inner_predict_buffer = [None]
"""Get num of inner evals""" self.__get_eval_info()
_safe_call(_LIB.LGBM_BoosterGetEvalCounts(
self.handle,
ctypes.byref(out_len)))
self.__num_inner_eval = out_len.value
if self.__num_inner_eval > 0:
"""Get name of evals"""
string_buffers = [ctypes.create_string_buffer(255) for i in range(self.__num_inner_eval)]
ptr_string_buffers = (ctypes.c_char_p*self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetEvalNames(
self.handle,
ctypes.byref(out_len),
ptr_string_buffers))
if self.__num_inner_eval != out_len.value:
raise ValueError("size of eval names doesn't equal with num_evals")
self.__name_inner_eval = []
for i in range(self.__num_inner_eval):
self.__name_inner_eval.append(string_buffers[i].value.decode())
elif model_file is not None: elif model_file is not None:
"""Prediction task""" """Prediction task"""
out_num_total_model = ctypes.c_int64(0) out_num_total_model = ctypes.c_int64(0)
...@@ -667,18 +813,40 @@ class Booster(object): ...@@ -667,18 +813,40 @@ class Booster(object):
ctypes.byref(out_num_total_model), ctypes.byref(out_num_total_model),
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
self.__num_total_model = out_num_total_model.value self.__num_total_model = out_num_total_model.value
out_len = ctypes.c_int64(0) out_num_class = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses( _safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle, self.handle,
ctypes.byref(out_len))) ctypes.byref(out_num_class)))
self.__num_class = out_len.value self.__num_class = out_num_class.value
else: else:
raise TypeError('At least need training dataset or model file to create booster instance') raise TypeError('At least need training dataset or model file to create booster instance')
def __del__(self): def __del__(self):
if self.handle is not None and self.__is_manage_handle:
_safe_call(_LIB.LGBM_BoosterFree(self.handle)) _safe_call(_LIB.LGBM_BoosterFree(self.handle))
def update(self, fobj=None): def add_valid_data(self, data, name):
if data.predictor is not self.init_predictor:
raise Exception("Add validation data failed, you should use same predictor for these data")
_safe_call(_LIB.LGBM_BoosterAddValidData(
self.handle,
data.handle))
self.valid_sets.append(data)
self.name_valid_sets.append(name)
self.__num_dataset += 1
def ResetParameter(self, params, silent=False):
self.__need_reload_eval_info = True
if silent:
params["verbose"] = 0
else:
params["verbose"] = 1
params_str = dict_to_str(params)
_safe_call(_LIB.LGBM_BoosterResetParameter(
self.handle,
c_str(params_str)))
def update(self, train_set=None, fobj=None):
""" """
Update for one iteration Update for one iteration
Note: for multi-class task, the score is group by class_id first, then group by row_id Note: for multi-class task, the score is group by class_id first, then group by row_id
...@@ -686,6 +854,7 @@ class Booster(object): ...@@ -686,6 +854,7 @@ class Booster(object):
and you should group grad and hess in this way as well and you should group grad and hess in this way as well
Parameters Parameters
---------- ----------
train_set : training data, None means use last training data
fobj : function fobj : function
Customized objective function. Customized objective function.
...@@ -693,6 +862,15 @@ class Booster(object): ...@@ -693,6 +862,15 @@ class Booster(object):
------- -------
is_finished, bool is_finished, bool
""" """
"""need reset training data"""
if train_set is not None and train_set is not self.train_set:
if train_set.predictor is not self.init_predictor:
raise Exception("Replace training data failed, you should use same predictor for these data")
self.train_set = train_set
_safe_call(_LIB.LGBM_BoosterResetTrainingData(
self.handle,
self.train_set.handle))
self.__inner_predict_buffer[0] = None
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
if fobj is None: if fobj is None:
_safe_call(_LIB.LGBM_BoosterUpdateOneIter( _safe_call(_LIB.LGBM_BoosterUpdateOneIter(
...@@ -701,9 +879,9 @@ class Booster(object): ...@@ -701,9 +879,9 @@ class Booster(object):
return is_finished.value == 1 return is_finished.value == 1
else: else:
grad, hess = fobj(self.__inner_predict(0), self.train_set) grad, hess = fobj(self.__inner_predict(0), self.train_set)
return self.boost(grad, hess) return self.__boost(grad, hess)
def boost(self, grad, hess): def __boost(self, grad, hess):
""" """
Boost the booster for one iteration, with customized gradient statistics. Boost the booster for one iteration, with customized gradient statistics.
Note: for multi-class task, the score is group by class_id first, then group by row_id Note: for multi-class task, the score is group by class_id first, then group by row_id
...@@ -729,11 +907,53 @@ class Booster(object): ...@@ -729,11 +907,53 @@ class Booster(object):
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom( _safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom(
self.handle, self.handle,
grad.ctypes.data_as(ctypes.c_float), grad.ctypes.data_as(ctypes.ctypes.POINTER(ctypes.c_float)),
hess.ctypes.data_as(ctypes.c_float), hess.ctypes.data_as(ctypes.ctypes.POINTER(ctypes.c_float)),
ctypes.byref(is_finished))) ctypes.byref(is_finished)))
return is_finished.value == 1 return is_finished.value == 1
def rollback_one_iter(self):
_safe_call(_LIB.LGBM_BoosterRollbackOneIter(
self.handle))
def current_iteration(self):
out_cur_iter = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
self.handle,
ctypes.byref(out_cur_iter)))
return out_cur_iter.value
def eval(self, data, name, feval=None):
"""Evaluate for data
Parameters
----------
data : Dataset object
name : name of data
feval : function
Custom evaluation function.
Returns
-------
result: str
Evaluation result string.
"""
if not isinstance(data, Dataset):
raise TypeError("Can only eval for Dataset instance")
data_idx = -1
if data is self.train_set:
data_idx = 0
else:
for i in range(len(self.valid_sets)):
if data is self.valid_sets[i]:
data_idx = i + 1
break
"""need push new valid data"""
if data_idx == -1:
self.add_valid_data(data, name)
data_idx = self.__num_dataset - 1
return self.__inner_eval(name, data_idx, feval)
def eval_train(self, feval=None): def eval_train(self, feval=None):
"""Evaluate for training data """Evaluate for training data
...@@ -774,141 +994,28 @@ class Booster(object): ...@@ -774,141 +994,28 @@ class Booster(object):
c_str(filename))) c_str(filename)))
def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True): def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True):
if isinstance(data, Dataset): predictor = Predictor(booster_handle=self.handle, is_manage_handle=False)
raise TypeError("cannot use Dataset instance for prediction, please use raw data instead") return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape)
predict_type = C_API_PREDICT_NORMAL
if raw_score:
predict_type = cC_API_PREDICT_RAW_SCORE
if pred_leaf:
predict_type = C_API_PREDICT_LEAF_INDEX
int_data_has_header = 0
if data_has_header:
int_data_has_header = 1
if is_str(data):
tmp_pred_fname = tempfile.NamedTemporaryFile(prefix="lightgbm_tmp_pred_").name
_safe_call(_LIB.LGBM_BoosterPredictForFile(
self.handle,
c_str(data),
int_data_has_header,
predict_type,
num_iteration,
c_str(tmp_pred_fname)))
lines = open(tmp_pred_fname,"r").readlines()
nrow = len(lines)
preds = []
for line in lines:
for token in line.split('\t'):
preds.append(float(token))
preds = np.array(preds, copy=False)
os.remove(tmp_pred_fname)
elif isinstance(data, scipy.sparse.csr_matrix):
preds, nrow = self.__pred_for_csr(data, num_iteration, predict_type)
elif isinstance(data, np.ndarray):
preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type)
else:
try:
csr = scipy.sparse.csr_matrix(data)
res = self.__pred_for_csr(csr, num_iteration, predict_type)
except:
raise TypeError('can not predict data for type {}'.format(type(data).__name__))
if pred_leaf:
preds = preds.astype(np.int32)
if preds.size != nrow and is_reshape:
if preds.size % nrow == 0:
ncol = int(preds.size / nrow)
preds = preds.reshape(nrow, ncol)
else:
raise ValueError('len of predict result(%d) cannot be divide nrow(%d)' %(preds.size, nrow) )
return preds
def __pred_for_np2d(self, mat, num_iteration, predict_type):
"""
Predict for a 2-D numpy matrix.
"""
if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional')
if mat.dtype == np.float32 or mat.dtype == np.float64: def to_predictor(self):
data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False) predictor = Predictor(booster_handle=self.handle, is_manage_handle=True)
else: self.__is_manage_handle = False
"""change non-float data to float data, need to copy""" return predictor
data = np.array(mat.reshape(mat.size), dtype=np.float32)
ptr_data, type_ptr_data = c_float_array(data)
n_preds = self.__num_class * mat.shape[0]
if predict_type == C_API_PREDICT_LEAF_INDEX:
if num_iteration > 0:
n_preds *= num_iteration
else:
used_iteration = self.__num_total_model / self.__num_class
n_preds *= used_iteration
preds = np.zeros(n_preds, dtype=np.float32)
out_num_preds = ctypes.c_int64(0)
_safe_call(LIB.LGBM_BoosterPredictForMat(
self.handle,
ptr_data,
type_ptr_data,
mat.shape[0],
mat.shape[1],
C_API_IS_ROW_MAJOR,
predict_type,
num_iteration,
ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
))
if n_preds != out_num_preds.value:
raise ValueError("incorrect number for predict result")
return preds, mat.shape[0]
def __pred_for_csr(self, csr, num_iteration, predict_type):
"""
Predict for a csr data
"""
nrow = len(csr.indptr) - 1
n_preds = self.__num_class * nrow
if predict_type == C_API_PREDICT_LEAF_INDEX:
if num_iteration > 0:
n_preds *= num_iteration
else:
used_iteration = self.__num_total_model / self.__num_class
n_preds *= used_iteration
preds = np.zeros(n_preds, dtype=np.float32)
out_num_preds = ctypes.c_int64(0)
ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr)
ptr_data, type_ptr_data = c_float_array(csr.data)
_safe_call(LIB.LGBM_BoosterPredictForCSR(
self.handle,
ptr_indptr,
type_ptr_indptr,
csr.indices.ctypes.data_as(ctypes.c_int32),
ptr_data,
type_ptr_data,
len(csr.indptr),
len(csr.data),
csr.shape[1],
predict_type,
num_iteration,
ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
))
if n_preds != out_num_preds.value:
raise ValueError("incorrect number for predict result")
return preds, nrow
def __inner_eval(self, data_name, data_idx, feval=None): def __inner_eval(self, data_name, data_idx, feval=None):
if data_idx >= self.__num_dataset: if data_idx >= self.__num_dataset:
raise ValueError("data_idx should be smaller than number of dataset") raise ValueError("data_idx should be smaller than number of dataset")
self.__get_eval_info()
ret = [] ret = []
if self.__num_inner_eval > 0: if self.__num_inner_eval > 0:
result = np.array([0.0 for _ in range(self.__num_inner_eval)], dtype=np.float32) result = np.array([0.0 for _ in range(self.__num_inner_eval)], dtype=np.float32)
out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterGetEval( _safe_call(_LIB.LGBM_BoosterGetEval(
self.handle, self.handle,
data_idx, data_idx,
ctypes.byref(out_len), ctypes.byref(tmp_out_len),
result.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))) result.ctypes.data_as(ctypes.POINTER(ctypes.c_float))))
if out_len.value != self.__num_inner_eval: if tmp_out_len.value != self.__num_inner_eval:
raise ValueError("incorrect number of eval results") raise ValueError("incorrect number of eval results")
for i in range(self.__num_inner_eval): for i in range(self.__num_inner_eval):
ret.append('%s %s : %f' %(data_name, self.__name_inner_eval[i], result[i])) ret.append('%s %s : %f' %(data_name, self.__name_inner_eval[i], result[i]))
...@@ -936,33 +1043,37 @@ class Booster(object): ...@@ -936,33 +1043,37 @@ class Booster(object):
num_data = self.valid_sets[data_idx - 1].num_data() * self.__num_class num_data = self.valid_sets[data_idx - 1].num_data() * self.__num_class
self.__inner_predict_buffer[data_idx] = \ self.__inner_predict_buffer[data_idx] = \
np.array([0.0 for _ in range(num_data)], dtype=np.float32, copy=False) np.array([0.0 for _ in range(num_data)], dtype=np.float32, copy=False)
out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_float)) data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_float))
_safe_call(_LIB.LGBM_BoosterGetPredict( _safe_call(_LIB.LGBM_BoosterGetPredict(
self.handle, self.handle,
data_idx, data_idx,
ctypes.byref(out_len), ctypes.byref(tmp_out_len),
data_ptr)) data_ptr))
if out_len.value != len(self.__inner_predict_buffer[data_idx]): if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]):
raise ValueError("incorrect number of predict results for data %d" %(data_idx) ) raise ValueError("incorrect number of predict results for data %d" %(data_idx) )
return self.__inner_predict_buffer[data_idx] return self.__inner_predict_buffer[data_idx]
def __get_eval_info(self):
def __init_continue_train(self, dataset): if self.__need_reload_eval_info:
if dataset.raw_data is None: self.__need_reload_eval_info = False
raise ValueError("should set is_continue_train=True in dataset while need to continue train") out_num_eval = ctypes.c_int64(0)
init_score = self.predict(dataset.raw_data, raw_score=True,data_has_header=dataset.data_has_header, is_reshape=False) """Get num of inner evals"""
dataset.set_init_score(init_score) _safe_call(_LIB.LGBM_BoosterGetEvalCounts(
dataset.free_raw_data() self.handle,
ctypes.byref(out_num_eval)))
self.__num_inner_eval = out_num_eval.value
#tmp test if self.__num_inner_eval > 0:
train_data = Dataset('../../examples/binary_classification/binary.train') """Get name of evals"""
test_data = Dataset('../../examples/binary_classification/binary.test', reference = train_data) tmp_out_len = ctypes.c_int64(0)
param = {"metric":"l2,l1"} string_buffers = [ctypes.create_string_buffer(255) for i in range(self.__num_inner_eval)]
lgb = Booster(train_set=train_data, valid_sets=[test_data], params=param) ptr_string_buffers = (ctypes.c_char_p*self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))
for i in range(100): _safe_call(_LIB.LGBM_BoosterGetEvalNames(
lgb.update() self.handle,
print(lgb.eval_valid()) ctypes.byref(tmp_out_len),
print(lgb.eval_train()) ptr_string_buffers))
print(lgb.predict('../../examples/binary_classification/binary.train')) if self.__num_inner_eval != tmp_out_len.value:
\ No newline at end of file raise ValueError("size of eval names doesn't equal with num_evals")
self.__name_inner_eval = []
for i in range(self.__num_inner_eval):
self.__name_inner_eval.append(string_buffers[i].value.decode())
...@@ -46,12 +46,12 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -46,12 +46,12 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
gbdt_config_ = config; gbdt_config_ = config;
early_stopping_round_ = gbdt_config_->early_stopping_round; early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate; shrinkage_rate_ = gbdt_config_->learning_rate;
train_data_ = train_data; random_ = Random(gbdt_config_->bagging_seed);
// create tree learner // create tree learner
tree_learner_.clear(); tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) { for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config)); auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config));
new_tree_learner->Init(train_data_); new_tree_learner->Init(train_data);
// init tree learner // init tree learner
tree_learner_.push_back(std::move(new_tree_learner)); tree_learner_.push_back(std::move(new_tree_learner));
} }
...@@ -63,24 +63,33 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -63,24 +63,33 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
training_metrics_.push_back(metric); training_metrics_.push_back(metric);
} }
training_metrics_.shrink_to_fit(); training_metrics_.shrink_to_fit();
// create score tracker
train_score_updater_.reset(new ScoreUpdater(train_data_, num_class_));
num_data_ = train_data_->num_data();
// create buffer for gradients and hessians
if (object_function_ != nullptr) {
gradients_ = std::vector<score_t>(num_data_ * num_class_);
hessians_ = std::vector<score_t>(num_data_ * num_class_);
}
sigmoid_ = -1.0f; sigmoid_ = -1.0f;
if (object_function_ != nullptr if (object_function_ != nullptr
&& std::string(object_function_->GetName()) == std::string("binary")) { && std::string(object_function_->GetName()) == std::string("binary")) {
// only binary classification need sigmoid transform // only binary classification need sigmoid transform
sigmoid_ = gbdt_config_->sigmoid; sigmoid_ = gbdt_config_->sigmoid;
} }
if (train_data_ != train_data) {
// not same training data, need reset score and others
// create score tracker
train_score_updater_.reset(new ScoreUpdater(train_data, num_class_));
// update score
for (int i = 0; i < iter_; ++i) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = (i + num_init_iteration_) * num_class_ + curr_class;
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
}
}
num_data_ = train_data->num_data();
// create buffer for gradients and hessians
if (object_function_ != nullptr) {
gradients_ = std::vector<score_t>(num_data_ * num_class_);
hessians_ = std::vector<score_t>(num_data_ * num_class_);
}
// get max feature index // get max feature index
max_feature_idx_ = train_data_->num_total_features() - 1; max_feature_idx_ = train_data->num_total_features() - 1;
// get label index // get label index
label_idx_ = train_data_->label_idx(); label_idx_ = train_data->label_idx();
// if need bagging, create buffer // if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) { if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_); out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
...@@ -91,14 +100,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -91,14 +100,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
bag_data_cnt_ = num_data_; bag_data_cnt_ = num_data_;
bag_data_indices_.clear(); bag_data_indices_.clear();
} }
random_ = Random(gbdt_config_->bagging_seed);
// update score
for (int i = 0; i < iter_; ++i) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
}
} }
train_data_ = train_data;
} }
void GBDT::AddValidDataset(const Dataset* valid_data, void GBDT::AddValidDataset(const Dataset* valid_data,
...@@ -111,7 +114,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data, ...@@ -111,7 +114,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
// update score // update score
for (int i = 0; i < iter_; ++i) { for (int i = 0; i < iter_; ++i) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class; auto curr_tree = (i + num_init_iteration_) * num_class_ + curr_class;
new_score_updater->AddScore(models_[curr_tree].get(), curr_class); new_score_updater->AddScore(models_[curr_tree].get(), curr_class);
} }
} }
...@@ -232,7 +235,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -232,7 +235,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
void GBDT::RollbackOneIter() { void GBDT::RollbackOneIter() {
if (iter_ == 0) { return; } if (iter_ == 0) { return; }
int cur_iter = iter_ - 1; int cur_iter = iter_ + num_init_iteration_ - 1;
// reset score // reset score
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = cur_iter * num_class_ + curr_class; auto curr_tree = cur_iter * num_class_ + curr_class;
......
...@@ -36,12 +36,28 @@ public: ...@@ -36,12 +36,28 @@ public:
const std::vector<const Metric*>& training_metrics) const std::vector<const Metric*>& training_metrics)
override; override;
/*!
* \brief Merge model from other boosting object
Will insert to the front of current boosting object
* \param other
*/
void MergeFrom(const Boosting* other) override { void MergeFrom(const Boosting* other) override {
auto other_gbdt = reinterpret_cast<const GBDT*>(other); auto other_gbdt = reinterpret_cast<const GBDT*>(other);
// tmp move to other vector
auto original_models = std::move(models_);
models_ = std::vector<std::unique_ptr<Tree>>();
// push model from other first
for (const auto& tree : other_gbdt->models_) { for (const auto& tree : other_gbdt->models_) {
auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get()))); auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get())));
models_.push_back(std::move(new_tree)); models_.push_back(std::move(new_tree));
} }
num_init_iteration_ = static_cast<int>(models_.size()) / num_class_;
// push model in current object
for (const auto& tree : original_models) {
auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get())));
models_.push_back(std::move(new_tree));
}
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
} }
/*! /*!
...@@ -266,6 +282,7 @@ protected: ...@@ -266,6 +282,7 @@ protected:
int num_iteration_for_pred_; int num_iteration_for_pred_;
/*! \brief Shrinkage rate for one iteration */ /*! \brief Shrinkage rate for one iteration */
double shrinkage_rate_; double shrinkage_rate_;
/*! \brief Number of loaded initial models */
int num_init_iteration_; int num_init_iteration_;
}; };
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
Log::Warning("continued train from model is not support for c_api, \ Log::Warning("continued train from model is not support for c_api, \
please use continued train with input score"); please use continued train with input score");
} }
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, "")); boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
ConstructObjectAndTrainingMetrics(train_data); ConstructObjectAndTrainingMetrics(train_data);
// initialize the boosting // initialize the boosting
boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(), boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
...@@ -114,6 +114,10 @@ public: ...@@ -114,6 +114,10 @@ public:
return boosting_->TrainOneIter(gradients, hessians, false); return boosting_->TrainOneIter(gradients, hessians, false);
} }
void RollbackOneIter() {
boosting_->RollbackOneIter();
}
void PrepareForPrediction(int num_iteration, int predict_type) { void PrepareForPrediction(int num_iteration, int predict_type) {
boosting_->SetNumIterationForPred(num_iteration); boosting_->SetNumIterationForPred(num_iteration);
bool is_predict_leaf = false; bool is_predict_leaf = false;
...@@ -156,24 +160,13 @@ public: ...@@ -156,24 +160,13 @@ public:
int idx = 0; int idx = 0;
for (const auto& metric : train_metric_) { for (const auto& metric : train_metric_) {
for (const auto& name : metric->GetName()) { for (const auto& name : metric->GetName()) {
int j = 0; std::strcpy(out_strs[idx], name.c_str());
auto name_cstr = name.c_str();
while (name_cstr[j] != '\0') {
out_strs[idx][j] = name_cstr[j];
++j;
}
out_strs[idx][j] = '\0';
++idx; ++idx;
} }
} }
return idx; return idx;
} }
void RollbackOneIter() {
boosting_->RollbackOneIter();
}
const Boosting* GetBoosting() const { return boosting_.get(); } const Boosting* GetBoosting() const { return boosting_.get(); }
private: private:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment