Commit bba800df authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[python] handle list for prediction (#625), fix #621

* handle list for prediction

* fix error msg
parent bc3d961f
......@@ -414,16 +414,21 @@ class _InnerPredictor(object):
elif isinstance(data, np.ndarray):
preds, nrow = self.__pred_for_np2d(data, num_iteration,
predict_type)
elif isinstance(data, DataFrame):
preds, nrow = self.__pred_for_np2d(data.values, num_iteration,
predict_type, early_stop_instance_handle)
elif isinstance(data, list):
try:
data = np.array(data)
except:
raise ValueError('Cannot convert data list to numpy array.')
preds, nrow = self.__pred_for_np2d(data, num_iteration,
predict_type)
else:
try:
warnings.warn('Converting data to scipy sparse matrix.')
csr = scipy.sparse.csr_matrix(data)
preds, nrow = self.__pred_for_csr(csr, num_iteration,
predict_type)
except:
raise TypeError('Cannot predict data for type {}'.format(type(data).__name__))
preds, nrow = self.__pred_for_csr(csr, num_iteration,
predict_type)
if pred_leaf:
preds = preds.astype(np.int32)
if is_reshape and preds.size != nrow:
......@@ -452,7 +457,7 @@ class _InnerPredictor(object):
Predict for a 2-D numpy matrix.
"""
if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional')
raise ValueError('Input numpy.ndarray or list must be 2 dimensional')
if mat.dtype == np.float32 or mat.dtype == np.float64:
data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment