Unverified Commit b5027de3 authored by Alberto Ferreira's avatar Alberto Ferreira Committed by GitHub
Browse files

Fast single row predict API v2 (#3268)

* Fix bug introduced in PR #2992 for Fast predict

* Faster Fast predict API

* Add const to SingleRow Fast methods
parent a9f5654b
......@@ -862,6 +862,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
* Release the ``FastConfig`` by passing its handle to ``LGBM_FastConfigFree`` when no longer needed.
*
* \param handle Booster handle
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param num_col Number of columns
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
......@@ -869,6 +875,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
* \return 0 when it succeeds, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
const int predict_type,
const int num_iteration,
const int data_type,
const int64_t num_col,
const char* parameter,
......@@ -901,25 +909,17 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle h
* \param data Pointer to the data space
* \param nindptr Number of rows in the matrix + 1
* \param nelem Number of nonzero elements in the matrix
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param[out] out_len Length of output result
* \param[out] out_result Pointer to array with predictions
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fastConfig_handle,
const void* indptr,
int indptr_type,
const int indptr_type,
const int32_t* indices,
const void* data,
int64_t nindptr,
int64_t nelem,
int predict_type,
int num_iteration,
const int64_t nindptr,
const int64_t nelem,
int64_t* out_len,
double* out_result);
......@@ -1042,6 +1042,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
* Release the ``FastConfig`` by passing its handle to ``LGBM_FastConfigFree`` when no longer needed.
*
* \param handle Booster handle
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param ncol Number of columns
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
......@@ -1049,8 +1055,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
* \return 0 when it succeeds, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
int data_type,
int32_t ncol,
const int predict_type,
const int num_iteration,
const int data_type,
const int32_t ncol,
const char* parameter,
FastConfigHandle *out_fastConfig);
......@@ -1070,20 +1078,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle h
*
* \param fastConfig_handle FastConfig object handle returned by ``LGBM_BoosterPredictForMatSingleRowFastInit``
* \param data Single-row array data (no other way than row-major form).
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param[out] out_len Length of output result
* \param[out] out_result Pointer to array with predictions
* \return 0 when it succeeds, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fastConfig_handle,
const void* data,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result);
......
......@@ -1769,13 +1769,15 @@ int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
struct FastConfig {
FastConfig(Booster *const booster_ptr,
const char *parameter,
const int predict_type_,
const int data_type_,
const int32_t num_cols) : booster(booster_ptr), data_type(data_type_), ncol(num_cols) {
const int32_t num_cols) : booster(booster_ptr), predict_type(predict_type_), data_type(data_type_), ncol(num_cols) {
config.Set(Config::Str2Map(parameter));
}
Booster* const booster;
Config config;
const int predict_type;
const int data_type;
const int32_t ncol;
};
......@@ -1939,6 +1941,8 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
}
int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
const int predict_type,
const int num_iteration,
const int data_type,
const int64_t num_col,
const char* parameter,
......@@ -1953,6 +1957,7 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
auto fastConfig_ptr = std::unique_ptr<FastConfig>(new FastConfig(
reinterpret_cast<Booster*>(handle),
parameter,
predict_type,
data_type,
static_cast<int32_t>(num_col)));
......@@ -1960,25 +1965,25 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
omp_set_num_threads(fastConfig_ptr->config.num_threads);
}
fastConfig_ptr->booster->SetSingleRowPredictor(num_iteration, predict_type, fastConfig_ptr->config);
*out_fastConfig = fastConfig_ptr.release();
API_END();
}
int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fastConfig_handle,
const void* indptr,
int indptr_type,
const int indptr_type,
const int32_t* indices,
const void* data,
int64_t nindptr,
int64_t nelem,
int predict_type,
int num_iteration,
const int64_t nindptr,
const int64_t nelem,
int64_t* out_len,
double* out_result) {
API_BEGIN();
FastConfig *fastConfig = reinterpret_cast<FastConfig*>(fastConfig_handle);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, fastConfig->data_type, nindptr, nelem);
fastConfig->booster->PredictSingleRow(predict_type, fastConfig->ncol,
fastConfig->booster->PredictSingleRow(fastConfig->predict_type, fastConfig->ncol,
get_row_fun, fastConfig->config, out_result, out_len);
API_END();
}
......@@ -2082,6 +2087,8 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
}
int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
const int predict_type,
const int num_iteration,
const int data_type,
const int32_t ncol,
const char* parameter,
......@@ -2090,6 +2097,7 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
auto fastConfig_ptr = std::unique_ptr<FastConfig>(new FastConfig(
reinterpret_cast<Booster*>(handle),
parameter,
predict_type,
data_type,
ncol));
......@@ -2097,21 +2105,21 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
omp_set_num_threads(fastConfig_ptr->config.num_threads);
}
fastConfig_ptr->booster->SetSingleRowPredictor(num_iteration, predict_type, fastConfig_ptr->config);
*out_fastConfig = fastConfig_ptr.release();
API_END();
}
int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fastConfig_handle,
const void* data,
const int predict_type,
const int num_iteration,
int64_t* out_len,
double* out_result) {
API_BEGIN();
FastConfig *fastConfig = reinterpret_cast<FastConfig*>(fastConfig_handle);
// Single row in row-major format:
auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, fastConfig->ncol, fastConfig->data_type, 1);
fastConfig->booster->PredictSingleRow(predict_type, fastConfig->ncol,
fastConfig->booster->PredictSingleRow(fastConfig->predict_type, fastConfig->ncol,
get_row_fun, fastConfig->config,
out_result, out_len);
API_END();
......
......@@ -109,14 +109,11 @@
int LGBM_BoosterPredictForMatSingleRowFastCriticalSWIG(JNIEnv *jenv,
jdoubleArray data,
FastConfigHandle handle,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result) {
double* data0 = (double*)jenv->GetPrimitiveArrayCritical(data, 0);
int ret = LGBM_BoosterPredictForMatSingleRowFast(handle, data0, predict_type,
num_iteration, out_len, out_result);
int ret = LGBM_BoosterPredictForMatSingleRowFast(handle, data0, out_len, out_result);
jenv->ReleasePrimitiveArrayCritical(data, data0, JNI_ABORT);
......@@ -174,8 +171,6 @@
FastConfigHandle handle,
int indptr_type,
int64_t nelem,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result) {
// Alternatives
......@@ -191,7 +186,7 @@
int32_t ind[2] = { 0, numNonZeros };
int ret = LGBM_BoosterPredictForCSRSingleRowFast(handle, ind, indptr_type, indices0, values0, 2,
nelem, predict_type, num_iteration, out_len, out_result);
nelem, out_len, out_result);
jenv->ReleasePrimitiveArrayCritical(values, values0, JNI_ABORT);
jenv->ReleasePrimitiveArrayCritical(indices, indices0, JNI_ABORT);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment