Unverified Commit e1a02ab3 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix indices type in csr and csc (#1719)

* fix indices type in csr and csc

* fix
parent 87c39dec
...@@ -548,6 +548,9 @@ class _InnerPredictor(object): ...@@ -548,6 +548,9 @@ class _InnerPredictor(object):
ptr_indptr, type_ptr_indptr, __ = c_int_array(csr.indptr) ptr_indptr, type_ptr_indptr, __ = c_int_array(csr.indptr)
ptr_data, type_ptr_data, _ = c_float_array(csr.data) ptr_data, type_ptr_data, _ = c_float_array(csr.data)
assert csr.shape[1] <= MAX_INT32
csr.indices = csr.indices.astype(np.int32, copy=False)
_safe_call(_LIB.LGBM_BoosterPredictForCSR( _safe_call(_LIB.LGBM_BoosterPredictForCSR(
self.handle, self.handle,
ptr_indptr, ptr_indptr,
...@@ -596,6 +599,9 @@ class _InnerPredictor(object): ...@@ -596,6 +599,9 @@ class _InnerPredictor(object):
ptr_indptr, type_ptr_indptr, __ = c_int_array(csc.indptr) ptr_indptr, type_ptr_indptr, __ = c_int_array(csc.indptr)
ptr_data, type_ptr_data, _ = c_float_array(csc.data) ptr_data, type_ptr_data, _ = c_float_array(csc.data)
assert csc.shape[0] <= MAX_INT32
csc.indices = csc.indices.astype(np.int32, copy=False)
_safe_call(_LIB.LGBM_BoosterPredictForCSC( _safe_call(_LIB.LGBM_BoosterPredictForCSC(
self.handle, self.handle,
ptr_indptr, ptr_indptr,
...@@ -888,6 +894,9 @@ class Dataset(object): ...@@ -888,6 +894,9 @@ class Dataset(object):
ptr_indptr, type_ptr_indptr, __ = c_int_array(csr.indptr) ptr_indptr, type_ptr_indptr, __ = c_int_array(csr.indptr)
ptr_data, type_ptr_data, _ = c_float_array(csr.data) ptr_data, type_ptr_data, _ = c_float_array(csr.data)
assert csr.shape[1] <= MAX_INT32
csr.indices = csr.indices.astype(np.int32, copy=False)
_safe_call(_LIB.LGBM_DatasetCreateFromCSR( _safe_call(_LIB.LGBM_DatasetCreateFromCSR(
ptr_indptr, ptr_indptr,
ctypes.c_int(type_ptr_indptr), ctypes.c_int(type_ptr_indptr),
...@@ -913,6 +922,9 @@ class Dataset(object): ...@@ -913,6 +922,9 @@ class Dataset(object):
ptr_indptr, type_ptr_indptr, __ = c_int_array(csc.indptr) ptr_indptr, type_ptr_indptr, __ = c_int_array(csc.indptr)
ptr_data, type_ptr_data, _ = c_float_array(csc.data) ptr_data, type_ptr_data, _ = c_float_array(csc.data)
assert csc.shape[0] <= MAX_INT32
csc.indices = csc.indices.astype(np.int32, copy=False)
_safe_call(_LIB.LGBM_DatasetCreateFromCSC( _safe_call(_LIB.LGBM_DatasetCreateFromCSC(
ptr_indptr, ptr_indptr,
ctypes.c_int(type_ptr_indptr), ctypes.c_int(type_ptr_indptr),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment