"python-package/vscode:/vscode.git/clone" did not exist on "044a39d0ee7b3d3d91414f3ada281f271a90587f"
Commit 9d375069 authored by Guolin Ke's avatar Guolin Ke
Browse files

add LGBM_DatasetCreateFromSampledMat api.

parent 9f69165b
...@@ -48,13 +48,33 @@ LIGHTGBM_C_EXPORT const char* LGBM_GetLastError(); ...@@ -48,13 +48,33 @@ LIGHTGBM_C_EXPORT const char* LGBM_GetLastError();
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
const char* parameters, const char* parameters,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out); DatasetHandle* out);
/*!
* \brief create a empty dataset by sampling matrix, if num_sample_row == num_total_row, will construct this dataset.
* Need call LGBM_DatasetPushRows/LGBM_DatasetPushRowsByCSR after calling this function.
* \param data pointer to the data space
* \param data_type type of data pointer, can be C_API_DTYPE_FLOAT32 or C_API_DTYPE_FLOAT64
* \param num_sample_row number of rows
* \param ncol number columns
* \param num_total_row number of total rows
* \param parameters additional parameters
* \param out created dataset
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledMat(const void* data,
int data_type,
int32_t num_sample_row,
int32_t ncol,
int32_t num_total_row,
const char* parameters,
DatasetHandle* out);
/*! /*!
* \brief create a empty dataset by sampling csc data, if num_sample_row == num_total_row, will construct this dataset. * \brief create a empty dataset by sampling CSR data, if num_sample_row == num_total_row, will construct this dataset.
* Need call LGBM_DatasetPushRows/LGBM_DatasetPushRowsByCSR after calling this function.
* \param indptr pointer to row headers * \param indptr pointer to row headers
* \param indptr_type type of indptr, can be C_API_DTYPE_INT32 or C_API_DTYPE_INT64 * \param indptr_type type of indptr, can be C_API_DTYPE_INT32 or C_API_DTYPE_INT64
* \param indices findex * \param indices findex
...@@ -69,16 +89,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -69,16 +89,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t nindptr, int64_t nindptr,
int64_t n_sample_elem, int64_t n_sample_elem,
int64_t num_col, int64_t num_col,
int64_t num_total_row, int64_t num_total_row,
const char* parameters, const char* parameters,
DatasetHandle* out); DatasetHandle* out);
/*! /*!
* \brief create a empty dataset by reference Dataset * \brief create a empty dataset by reference Dataset
...@@ -88,8 +108,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr, ...@@ -88,8 +108,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference,
int64_t num_total_row, int64_t num_total_row,
DatasetHandle* out); DatasetHandle* out);
/*! /*!
* \brief push data to existing dataset, if nrow + start_row == num_total_row, will call dataset->FinishLoad * \brief push data to existing dataset, if nrow + start_row == num_total_row, will call dataset->FinishLoad
...@@ -102,11 +122,11 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc ...@@ -102,11 +122,11 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset, LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
const void* data, const void* data,
int data_type, int data_type,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
int32_t start_row); int32_t start_row);
/*! /*!
* \brief push data to existing dataset, if nrow + start_row == num_total_row, will call dataset->FinishLoad * \brief push data to existing dataset, if nrow + start_row == num_total_row, will call dataset->FinishLoad
...@@ -123,15 +143,15 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset, ...@@ -123,15 +143,15 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset, LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
const void* indptr, const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t nindptr, int64_t nindptr,
int64_t nelem, int64_t nelem,
int64_t num_col, int64_t num_col,
int64_t start_row); int64_t start_row);
/*! /*!
* \brief create a dataset from CSR format * \brief create a dataset from CSR format
...@@ -149,16 +169,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset, ...@@ -149,16 +169,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t nindptr, int64_t nindptr,
int64_t nelem, int64_t nelem,
int64_t num_col, int64_t num_col,
const char* parameters, const char* parameters,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out); DatasetHandle* out);
/*! /*!
* \brief create a dataset from CSC format * \brief create a dataset from CSC format
...@@ -176,16 +196,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -176,16 +196,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int col_ptr_type, int col_ptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t ncol_ptr, int64_t ncol_ptr,
int64_t nelem, int64_t nelem,
int64_t num_row, int64_t num_row,
const char* parameters, const char* parameters,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out); DatasetHandle* out);
/*! /*!
* \brief create dataset from dense matrix * \brief create dataset from dense matrix
...@@ -200,13 +220,13 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -200,13 +220,13 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
int data_type, int data_type,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
const char* parameters, const char* parameters,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out); DatasetHandle* out);
/*! /*!
* \brief Create subset of a data * \brief Create subset of a data
...@@ -263,7 +283,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle); ...@@ -263,7 +283,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle);
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle, LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
const char* filename); const char* filename);
/*! /*!
* \brief set vector to a content in info * \brief set vector to a content in info
...@@ -277,10 +297,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle, ...@@ -277,10 +297,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
const char* field_name, const char* field_name,
const void* field_data, const void* field_data,
int num_element, int num_element,
int type); int type);
/*! /*!
* \brief get info vector from dataset * \brief get info vector from dataset
...@@ -292,10 +312,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, ...@@ -292,10 +312,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle, LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
const char* field_name, const char* field_name,
int* out_len, int* out_len,
const void** out_ptr, const void** out_ptr,
int* out_type); int* out_type);
/*! /*!
* \brief get number of data. * \brief get number of data.
...@@ -304,7 +324,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle, ...@@ -304,7 +324,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle, LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
int* out); int* out);
/*! /*!
* \brief get number of features * \brief get number of features
...@@ -313,7 +333,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle, ...@@ -313,7 +333,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle, LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
int* out); int* out);
// --- start Booster interfaces // --- start Booster interfaces
...@@ -325,8 +345,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle, ...@@ -325,8 +345,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data, LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
const char* parameters, const char* parameters,
BoosterHandle* out); BoosterHandle* out);
/*! /*!
* \brief load an existing boosting from model file * \brief load an existing boosting from model file
...@@ -366,7 +386,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle); ...@@ -366,7 +386,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle);
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
BoosterHandle other_handle); BoosterHandle other_handle);
/*! /*!
* \brief Add new validation to booster * \brief Add new validation to booster
...@@ -375,7 +395,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle, ...@@ -375,7 +395,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatasetHandle valid_data); const DatasetHandle valid_data);
/*! /*!
* \brief Reset training data for booster * \brief Reset training data for booster
...@@ -384,7 +404,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle, ...@@ -384,7 +404,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
const DatasetHandle train_data); const DatasetHandle train_data);
/*! /*!
* \brief Reset config for current booster * \brief Reset config for current booster
...@@ -420,9 +440,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_fi ...@@ -420,9 +440,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_fi
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* grad, const float* grad,
const float* hess, const float* hess,
int* is_finished); int* is_finished);
/*! /*!
* \brief Rollback one iteration * \brief Rollback one iteration
...@@ -479,9 +499,9 @@ Note: 1. you should call LGBM_BoosterGetEvalNames first to get the name of evalu ...@@ -479,9 +499,9 @@ Note: 1. you should call LGBM_BoosterGetEvalNames first to get the name of evalu
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
int data_idx, int data_idx,
int* out_len, int* out_len,
double* out_results); double* out_results);
/*! /*!
* \brief Get number of predict for inner dataset * \brief Get number of predict for inner dataset
...@@ -493,8 +513,8 @@ Note: should pre-allocate memory for out_result, its length is equal to num_cla ...@@ -493,8 +513,8 @@ Note: should pre-allocate memory for out_result, its length is equal to num_cla
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
int data_idx, int data_idx,
int64_t* out_len); int64_t* out_len);
/*! /*!
* \brief Get prediction for training data and validation data * \brief Get prediction for training data and validation data
...@@ -507,9 +527,9 @@ Note: should pre-allocate memory for out_result, its length is equal to num_cla ...@@ -507,9 +527,9 @@ Note: should pre-allocate memory for out_result, its length is equal to num_cla
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
int data_idx, int data_idx,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
/*! /*!
* \brief make prediction for file * \brief make prediction for file
...@@ -525,11 +545,11 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle, ...@@ -525,11 +545,11 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* data_filename, const char* data_filename,
int data_has_header, int data_has_header,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const char* result_filename); const char* result_filename);
/*! /*!
* \brief Get number of prediction * \brief Get number of prediction
...@@ -544,10 +564,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -544,10 +564,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
int num_row, int num_row,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
int64_t* out_len); int64_t* out_len);
/*! /*!
* \brief make prediction for an new data set * \brief make prediction for an new data set
...@@ -573,18 +593,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle, ...@@ -573,18 +593,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
const void* indptr, const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t nindptr, int64_t nindptr,
int64_t nelem, int64_t nelem,
int64_t num_col, int64_t num_col,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
/*! /*!
* \brief make prediction for an new data set * \brief make prediction for an new data set
...@@ -610,18 +630,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -610,18 +630,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
const void* col_ptr, const void* col_ptr,
int col_ptr_type, int col_ptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t ncol_ptr, int64_t ncol_ptr,
int64_t nelem, int64_t nelem,
int64_t num_row, int64_t num_row,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
/*! /*!
* \brief make prediction for an new data set * \brief make prediction for an new data set
...@@ -644,15 +664,15 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -644,15 +664,15 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
const void* data, const void* data,
int data_type, int data_type,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result); double* out_result);
/*! /*!
* \brief save model into file * \brief save model into file
...@@ -662,8 +682,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -662,8 +682,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_iteration, int num_iteration,
const char* filename); const char* filename);
/*! /*!
* \brief save model to string * \brief save model to string
...@@ -675,10 +695,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle, ...@@ -675,10 +695,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int num_iteration, int num_iteration,
int buffer_len, int buffer_len,
int* out_len, int* out_len,
char* out_str); char* out_str);
/*! /*!
* \brief dump model to json * \brief dump model to json
...@@ -690,10 +710,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle, ...@@ -690,10 +710,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration, int num_iteration,
int buffer_len, int buffer_len,
int* out_len, int* out_len,
char* out_str); char* out_str);
/*! /*!
* \brief Get leaf value * \brief Get leaf value
...@@ -704,9 +724,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle, ...@@ -704,9 +724,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
int tree_idx, int tree_idx,
int leaf_idx, int leaf_idx,
double* out_val); double* out_val);
/*! /*!
* \brief Set leaf value * \brief Set leaf value
...@@ -717,9 +737,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle, ...@@ -717,9 +737,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int tree_idx, int tree_idx,
int leaf_idx, int leaf_idx,
double val); double val);
#if defined(_MSC_VER) #if defined(_MSC_VER)
// exception handle and error msg // exception handle and error msg
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
} }
Booster(const Dataset* train_data, Booster(const Dataset* train_data,
const char* parameters) { const char* parameters) {
auto param = ConfigBase::Str2Map(parameters); auto param = ConfigBase::Str2Map(parameters);
config_.Set(param); config_.Set(param);
if (config_.num_threads > 0) { if (config_.num_threads > 0) {
...@@ -52,7 +52,7 @@ public: ...@@ -52,7 +52,7 @@ public:
// initialize the boosting // initialize the boosting
boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(), boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
ResetTrainingData(train_data); ResetTrainingData(train_data);
} }
...@@ -71,7 +71,7 @@ public: ...@@ -71,7 +71,7 @@ public:
train_data_ = train_data; train_data_ = train_data;
// create objective function // create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config)); config_.objective_config));
if (objective_fun_ == nullptr) { if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective function"); Log::Warning("Using self-defined objective function");
} }
...@@ -92,7 +92,7 @@ public: ...@@ -92,7 +92,7 @@ public:
train_metric_.shrink_to_fit(); train_metric_.shrink_to_fit();
// reset the boosting // reset the boosting
boosting_->ResetTrainingData(&config_.boosting_config, train_data_, boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
} }
void ResetConfig(const char* parameters) { void ResetConfig(const char* parameters) {
...@@ -116,7 +116,7 @@ public: ...@@ -116,7 +116,7 @@ public:
if (param.count("objective")) { if (param.count("objective")) {
// create objective function // create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config)); config_.objective_config));
if (objective_fun_ == nullptr) { if (objective_fun_ == nullptr) {
Log::Warning("Using self-defined objective function"); Log::Warning("Using self-defined objective function");
} }
...@@ -127,7 +127,7 @@ public: ...@@ -127,7 +127,7 @@ public:
} }
boosting_->ResetTrainingData(&config_.boosting_config, train_data_, boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
} }
...@@ -142,7 +142,7 @@ public: ...@@ -142,7 +142,7 @@ public:
} }
valid_metrics_.back().shrink_to_fit(); valid_metrics_.back().shrink_to_fit();
boosting_->AddValidDataset(valid_data, boosting_->AddValidDataset(valid_data,
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back())); Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
} }
bool TrainOneIter() { bool TrainOneIter() {
...@@ -266,13 +266,13 @@ RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int d ...@@ -266,13 +266,13 @@ RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int d
std::function<std::vector<std::pair<int, double>>(int idx)> std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const void* data, int data_type, int64_t nindptr, int64_t nelem); const void* data, int data_type, int64_t nindptr, int64_t nelem);
// Row iterator of on column for CSC matrix // Row iterator of on column for CSC matrix
class CSC_RowIterator { class CSC_RowIterator {
public: public:
CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices, CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx); const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
~CSC_RowIterator() {} ~CSC_RowIterator() {}
// return value at idx, only can access by ascent order // return value at idx, only can access by ascent order
double Get(int idx); double Get(int idx);
...@@ -293,9 +293,9 @@ LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() { ...@@ -293,9 +293,9 @@ LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
const char* parameters, const char* parameters,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameters); auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config; IOConfig io_config;
...@@ -305,25 +305,59 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -305,25 +305,59 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
*out = loader.LoadFromFile(filename); *out = loader.LoadFromFile(filename);
} else { } else {
*out = loader.LoadFromFileAlignWithOtherDataset(filename, *out = loader.LoadFromFileAlignWithOtherDataset(filename,
reinterpret_cast<const Dataset*>(reference)); reinterpret_cast<const Dataset*>(reference));
} }
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledMat(const void* data,
int data_type,
int32_t num_sample_row,
int32_t ncol,
int32_t num_total_row,
const char* parameters,
DatasetHandle* out) {
if (num_sample_row == num_total_row) {
return LGBM_DatasetCreateFromMat(data, data_type, num_total_row, ncol, 1, parameters, nullptr, out);
} else {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
auto get_row_fun = RowFunctionFromDenseMatric(data, num_sample_row, ncol, data_type, 1);
std::vector<std::vector<double>> sample_values(ncol);
std::vector<std::vector<int>> sample_idx(ncol);
for (int i = 0; i < num_sample_row; ++i) {
auto row = get_row_fun(i);
for (size_t idx = 0; idx < row.size(); ++idx) {
if (std::fabs(row[idx]) > kEpsilon) {
sample_values[idx].emplace_back(row[idx]);
sample_idx[idx].emplace_back(i);
}
}
}
DatasetLoader loader(io_config, nullptr, 1, nullptr);
*out = loader.CostructFromSampleData(sample_values, sample_idx,
num_sample_row,
static_cast<data_size_t>(num_total_row));
API_END();
}
}
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t nindptr, int64_t nindptr,
int64_t n_sample_elem, int64_t n_sample_elem,
int64_t num_col, int64_t num_col,
int64_t num_total_row, int64_t num_total_row,
const char* parameters, const char* parameters,
DatasetHandle* out) { DatasetHandle* out) {
if (nindptr - 1 == num_total_row) { if (nindptr - 1 == num_total_row) {
return LGBM_DatasetCreateFromCSR(indptr, indptr_type, indices, data, return LGBM_DatasetCreateFromCSR(indptr, indptr_type, indices, data,
data_type, nindptr, n_sample_elem, num_col, parameters, nullptr, out); data_type, nindptr, n_sample_elem, num_col, parameters, nullptr, out);
} else { } else {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameters); auto param = ConfigBase::Str2Map(parameters);
...@@ -349,15 +383,15 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr, ...@@ -349,15 +383,15 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
CHECK(num_col >= static_cast<int>(sample_values.size())); CHECK(num_col >= static_cast<int>(sample_values.size()));
DatasetLoader loader(io_config, nullptr, 1, nullptr); DatasetLoader loader(io_config, nullptr, 1, nullptr);
*out = loader.CostructFromSampleData(sample_values, sample_idx, *out = loader.CostructFromSampleData(sample_values, sample_idx,
num_sample_row, num_sample_row,
static_cast<data_size_t>(num_total_row)); static_cast<data_size_t>(num_total_row));
API_END(); API_END();
} }
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference,
int64_t num_total_row, int64_t num_total_row,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
std::unique_ptr<Dataset> ret; std::unique_ptr<Dataset> ret;
ret.reset(new Dataset(static_cast<data_size_t>(num_total_row))); ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
...@@ -367,11 +401,11 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc ...@@ -367,11 +401,11 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset, LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
const void* data, const void* data,
int data_type, int data_type,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
int32_t start_row) { int32_t start_row) {
API_BEGIN(); API_BEGIN();
auto p_dataset = reinterpret_cast<Dataset*>(dataset); auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1); auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
...@@ -388,15 +422,15 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset, ...@@ -388,15 +422,15 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset, LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
const void* indptr, const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t nindptr, int64_t nindptr,
int64_t nelem, int64_t nelem,
int64_t, int64_t,
int64_t start_row) { int64_t start_row) {
API_BEGIN(); API_BEGIN();
auto p_dataset = reinterpret_cast<Dataset*>(dataset); auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
...@@ -406,7 +440,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset, ...@@ -406,7 +440,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
auto one_row = get_row_fun(i); auto one_row = get_row_fun(i);
p_dataset->PushOneRow(tid, p_dataset->PushOneRow(tid,
static_cast<data_size_t>(start_row + i), one_row); static_cast<data_size_t>(start_row + i), one_row);
} }
if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) { if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
p_dataset->FinishLoad(); p_dataset->FinishLoad();
...@@ -415,13 +449,13 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset, ...@@ -415,13 +449,13 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
int data_type, int data_type,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
const char* parameters, const char* parameters,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameters); auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config; IOConfig io_config;
...@@ -465,16 +499,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data, ...@@ -465,16 +499,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t nindptr, int64_t nindptr,
int64_t nelem, int64_t nelem,
int64_t num_col, int64_t num_col,
const char* parameters, const char* parameters,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameters); auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config; IOConfig io_config;
...@@ -524,16 +558,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -524,16 +558,16 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int col_ptr_type, int col_ptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t ncol_ptr, int64_t ncol_ptr,
int64_t nelem, int64_t nelem,
int64_t num_row, int64_t num_row,
const char* parameters, const char* parameters,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
auto param = ConfigBase::Str2Map(parameters); auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config; IOConfig io_config;
...@@ -547,7 +581,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -547,7 +581,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_values(ncol_ptr - 1); std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
std::vector<std::vector<int>> sample_idx(ncol_ptr - 1); std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) { for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i); CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
for (int j = 0; j < sample_cnt; j++) { for (int j = 0; j < sample_cnt; j++) {
...@@ -641,7 +675,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) { ...@@ -641,7 +675,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle, LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
const char* filename) { const char* filename) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->SaveBinaryFile(filename); dataset->SaveBinaryFile(filename);
...@@ -649,10 +683,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle, ...@@ -649,10 +683,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
const char* field_name, const char* field_name,
const void* field_data, const void* field_data,
int num_element, int num_element,
int type) { int type) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false; bool is_success = false;
...@@ -668,10 +702,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, ...@@ -668,10 +702,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle, LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
const char* field_name, const char* field_name,
int* out_len, int* out_len,
const void** out_ptr, const void** out_ptr,
int* out_type) { int* out_type) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false; bool is_success = false;
...@@ -691,7 +725,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle, ...@@ -691,7 +725,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle, LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
int* out) { int* out) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_data(); *out = dataset->num_data();
...@@ -699,7 +733,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle, ...@@ -699,7 +733,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle, LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
int* out) { int* out) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_total_features(); *out = dataset->num_total_features();
...@@ -709,8 +743,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle, ...@@ -709,8 +743,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
// ---- start of booster // ---- start of booster
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data, LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
const char* parameters, const char* parameters,
BoosterHandle* out) { BoosterHandle* out) {
API_BEGIN(); API_BEGIN();
const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data); const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters)); auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
...@@ -748,7 +782,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) { ...@@ -748,7 +782,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
BoosterHandle other_handle) { BoosterHandle other_handle) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle); Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
...@@ -757,7 +791,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle, ...@@ -757,7 +791,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatasetHandle valid_data) { const DatasetHandle valid_data) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data); const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
...@@ -766,7 +800,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle, ...@@ -766,7 +800,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
const DatasetHandle train_data) { const DatasetHandle train_data) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data); const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
...@@ -800,9 +834,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_fi ...@@ -800,9 +834,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_fi
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* grad, const float* grad,
const float* hess, const float* hess,
int* is_finished) { int* is_finished) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter(grad, hess)) { if (ref_booster->TrainOneIter(grad, hess)) {
...@@ -856,9 +890,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_l ...@@ -856,9 +890,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_l
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
int data_idx, int data_idx,
int* out_len, int* out_len,
double* out_results) { double* out_results) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting(); auto boosting = ref_booster->GetBoosting();
...@@ -871,8 +905,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle, ...@@ -871,8 +905,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
int data_idx, int data_idx,
int64_t* out_len) { int64_t* out_len) {
API_BEGIN(); API_BEGIN();
auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting(); auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
*out_len = boosting->GetNumPredictAt(data_idx); *out_len = boosting->GetNumPredictAt(data_idx);
...@@ -880,9 +914,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle, ...@@ -880,9 +914,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
int data_idx, int data_idx,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->GetPredictAt(data_idx, out_result, out_len); ref_booster->GetPredictAt(data_idx, out_result, out_len);
...@@ -890,11 +924,11 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle, ...@@ -890,11 +924,11 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* data_filename, const char* data_filename,
int data_has_header, int data_has_header,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const char* result_filename) { const char* result_filename) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type); auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
...@@ -917,10 +951,10 @@ int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t n ...@@ -917,10 +951,10 @@ int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t n
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
int num_row, int num_row,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
int64_t* out_len) { int64_t* out_len) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration)); *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
...@@ -928,18 +962,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle, ...@@ -928,18 +962,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
const void* indptr, const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t nindptr, int64_t nindptr,
int64_t nelem, int64_t nelem,
int64_t, int64_t,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type); auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
...@@ -959,18 +993,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -959,18 +993,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
const void* col_ptr, const void* col_ptr,
int col_ptr_type, int col_ptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
int data_type, int data_type,
int64_t ncol_ptr, int64_t ncol_ptr,
int64_t nelem, int64_t nelem,
int64_t num_row, int64_t num_row,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type); auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
...@@ -978,7 +1012,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -978,7 +1012,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int ncol = static_cast<int>(ncol_ptr - 1); int ncol = static_cast<int>(ncol_ptr - 1);
Threading::For<int64_t>(0, num_row, Threading::For<int64_t>(0, num_row,
[&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem] [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
(int, data_size_t start, data_size_t end) { (int, data_size_t start, data_size_t end) {
std::vector<CSC_RowIterator> iterators; std::vector<CSC_RowIterator> iterators;
for (int j = 0; j < ncol; ++j) { for (int j = 0; j < ncol; ++j) {
...@@ -1004,15 +1038,15 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -1004,15 +1038,15 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
const void* data, const void* data,
int data_type, int data_type,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
int is_row_major, int is_row_major,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type); auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
...@@ -1031,8 +1065,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -1031,8 +1065,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_iteration, int num_iteration,
const char* filename) { const char* filename) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SaveModelToFile(num_iteration, filename); ref_booster->SaveModelToFile(num_iteration, filename);
...@@ -1040,10 +1074,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle, ...@@ -1040,10 +1074,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int num_iteration, int num_iteration,
int buffer_len, int buffer_len,
int* out_len, int* out_len,
char* out_str) { char* out_str) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->SaveModelToString(num_iteration); std::string model = ref_booster->SaveModelToString(num_iteration);
...@@ -1055,10 +1089,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle, ...@@ -1055,10 +1089,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration, int num_iteration,
int buffer_len, int buffer_len,
int* out_len, int* out_len,
char* out_str) { char* out_str) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->DumpModel(num_iteration); std::string model = ref_booster->DumpModel(num_iteration);
...@@ -1070,9 +1104,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle, ...@@ -1070,9 +1104,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
int tree_idx, int tree_idx,
int leaf_idx, int leaf_idx,
double* out_val) { double* out_val) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx)); *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
...@@ -1080,9 +1114,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle, ...@@ -1080,9 +1114,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int tree_idx, int tree_idx,
int leaf_idx, int leaf_idx,
double val) { double val) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SetLeafValue(tree_idx, leaf_idx, val); ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
...@@ -1096,7 +1130,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1096,7 +1130,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
if (data_type == C_API_DTYPE_FLOAT32) { if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data); const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) { if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + num_col * row_idx; auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
...@@ -1105,7 +1139,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1105,7 +1139,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret; return ret;
}; };
} else { } else {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx)); ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
...@@ -1116,7 +1150,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1116,7 +1150,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
} else if (data_type == C_API_DTYPE_FLOAT64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data); const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) { if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + num_col * row_idx; auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
...@@ -1125,7 +1159,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1125,7 +1159,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret; return ret;
}; };
} else { } else {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx)); ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
...@@ -1141,7 +1175,7 @@ std::function<std::vector<std::pair<int, double>>(int row_idx)> ...@@ -1141,7 +1175,7 @@ std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) { RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major); auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
if (inner_function != nullptr) { if (inner_function != nullptr) {
return [inner_function](int row_idx) { return [inner_function] (int row_idx) {
auto raw_values = inner_function(row_idx); auto raw_values = inner_function(row_idx);
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) { for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
...@@ -1161,7 +1195,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1161,7 +1195,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const float* data_ptr = reinterpret_cast<const float*>(data); const float* data_ptr = reinterpret_cast<const float*>(data);
if (indptr_type == C_API_DTYPE_INT32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1172,7 +1206,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1172,7 +1206,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
}; };
} else if (indptr_type == C_API_DTYPE_INT64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1186,7 +1220,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1186,7 +1220,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const double* data_ptr = reinterpret_cast<const double*>(data); const double* data_ptr = reinterpret_cast<const double*>(data);
if (indptr_type == C_API_DTYPE_INT32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1197,7 +1231,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1197,7 +1231,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
}; };
} else if (indptr_type == C_API_DTYPE_INT64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1220,7 +1254,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1220,7 +1254,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1233,7 +1267,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1233,7 +1267,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1249,7 +1283,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1249,7 +1283,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1262,7 +1296,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1262,7 +1296,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1277,7 +1311,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1277,7 +1311,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
} }
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices, CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) { const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx); iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment