Commit f059d0fe authored by Guolin Ke's avatar Guolin Ke
Browse files

fix some bugs in basic.py

parent 6e0b58ba
...@@ -78,10 +78,21 @@ def is_numpy_1d_array(data): ...@@ -78,10 +78,21 @@ def is_numpy_1d_array(data):
else: else:
return False return False
def is_1d_list(data):
if not isinstance(data, list):
return False
if len(data) > 0:
if not isinstance(data[0], (int, str, bool) ):
return False
return True
def list_to_1d_numpy(data, dtype): def list_to_1d_numpy(data, dtype):
if is_numpy_1d_array(data): if is_numpy_1d_array(data):
return data if data.dtype == dtype:
elif isinstance(data, list): return data
else:
return data.astype(dtype=dtype, copy=False)
elif is_1d_list(data):
return np.array(data, dtype=dtype, copy=False) return np.array(data, dtype=dtype, copy=False)
else: else:
raise TypeError("Unknow type({})".format(type(data).__name__)) raise TypeError("Unknow type({})".format(type(data).__name__))
...@@ -140,7 +151,7 @@ FIELD_TYPE_MAPPER = {"label":C_API_DTYPE_FLOAT32, ...@@ -140,7 +151,7 @@ FIELD_TYPE_MAPPER = {"label":C_API_DTYPE_FLOAT32,
def c_float_array(data): def c_float_array(data):
"""Convert numpy array / list to c float array.""" """Convert numpy array / list to c float array."""
if isinstance(data, list): if is_1d_list(data):
data = np.array(data, copy=False) data = np.array(data, copy=False)
if is_numpy_1d_array(data): if is_numpy_1d_array(data):
if data.dtype == np.float32: if data.dtype == np.float32:
...@@ -157,7 +168,7 @@ def c_float_array(data): ...@@ -157,7 +168,7 @@ def c_float_array(data):
def c_int_array(data): def c_int_array(data):
"""Convert numpy array to c int array.""" """Convert numpy array to c int array."""
if isinstance(data, list): if is_1d_list(data):
data = np.array(data, copy=False) data = np.array(data, copy=False)
if is_numpy_1d_array(data): if is_numpy_1d_array(data):
if data.dtype == np.int32: if data.dtype == np.int32:
...@@ -256,7 +267,7 @@ class Predictor(object): ...@@ -256,7 +267,7 @@ class Predictor(object):
else: else:
try: try:
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
res = self.__pred_for_csr(csr, num_iteration, predict_type) preds, nrow = self.__pred_for_csr(csr, num_iteration, predict_type)
except: except:
raise TypeError('can not predict data for type {}'.format(type(data).__name__)) raise TypeError('can not predict data for type {}'.format(type(data).__name__))
if pred_leaf: if pred_leaf:
...@@ -417,7 +428,7 @@ class Dataset(object): ...@@ -417,7 +428,7 @@ class Dataset(object):
else: else:
try: try:
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
self.__init_from_csr(csr) self.__init_from_csr(csr, params_str, ref_dataset)
except: except:
raise TypeError('can not initialize Dataset from {}'.format(type(data).__name__)) raise TypeError('can not initialize Dataset from {}'.format(type(data).__name__))
self.__label = None self.__label = None
...@@ -618,8 +629,6 @@ class Dataset(object): ...@@ -618,8 +629,6 @@ class Dataset(object):
The label information to be set into Dataset The label information to be set into Dataset
""" """
label = list_to_1d_numpy(label, np.float32) label = list_to_1d_numpy(label, np.float32)
if label.dtype != np.float32:
label = label.astype(np.float32, copy=False)
self.__label = label self.__label = label
self.set_field('label', label) self.set_field('label', label)
...@@ -633,8 +642,6 @@ class Dataset(object): ...@@ -633,8 +642,6 @@ class Dataset(object):
""" """
if weight is not None: if weight is not None:
weight = list_to_1d_numpy(weight, np.float32) weight = list_to_1d_numpy(weight, np.float32)
if weight.dtype != np.float32:
weight = weight.astype(np.float32, copy=False)
self.__weight = weight self.__weight = weight
self.set_field('weight', weight) self.set_field('weight', weight)
...@@ -647,8 +654,6 @@ class Dataset(object): ...@@ -647,8 +654,6 @@ class Dataset(object):
""" """
if score is not None: if score is not None:
score = list_to_1d_numpy(score, np.float32) score = list_to_1d_numpy(score, np.float32)
if score.dtype != np.float32:
score = score.astype(np.float32, copy=False)
self.__init_score = score self.__init_score = score
self.set_field('init_score', score) self.set_field('init_score', score)
...@@ -662,8 +667,6 @@ class Dataset(object): ...@@ -662,8 +667,6 @@ class Dataset(object):
""" """
if group is not None: if group is not None:
group = list_to_1d_numpy(group, np.int32) group = list_to_1d_numpy(group, np.int32)
if group.dtype != np.int32:
group = group.astype(np.int32, copy=False)
self.__group = group self.__group = group
self.set_field('group', group) self.set_field('group', group)
...@@ -678,8 +681,6 @@ class Dataset(object): ...@@ -678,8 +681,6 @@ class Dataset(object):
""" """
if group_id is not None: if group_id is not None:
group_id = list_to_1d_numpy(group_id, np.int32) group_id = list_to_1d_numpy(group_id, np.int32)
if group_id.dtype != np.int32:
group_id = group_id.astype(np.int32, copy=False)
self.set_field('group_id', group_id) self.set_field('group_id', group_id)
def get_label(self): def get_label(self):
...@@ -890,26 +891,36 @@ class Booster(object): ...@@ -890,26 +891,36 @@ class Booster(object):
and you should group grad and hess in this way as well and you should group grad and hess in this way as well
Parameters Parameters
---------- ----------
grad : 1d numpy with dtype=float32 grad : 1d numpy or list
The first order of gradient. The first order of gradient.
hess : 1d numpy with dtype=float32 hess : 1d numpy or list
The second order of gradient. The second order of gradient.
Returns Returns
------- -------
is_finished, bool is_finished, bool
""" """
if not is_numpy_1d_array(grad) and not is_numpy_1d_array(hess): if not is_numpy_1d_array(grad):
raise TypeError('type of grad / hess should be 1d numpy object') if is_1d_list(grad):
if not grad.dtype == np.float32 and not hess.dtype == np.float32: grad = np.array(grad, dtype=np.float32, copy=False)
raise TypeError('type of grad / hess should be np.float32') else:
raise TypeError("grad should be numpy 1d array or 1d list")
if not is_numpy_1d_array(hess):
if is_1d_list(hess):
hess = np.array(hess, dtype=np.float32, copy=False)
else:
raise TypeError("hess should be numpy 1d array or 1d list")
if len(grad) != len(hess): if len(grad) != len(hess):
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess))) raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
if grad.dtype != np.float32:
grad = grad.astype(np.float32, copy=False)
if hess.dtype != np.float32:
hess = hess.astype(np.float32, copy=False)
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom( _safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom(
self.handle, self.handle,
grad.ctypes.data_as(ctypes.ctypes.POINTER(ctypes.c_float)), grad.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
hess.ctypes.data_as(ctypes.ctypes.POINTER(ctypes.c_float)), hess.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
ctypes.byref(is_finished))) ctypes.byref(is_finished)))
return is_finished.value == 1 return is_finished.value == 1
...@@ -950,7 +961,7 @@ class Booster(object): ...@@ -950,7 +961,7 @@ class Booster(object):
break break
"""need push new valid data""" """need push new valid data"""
if data_idx == -1: if data_idx == -1:
self.add_valid_data(data, name) self.add_valid(data, name)
data_idx = self.__num_dataset - 1 data_idx = self.__num_dataset - 1
return self.__inner_eval(name, data_idx, feval) return self.__inner_eval(name, data_idx, feval)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment