Commit bba800df authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[python] handle list for prediction (#625), fix #621

* handle list for prediction

* fix error msg
parent bc3d961f
...@@ -414,16 +414,21 @@ class _InnerPredictor(object): ...@@ -414,16 +414,21 @@ class _InnerPredictor(object):
elif isinstance(data, np.ndarray): elif isinstance(data, np.ndarray):
preds, nrow = self.__pred_for_np2d(data, num_iteration, preds, nrow = self.__pred_for_np2d(data, num_iteration,
predict_type) predict_type)
elif isinstance(data, DataFrame): elif isinstance(data, list):
preds, nrow = self.__pred_for_np2d(data.values, num_iteration, try:
predict_type, early_stop_instance_handle) data = np.array(data)
except:
raise ValueError('Cannot convert data list to numpy array.')
preds, nrow = self.__pred_for_np2d(data, num_iteration,
predict_type)
else: else:
try: try:
warnings.warn('Converting data to scipy sparse matrix.')
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
preds, nrow = self.__pred_for_csr(csr, num_iteration,
predict_type)
except: except:
raise TypeError('Cannot predict data for type {}'.format(type(data).__name__)) raise TypeError('Cannot predict data for type {}'.format(type(data).__name__))
preds, nrow = self.__pred_for_csr(csr, num_iteration,
predict_type)
if pred_leaf: if pred_leaf:
preds = preds.astype(np.int32) preds = preds.astype(np.int32)
if is_reshape and preds.size != nrow: if is_reshape and preds.size != nrow:
...@@ -452,7 +457,7 @@ class _InnerPredictor(object): ...@@ -452,7 +457,7 @@ class _InnerPredictor(object):
Predict for a 2-D numpy matrix. Predict for a 2-D numpy matrix.
""" """
if len(mat.shape) != 2: if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional') raise ValueError('Input numpy.ndarray or list must be 2 dimensional')
if mat.dtype == np.float32 or mat.dtype == np.float64: if mat.dtype == np.float32 or mat.dtype == np.float64:
data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False) data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment