Commit f059d0fe authored by Guolin Ke's avatar Guolin Ke
Browse files

fix some bugs in basic.py

parent 6e0b58ba
......@@ -78,10 +78,21 @@ def is_numpy_1d_array(data):
else:
return False
def is_1d_list(data):
if not isinstance(data, list):
return False
if len(data) > 0:
if not isinstance(data[0], (int, str, bool) ):
return False
return True
def list_to_1d_numpy(data, dtype):
if is_numpy_1d_array(data):
if data.dtype == dtype:
return data
elif isinstance(data, list):
else:
return data.astype(dtype=dtype, copy=False)
elif is_1d_list(data):
return np.array(data, dtype=dtype, copy=False)
else:
raise TypeError("Unknow type({})".format(type(data).__name__))
......@@ -140,7 +151,7 @@ FIELD_TYPE_MAPPER = {"label":C_API_DTYPE_FLOAT32,
def c_float_array(data):
"""Convert numpy array / list to c float array."""
if isinstance(data, list):
if is_1d_list(data):
data = np.array(data, copy=False)
if is_numpy_1d_array(data):
if data.dtype == np.float32:
......@@ -157,7 +168,7 @@ def c_float_array(data):
def c_int_array(data):
"""Convert numpy array to c int array."""
if isinstance(data, list):
if is_1d_list(data):
data = np.array(data, copy=False)
if is_numpy_1d_array(data):
if data.dtype == np.int32:
......@@ -256,7 +267,7 @@ class Predictor(object):
else:
try:
csr = scipy.sparse.csr_matrix(data)
res = self.__pred_for_csr(csr, num_iteration, predict_type)
preds, nrow = self.__pred_for_csr(csr, num_iteration, predict_type)
except:
raise TypeError('can not predict data for type {}'.format(type(data).__name__))
if pred_leaf:
......@@ -417,7 +428,7 @@ class Dataset(object):
else:
try:
csr = scipy.sparse.csr_matrix(data)
self.__init_from_csr(csr)
self.__init_from_csr(csr, params_str, ref_dataset)
except:
raise TypeError('can not initialize Dataset from {}'.format(type(data).__name__))
self.__label = None
......@@ -618,8 +629,6 @@ class Dataset(object):
The label information to be set into Dataset
"""
label = list_to_1d_numpy(label, np.float32)
if label.dtype != np.float32:
label = label.astype(np.float32, copy=False)
self.__label = label
self.set_field('label', label)
......@@ -633,8 +642,6 @@ class Dataset(object):
"""
if weight is not None:
weight = list_to_1d_numpy(weight, np.float32)
if weight.dtype != np.float32:
weight = weight.astype(np.float32, copy=False)
self.__weight = weight
self.set_field('weight', weight)
......@@ -647,8 +654,6 @@ class Dataset(object):
"""
if score is not None:
score = list_to_1d_numpy(score, np.float32)
if score.dtype != np.float32:
score = score.astype(np.float32, copy=False)
self.__init_score = score
self.set_field('init_score', score)
......@@ -662,8 +667,6 @@ class Dataset(object):
"""
if group is not None:
group = list_to_1d_numpy(group, np.int32)
if group.dtype != np.int32:
group = group.astype(np.int32, copy=False)
self.__group = group
self.set_field('group', group)
......@@ -678,8 +681,6 @@ class Dataset(object):
"""
if group_id is not None:
group_id = list_to_1d_numpy(group_id, np.int32)
if group_id.dtype != np.int32:
group_id = group_id.astype(np.int32, copy=False)
self.set_field('group_id', group_id)
def get_label(self):
......@@ -890,26 +891,36 @@ class Booster(object):
and you should group grad and hess in this way as well
Parameters
----------
grad : 1d numpy with dtype=float32
grad : 1d numpy or list
The first order of gradient.
hess : 1d numpy with dtype=float32
hess : 1d numpy or list
The second order of gradient.
Returns
-------
is_finished, bool
"""
if not is_numpy_1d_array(grad) and not is_numpy_1d_array(hess):
raise TypeError('type of grad / hess should be 1d numpy object')
if not grad.dtype == np.float32 and not hess.dtype == np.float32:
raise TypeError('type of grad / hess should be np.float32')
if not is_numpy_1d_array(grad):
if is_1d_list(grad):
grad = np.array(grad, dtype=np.float32, copy=False)
else:
raise TypeError("grad should be numpy 1d array or 1d list")
if not is_numpy_1d_array(hess):
if is_1d_list(hess):
hess = np.array(hess, dtype=np.float32, copy=False)
else:
raise TypeError("hess should be numpy 1d array or 1d list")
if len(grad) != len(hess):
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
if grad.dtype != np.float32:
grad = grad.astype(np.float32, copy=False)
if hess.dtype != np.float32:
hess = hess.astype(np.float32, copy=False)
is_finished = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom(
self.handle,
grad.ctypes.data_as(ctypes.ctypes.POINTER(ctypes.c_float)),
hess.ctypes.data_as(ctypes.ctypes.POINTER(ctypes.c_float)),
grad.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
hess.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
ctypes.byref(is_finished)))
return is_finished.value == 1
......@@ -950,7 +961,7 @@ class Booster(object):
break
"""need push new valid data"""
if data_idx == -1:
self.add_valid_data(data, name)
self.add_valid(data, name)
data_idx = self.__num_dataset - 1
return self.__inner_eval(name, data_idx, feval)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment