"src/vscode:/vscode.git/clone" did not exist on "d038aa5716a3e1db0ce717eeef469df366b7aade"
Commit fa4ecfda authored by Guolin Ke's avatar Guolin Ke
Browse files

add constructor for booster

parent de114be5
...@@ -110,6 +110,13 @@ def c_array(ctype, values): ...@@ -110,6 +110,13 @@ def c_array(ctype, values):
"""Convert a python array to c array.""" """Convert a python array to c array."""
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
def dict_to_str(data):
if len(data) == 0:
return ""
pairs = []
for key in data:
pairs.append(str(key)+'='+str(data[key]))
return ' '.join(pairs)
"""marco definition of data type in c_api of LightGBM""" """marco definition of data type in c_api of LightGBM"""
C_API_DTYPE_FLOAT32 =0 C_API_DTYPE_FLOAT32 =0
C_API_DTYPE_FLOAT64 =1 C_API_DTYPE_FLOAT64 =1
...@@ -164,7 +171,7 @@ class Dataset(object): ...@@ -164,7 +171,7 @@ class Dataset(object):
def __init__(self, data, max_bin=255, reference=None, def __init__(self, data, max_bin=255, reference=None,
label=None, weight=None, group_id=None, label=None, weight=None, group_id=None,
silent=False, feature_names=None, silent=False, feature_names=None,
other_args=None): other_params=None, is_continue_train=False):
""" """
Dataset used in LightGBM. Dataset used in LightGBM.
...@@ -187,20 +194,27 @@ class Dataset(object): ...@@ -187,20 +194,27 @@ class Dataset(object):
Whether print messages during construction Whether print messages during construction
feature_names : list, optional feature_names : list, optional
Set names for features. Set names for features.
other_args: list, optional other_params: dict, optional
other parameters, format: ['key1=val1','key2=val2'] other parameters
""" """
if data is None: if data is None:
self.handle = None self.handle = None
return return
"""save raw data for continue train """
if is_continue_train:
self.raw_data = data
else:
self.raw_data = None
"""process for args""" """process for args"""
pass_args = ["max_bin={}".format(max_bin)] params = {}
params["max_bin"] = max_bin
if silent: if silent:
pass_args.append("verbose=0") params["verbose"] = 0
if other_args: if other_params:
pass_args += other_args other_params.update(params)
pass_args_str = ' '.join(pass_args) params = other_params
params_str = dict_to_str(params)
"""process for reference dataset""" """process for reference dataset"""
ref_dataset = None ref_dataset = None
if isinstance(reference, Dataset): if isinstance(reference, Dataset):
...@@ -212,15 +226,15 @@ class Dataset(object): ...@@ -212,15 +226,15 @@ class Dataset(object):
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_CreateDatasetFromFile( _safe_call(_LIB.LGBM_CreateDatasetFromFile(
c_str(data), c_str(data),
c_str(pass_args_str), c_str(params_str),
ref_dataset, ref_dataset,
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
elif isinstance(data, scipy.sparse.csr_matrix): elif isinstance(data, scipy.sparse.csr_matrix):
self._init_from_csr(data, pass_args_str, ref_dataset) self._init_from_csr(data, params_str, ref_dataset)
elif isinstance(data, scipy.sparse.csc_matrix): elif isinstance(data, scipy.sparse.csc_matrix):
self._init_from_csc(data, pass_args_str, ref_dataset) self._init_from_csc(data, params_str, ref_dataset)
elif isinstance(data, np.ndarray): elif isinstance(data, np.ndarray):
self._init_from_npy2d(data, pass_args_str, ref_dataset) self._init_from_npy2d(data, params_str, ref_dataset)
else: else:
try: try:
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
...@@ -235,7 +249,10 @@ class Dataset(object): ...@@ -235,7 +249,10 @@ class Dataset(object):
self.set_group_id(group_id) self.set_group_id(group_id)
self.feature_names = feature_names self.feature_names = feature_names
def _init_from_csr(self, csr, pass_args_str, ref_dataset): def free_raw_data(self):
self.raw_data = None
def _init_from_csr(self, csr, params_str, ref_dataset):
""" """
Initialize data from a CSR matrix. Initialize data from a CSR matrix.
""" """
...@@ -255,11 +272,11 @@ class Dataset(object): ...@@ -255,11 +272,11 @@ class Dataset(object):
len(csr.indptr), len(csr.indptr),
len(csr.data), len(csr.data),
csr.shape[1], csr.shape[1],
c_str(pass_args_str), c_str(params_str),
ref_dataset, ref_dataset,
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
def _init_from_csc(self, csr, pass_args_str, ref_dataset): def _init_from_csc(self, csr, params_str, ref_dataset):
""" """
Initialize data from a CSC matrix. Initialize data from a CSC matrix.
""" """
...@@ -279,11 +296,11 @@ class Dataset(object): ...@@ -279,11 +296,11 @@ class Dataset(object):
len(csc.indptr), len(csc.indptr),
len(csc.data), len(csc.data),
csc.shape[0], csc.shape[0],
c_str(pass_args_str), c_str(params_str),
ref_dataset, ref_dataset,
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
def _init_from_npy2d(self, mat, pass_args_str, ref_dataset): def _init_from_npy2d(self, mat, params_str, ref_dataset):
""" """
Initialize data from a 2-D numpy matrix. Initialize data from a 2-D numpy matrix.
""" """
...@@ -304,7 +321,7 @@ class Dataset(object): ...@@ -304,7 +321,7 @@ class Dataset(object):
mat.shape[0], mat.shape[0],
mat.shape[1], mat.shape[1],
C_API_IS_ROW_MAJOR, C_API_IS_ROW_MAJOR,
c_str(pass_args_str), c_str(params_str),
ref_dataset, ref_dataset,
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
...@@ -536,3 +553,79 @@ class Dataset(object): ...@@ -536,3 +553,79 @@ class Dataset(object):
else: else:
self._feature_names = None self._feature_names = None
class Booster(object):
""""A Booster of of LightGBM.
"""
feature_names = None
def __init__(self, params=None,
train_set=None,
valid_sets=None,
name_valid_sets=None,
model_file=None,
fobj=None):
# pylint: disable=invalid-name
"""Initialize the Booster.
Parameters
----------
params : dict
Parameters for boosters.
train_set : Dataset
training dataset
valid_sets : List of Dataset or None
validation datasets
name_valid_sets : List of string
name of validation datasets
model_file : string
Path to the model file.
"""
self.handle = ctypes.c_void_p()
if train_set is not None:
if not isinstance(train_set, Dataset):
raise TypeError('training data should be Dataset instance, met{}'.format(type(train_set).__name__))
valid_handles = None
valid_cnames = None
n_valid = 0
if valid_sets is not None:
for valid in valid_sets:
if not isinstance(valid, Dataset):
raise TypeError('valid data should be Dataset instance, met{}'.format(type(valid).__name__))
valid_handles = c_array(ctypes.c_void_p, [valid.handle for valid in valid_sets])
if name_valid_sets is None:
name_valid_sets = ["valid_{}".format(x) for x in range(len(valid_sets)) ]
if len(valid_sets) != len(name_valid_sets):
raise Exception('len of valid_sets should be equal with len of name_valid_sets')
valid_cnames = c_array(ctypes.c_char_p, [c_str(x) for x in name_valid_sets])
n_valid = len(valid_sets)
ref_input_model = None
params_str = dict_to_str(params)
if model_file is not None:
ref_input_model = c_str(model_file)
"""construct booster object"""
_safe_call(LIB.LGBM_BoosterCreate(
train_set.handle,
valid_handles,
valid_cnames,
n_valid,
params_str,
ref_input_model,
ctypes.byref(self.handle)))
"""if need to continue train"""
if model_file is not None:
self.init_continue_train(train_set)
if valid_sets is not None:
for valid in valid_sets:
self.init_continue_train(valid)
elif model_file is not None:
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile(c_str(model_file), ctypes.byref(self.handle)))
else:
raise TypeError('At least need training dataset or model file to create booster instance')
def __del__(self):
_LIB.LGBM_BoosterFree(self.handle)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment