Unverified Commit b5027de3 authored by Alberto Ferreira's avatar Alberto Ferreira Committed by GitHub
Browse files

Fast single row predict API v2 (#3268)

* Fix bug introduced in PR #2992 for Fast predict

* Faster Fast predict API

* Add const to SingleRow Fast methods
parent a9f5654b
...@@ -862,6 +862,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, ...@@ -862,6 +862,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
* Release the ``FastConfig`` by passing its handle to ``LGBM_FastConfigFree`` when no longer needed. * Release the ``FastConfig`` by passing its handle to ``LGBM_FastConfigFree`` when no longer needed.
* *
* \param handle Booster handle * \param handle Booster handle
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param num_col Number of columns * \param num_col Number of columns
* \param parameter Other parameters for prediction, e.g. early stopping for prediction * \param parameter Other parameters for prediction, e.g. early stopping for prediction
...@@ -869,6 +875,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, ...@@ -869,6 +875,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
* \return 0 when it succeeds, -1 when failure happens * \return 0 when it succeeds, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
const int predict_type,
const int num_iteration,
const int data_type, const int data_type,
const int64_t num_col, const int64_t num_col,
const char* parameter, const char* parameter,
...@@ -901,25 +909,17 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle h ...@@ -901,25 +909,17 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle h
* \param data Pointer to the data space * \param data Pointer to the data space
* \param nindptr Number of rows in the matrix + 1 * \param nindptr Number of rows in the matrix + 1
* \param nelem Number of nonzero elements in the matrix * \param nelem Number of nonzero elements in the matrix
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param[out] out_len Length of output result * \param[out] out_len Length of output result
* \param[out] out_result Pointer to array with predictions * \param[out] out_result Pointer to array with predictions
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fastConfig_handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fastConfig_handle,
const void* indptr, const void* indptr,
int indptr_type, const int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int64_t nindptr, const int64_t nindptr,
int64_t nelem, const int64_t nelem,
int predict_type,
int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
...@@ -1042,6 +1042,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, ...@@ -1042,6 +1042,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
* Release the ``FastConfig`` by passing its handle to ``LGBM_FastConfigFree`` when no longer needed. * Release the ``FastConfig`` by passing its handle to ``LGBM_FastConfigFree`` when no longer needed.
* *
* \param handle Booster handle * \param handle Booster handle
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param ncol Number of columns * \param ncol Number of columns
* \param parameter Other parameters for prediction, e.g. early stopping for prediction * \param parameter Other parameters for prediction, e.g. early stopping for prediction
...@@ -1049,8 +1055,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, ...@@ -1049,8 +1055,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
* \return 0 when it succeeds, -1 when failure happens * \return 0 when it succeeds, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
int data_type, const int predict_type,
int32_t ncol, const int num_iteration,
const int data_type,
const int32_t ncol,
const char* parameter, const char* parameter,
FastConfigHandle *out_fastConfig); FastConfigHandle *out_fastConfig);
...@@ -1070,20 +1078,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle h ...@@ -1070,20 +1078,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle h
* *
* \param fastConfig_handle FastConfig object handle returned by ``LGBM_BoosterPredictForMatSingleRowFastInit`` * \param fastConfig_handle FastConfig object handle returned by ``LGBM_BoosterPredictForMatSingleRowFastInit``
* \param data Single-row array data (no other way than row-major form). * \param data Single-row array data (no other way than row-major form).
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param[out] out_len Length of output result * \param[out] out_len Length of output result
* \param[out] out_result Pointer to array with predictions * \param[out] out_result Pointer to array with predictions
* \return 0 when it succeeds, -1 when failure happens * \return 0 when it succeeds, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fastConfig_handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fastConfig_handle,
const void* data, const void* data,
int predict_type,
int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
......
...@@ -1769,13 +1769,15 @@ int LGBM_BoosterCalcNumPredict(BoosterHandle handle, ...@@ -1769,13 +1769,15 @@ int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
struct FastConfig { struct FastConfig {
FastConfig(Booster *const booster_ptr, FastConfig(Booster *const booster_ptr,
const char *parameter, const char *parameter,
const int predict_type_,
const int data_type_, const int data_type_,
const int32_t num_cols) : booster(booster_ptr), data_type(data_type_), ncol(num_cols) { const int32_t num_cols) : booster(booster_ptr), predict_type(predict_type_), data_type(data_type_), ncol(num_cols) {
config.Set(Config::Str2Map(parameter)); config.Set(Config::Str2Map(parameter));
} }
Booster* const booster; Booster* const booster;
Config config; Config config;
const int predict_type;
const int data_type; const int data_type;
const int32_t ncol; const int32_t ncol;
}; };
...@@ -1939,6 +1941,8 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, ...@@ -1939,6 +1941,8 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
} }
int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle, int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
const int predict_type,
const int num_iteration,
const int data_type, const int data_type,
const int64_t num_col, const int64_t num_col,
const char* parameter, const char* parameter,
...@@ -1953,6 +1957,7 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle, ...@@ -1953,6 +1957,7 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
auto fastConfig_ptr = std::unique_ptr<FastConfig>(new FastConfig( auto fastConfig_ptr = std::unique_ptr<FastConfig>(new FastConfig(
reinterpret_cast<Booster*>(handle), reinterpret_cast<Booster*>(handle),
parameter, parameter,
predict_type,
data_type, data_type,
static_cast<int32_t>(num_col))); static_cast<int32_t>(num_col)));
...@@ -1960,25 +1965,25 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle, ...@@ -1960,25 +1965,25 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
omp_set_num_threads(fastConfig_ptr->config.num_threads); omp_set_num_threads(fastConfig_ptr->config.num_threads);
} }
fastConfig_ptr->booster->SetSingleRowPredictor(num_iteration, predict_type, fastConfig_ptr->config);
*out_fastConfig = fastConfig_ptr.release(); *out_fastConfig = fastConfig_ptr.release();
API_END(); API_END();
} }
int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fastConfig_handle, int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fastConfig_handle,
const void* indptr, const void* indptr,
int indptr_type, const int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int64_t nindptr, const int64_t nindptr,
int64_t nelem, const int64_t nelem,
int predict_type,
int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
FastConfig *fastConfig = reinterpret_cast<FastConfig*>(fastConfig_handle); FastConfig *fastConfig = reinterpret_cast<FastConfig*>(fastConfig_handle);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, fastConfig->data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, fastConfig->data_type, nindptr, nelem);
fastConfig->booster->PredictSingleRow(predict_type, fastConfig->ncol, fastConfig->booster->PredictSingleRow(fastConfig->predict_type, fastConfig->ncol,
get_row_fun, fastConfig->config, out_result, out_len); get_row_fun, fastConfig->config, out_result, out_len);
API_END(); API_END();
} }
...@@ -2082,6 +2087,8 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, ...@@ -2082,6 +2087,8 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
} }
int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle, int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
const int predict_type,
const int num_iteration,
const int data_type, const int data_type,
const int32_t ncol, const int32_t ncol,
const char* parameter, const char* parameter,
...@@ -2090,6 +2097,7 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle, ...@@ -2090,6 +2097,7 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
auto fastConfig_ptr = std::unique_ptr<FastConfig>(new FastConfig( auto fastConfig_ptr = std::unique_ptr<FastConfig>(new FastConfig(
reinterpret_cast<Booster*>(handle), reinterpret_cast<Booster*>(handle),
parameter, parameter,
predict_type,
data_type, data_type,
ncol)); ncol));
...@@ -2097,21 +2105,21 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle, ...@@ -2097,21 +2105,21 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
omp_set_num_threads(fastConfig_ptr->config.num_threads); omp_set_num_threads(fastConfig_ptr->config.num_threads);
} }
fastConfig_ptr->booster->SetSingleRowPredictor(num_iteration, predict_type, fastConfig_ptr->config);
*out_fastConfig = fastConfig_ptr.release(); *out_fastConfig = fastConfig_ptr.release();
API_END(); API_END();
} }
int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fastConfig_handle, int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fastConfig_handle,
const void* data, const void* data,
const int predict_type,
const int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
FastConfig *fastConfig = reinterpret_cast<FastConfig*>(fastConfig_handle); FastConfig *fastConfig = reinterpret_cast<FastConfig*>(fastConfig_handle);
// Single row in row-major format: // Single row in row-major format:
auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, fastConfig->ncol, fastConfig->data_type, 1); auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, fastConfig->ncol, fastConfig->data_type, 1);
fastConfig->booster->PredictSingleRow(predict_type, fastConfig->ncol, fastConfig->booster->PredictSingleRow(fastConfig->predict_type, fastConfig->ncol,
get_row_fun, fastConfig->config, get_row_fun, fastConfig->config,
out_result, out_len); out_result, out_len);
API_END(); API_END();
......
...@@ -109,14 +109,11 @@ ...@@ -109,14 +109,11 @@
int LGBM_BoosterPredictForMatSingleRowFastCriticalSWIG(JNIEnv *jenv, int LGBM_BoosterPredictForMatSingleRowFastCriticalSWIG(JNIEnv *jenv,
jdoubleArray data, jdoubleArray data,
FastConfigHandle handle, FastConfigHandle handle,
int predict_type,
int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
double* data0 = (double*)jenv->GetPrimitiveArrayCritical(data, 0); double* data0 = (double*)jenv->GetPrimitiveArrayCritical(data, 0);
int ret = LGBM_BoosterPredictForMatSingleRowFast(handle, data0, predict_type, int ret = LGBM_BoosterPredictForMatSingleRowFast(handle, data0, out_len, out_result);
num_iteration, out_len, out_result);
jenv->ReleasePrimitiveArrayCritical(data, data0, JNI_ABORT); jenv->ReleasePrimitiveArrayCritical(data, data0, JNI_ABORT);
...@@ -174,8 +171,6 @@ ...@@ -174,8 +171,6 @@
FastConfigHandle handle, FastConfigHandle handle,
int indptr_type, int indptr_type,
int64_t nelem, int64_t nelem,
int predict_type,
int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
// Alternatives // Alternatives
...@@ -191,7 +186,7 @@ ...@@ -191,7 +186,7 @@
int32_t ind[2] = { 0, numNonZeros }; int32_t ind[2] = { 0, numNonZeros };
int ret = LGBM_BoosterPredictForCSRSingleRowFast(handle, ind, indptr_type, indices0, values0, 2, int ret = LGBM_BoosterPredictForCSRSingleRowFast(handle, ind, indptr_type, indices0, values0, 2,
nelem, predict_type, num_iteration, out_len, out_result); nelem, out_len, out_result);
jenv->ReleasePrimitiveArrayCritical(values, values0, JNI_ABORT); jenv->ReleasePrimitiveArrayCritical(values, values0, JNI_ABORT);
jenv->ReleasePrimitiveArrayCritical(indices, indices0, JNI_ABORT); jenv->ReleasePrimitiveArrayCritical(indices, indices0, JNI_ABORT);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment