Commit fa4ecfda authored by Guolin Ke's avatar Guolin Ke
Browse files

add constructor for booster

parent de114be5
......@@ -110,6 +110,13 @@ def c_array(ctype, values):
"""Convert a python array to c array."""
return (ctype * len(values))(*values)
def dict_to_str(data):
if len(data) == 0:
return ""
pairs = []
for key in data:
pairs.append(str(key)+'='+str(data[key]))
return ' '.join(pairs)
"""marco definition of data type in c_api of LightGBM"""
C_API_DTYPE_FLOAT32 =0
C_API_DTYPE_FLOAT64 =1
......@@ -164,7 +171,7 @@ class Dataset(object):
def __init__(self, data, max_bin=255, reference=None,
label=None, weight=None, group_id=None,
silent=False, feature_names=None,
other_args=None):
other_params=None, is_continue_train=False):
"""
Dataset used in LightGBM.
......@@ -187,20 +194,27 @@ class Dataset(object):
Whether print messages during construction
feature_names : list, optional
Set names for features.
other_args: list, optional
other parameters, format: ['key1=val1','key2=val2']
other_params: dict, optional
other parameters
"""
if data is None:
self.handle = None
return
"""save raw data for continue train """
if is_continue_train:
self.raw_data = data
else:
self.raw_data = None
"""process for args"""
pass_args = ["max_bin={}".format(max_bin)]
params = {}
params["max_bin"] = max_bin
if silent:
pass_args.append("verbose=0")
if other_args:
pass_args += other_args
pass_args_str = ' '.join(pass_args)
params["verbose"] = 0
if other_params:
other_params.update(params)
params = other_params
params_str = dict_to_str(params)
"""process for reference dataset"""
ref_dataset = None
if isinstance(reference, Dataset):
......@@ -212,15 +226,15 @@ class Dataset(object):
self.handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_CreateDatasetFromFile(
c_str(data),
c_str(pass_args_str),
c_str(params_str),
ref_dataset,
ctypes.byref(self.handle)))
elif isinstance(data, scipy.sparse.csr_matrix):
self._init_from_csr(data, pass_args_str, ref_dataset)
self._init_from_csr(data, params_str, ref_dataset)
elif isinstance(data, scipy.sparse.csc_matrix):
self._init_from_csc(data, pass_args_str, ref_dataset)
self._init_from_csc(data, params_str, ref_dataset)
elif isinstance(data, np.ndarray):
self._init_from_npy2d(data, pass_args_str, ref_dataset)
self._init_from_npy2d(data, params_str, ref_dataset)
else:
try:
csr = scipy.sparse.csr_matrix(data)
......@@ -235,7 +249,10 @@ class Dataset(object):
self.set_group_id(group_id)
self.feature_names = feature_names
def _init_from_csr(self, csr, pass_args_str, ref_dataset):
def free_raw_data(self):
self.raw_data = None
def _init_from_csr(self, csr, params_str, ref_dataset):
"""
Initialize data from a CSR matrix.
"""
......@@ -255,11 +272,11 @@ class Dataset(object):
len(csr.indptr),
len(csr.data),
csr.shape[1],
c_str(pass_args_str),
c_str(params_str),
ref_dataset,
ctypes.byref(self.handle)))
def _init_from_csc(self, csr, pass_args_str, ref_dataset):
def _init_from_csc(self, csr, params_str, ref_dataset):
"""
Initialize data from a CSC matrix.
"""
......@@ -279,11 +296,11 @@ class Dataset(object):
len(csc.indptr),
len(csc.data),
csc.shape[0],
c_str(pass_args_str),
c_str(params_str),
ref_dataset,
ctypes.byref(self.handle)))
def _init_from_npy2d(self, mat, pass_args_str, ref_dataset):
def _init_from_npy2d(self, mat, params_str, ref_dataset):
"""
Initialize data from a 2-D numpy matrix.
"""
......@@ -304,7 +321,7 @@ class Dataset(object):
mat.shape[0],
mat.shape[1],
C_API_IS_ROW_MAJOR,
c_str(pass_args_str),
c_str(params_str),
ref_dataset,
ctypes.byref(self.handle)))
......@@ -536,3 +553,79 @@ class Dataset(object):
else:
self._feature_names = None
class Booster(object):
""""A Booster of of LightGBM.
"""
feature_names = None
def __init__(self, params=None,
train_set=None,
valid_sets=None,
name_valid_sets=None,
model_file=None,
fobj=None):
# pylint: disable=invalid-name
"""Initialize the Booster.
Parameters
----------
params : dict
Parameters for boosters.
train_set : Dataset
training dataset
valid_sets : List of Dataset or None
validation datasets
name_valid_sets : List of string
name of validation datasets
model_file : string
Path to the model file.
"""
self.handle = ctypes.c_void_p()
if train_set is not None:
if not isinstance(train_set, Dataset):
raise TypeError('training data should be Dataset instance, met{}'.format(type(train_set).__name__))
valid_handles = None
valid_cnames = None
n_valid = 0
if valid_sets is not None:
for valid in valid_sets:
if not isinstance(valid, Dataset):
raise TypeError('valid data should be Dataset instance, met{}'.format(type(valid).__name__))
valid_handles = c_array(ctypes.c_void_p, [valid.handle for valid in valid_sets])
if name_valid_sets is None:
name_valid_sets = ["valid_{}".format(x) for x in range(len(valid_sets)) ]
if len(valid_sets) != len(name_valid_sets):
raise Exception('len of valid_sets should be equal with len of name_valid_sets')
valid_cnames = c_array(ctypes.c_char_p, [c_str(x) for x in name_valid_sets])
n_valid = len(valid_sets)
ref_input_model = None
params_str = dict_to_str(params)
if model_file is not None:
ref_input_model = c_str(model_file)
"""construct booster object"""
_safe_call(LIB.LGBM_BoosterCreate(
train_set.handle,
valid_handles,
valid_cnames,
n_valid,
params_str,
ref_input_model,
ctypes.byref(self.handle)))
"""if need to continue train"""
if model_file is not None:
self.init_continue_train(train_set)
if valid_sets is not None:
for valid in valid_sets:
self.init_continue_train(valid)
elif model_file is not None:
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile(c_str(model_file), ctypes.byref(self.handle)))
else:
raise TypeError('At least need training dataset or model file to create booster instance')
def __del__(self):
_LIB.LGBM_BoosterFree(self.handle)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment