Commit 79de250f authored by Guolin Ke's avatar Guolin Ke
Browse files

add draft for definition of c_api

parent cba6b245
#ifndef LIGHTGBM_C_API_H_
#define LIGHTGBM_C_API_H_
#include<cstdint>
#ifdef __cplusplus
#define DLL_EXTERN_C extern "C"
#else
#define DLL_EXTERN_C
#endif
#ifdef _MSC_VER
#define DllExport DLL_EXTERN_C __declspec(dllexport)
#else
#define DllExport DLL_EXTERN_C
#endif
typedef void* DatesetHandle;
typedef void* BoosterHandle;
/*!
* \brief get string message of the last error
* all function in this file will return 0 when success
* and -1 when an error occured,
* \return const char* error inforomation
*/
DllExport const char* LGBM_GetLastError();
// --- start Dataset inferfaces
/*!
* \brief load data set from file like the command_line LightGBM do
* \param parameters additional parameters:
has_header, label_column, weight_column, group_column, ignore_column
use format like 'has_header=true label_column=1 '..
* \param filename the name of the file
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out a loaded dataset
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromFile(const char* filename,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out);
/*!
* \brief load data set from binary file like the command_line LightGBM do
* \param filename the name of the file
* \param out a loaded dataset
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
DatesetHandle* out);
/*!
* \brief create a dataset from CSR format
* \param indptr pointer to row headers
* \param indices findex
* \param data fvalue
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_col number of columns; when it's set to 0, then guess from data
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromCSR(const uint64_t* indptr,
const unsigned* indices,
const float* data,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_col,
const DatesetHandle* reference,
DatesetHandle* out);
/*!
* \brief create a dataset from CSC format
* \param col_ptr pointer to col headers
* \param indices findex
* \param data fvalue
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromCSC(const uint64_t* col_ptr,
const unsigned* indices,
const float* data,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_row,
const DatesetHandle* reference,
DatesetHandle* out);
/*!
* \brief create dataset from dense matrix
* \param data pointer to the data space
* \param nrow number of rows
* \param ncol number columns
* \param missing which value to represent missing value
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromMat(const float* data,
uint64_t nrow,
uint64_t ncol,
float missing,
const DatesetHandle* reference,
DatesetHandle* out);
/*!
* \brief free space for dataset
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetFree(DatesetHandle* handle);
/*!
* \brief save dateset to binary file
* \param handle a instance of dataset
* \param filename file name
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
const char* filename);
/*!
* \brief set vector to a content in info
* \param handle a instance of dataset
* \param field_name field name, can be label, weight, group
* \param field_data pointer to float vector
* \param field_len number of element in field_data
* \param type float:0, int:1
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetSetField(DatesetHandle handle,
const char* field_name,
const void* field_data,
uint64_t field_len,
int type);
/*!
* \brief get float info vector from dataset
* \param handle a instance of data matrix
* \param field_name field name
* \param out_len used to set result length
* \param out_ptr pointer to the result
* \param out_type float:0, int:1
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetGetField(DatesetHandle handle,
const char* field_name,
uint64_t* out_len,
const void** out_ptr,
int* out_type);
/*!
* \brief get number of data.
* \param handle the handle to the dataset
* \param out The address to hold number of data
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
uint64_t* out);
/*!
* \brief get number of features
* \param handle the handle to the dataset
* \param out The output of number of features
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
uint64_t* out);
// --- start Booster interfaces
/*!
* \brief create an new boosting learner
* \param train_data traning data set
* \param valid_datas validation data sets
* \param valid_names names of validation data sets
* \param n_valid_datas number of validation set
* \param parameters format: 'key1=value1 key2=value2'
* \prama out handle of created Booster
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterCreate(DatesetHandle train_data,
DatesetHandle valid_datas[],
const char* valid_names[],
int n_valid_datas,
const char* parameters,
BoosterHandle* out);
/*!
* \brief load an exsiting boosting from model file
* \param filename filename of model
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterLoadFromModelfile(
const char* filename,
BoosterHandle* out);
/*!
* \brief free obj in handle
* \param handle handle to be freed
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterFree(BoosterHandle handle);
/*!
* \brief update the model in one round
* \param handle handle
* \param is_finished 1 means finised
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished);
/*!
* \brief update the model, by directly specify gradient and second order gradient,
* this can be used to support customized loss function
* \param handle handle
* \param grad gradient statistics
* \param hess second order gradient statistics
* \param is_finished 1 means finised
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
float* grad,
float* hess,
int* is_finished);
/*!
* \brief get evaluation for training data and validation datas
* \param handle handle
* \param is_eval_train >0 means need to eval trainig data
* \param out_result the string containing evaluation statistics
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterEvalCurrent(BoosterHandle handle,
int is_eval_train,
const char*** out_result);
/*!
* \brief make prediction for training data and validation datas
this can be used to support customized eval function
* \param handle handle
* \param predict_type
* 0:raw score
* 1:with sigmoid transform(if needed)
* 2:leaf index
* \param is_predict_train >0 means need to predict for training result
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterPredictCurrent(BoosterHandle handle,
int predict_type,
int is_predict_train,
const float*** out_result);
/*!
* \brief make prediction for an new data set
* \param handle handle
* \param indptr pointer to row headers
* \param indices findex
* \param data fvalue
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_col number of columns; when it's set to 0, then guess from data
* \param predict_type
* 0:raw score
* 1:with sigmoid transform(if needed)
* 2:leaf index
* \param n_used_trees number of used tree
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
const uint64_t* indptr,
const unsigned* indices,
const float* data,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_col,
int predict_type,
uint64_t n_used_trees,
const float** out_result);
/*!
* \brief make prediction for an new data set
* \param handle handle
* \param col_ptr pointer to col headers
* \param indices findex
* \param data fvalue
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data
* \param predict_type
* 0:raw score
* 1:with sigmoid transform(if needed)
* 2:leaf index
* \param n_used_trees number of used tree
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterPredictForCSC(BoosterHandle handle,
const uint64_t* col_ptr,
const unsigned* indices,
const float* data,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_row,
int predict_type,
uint64_t n_used_trees,
const float** out_result);
/*!
* \brief make prediction for an new data set
* \param handle handle
* \param data pointer to the data space
* \param nrow number of rows
* \param ncol number columns
* \param missing which value to represent missing value
* \param predict_type
* 0:raw score
* 1:with sigmoid transform(if needed)
* 2:leaf index
* \param n_used_trees number of used tree
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
const float* data,
uint64_t nrow,
uint64_t ncol,
float missing,
int predict_type,
uint64_t n_used_trees,
const float** out_result);
/*!
* \brief save model into file
* \param handle handle
* \param is_finished 1 means finised
* \param filename file name
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
int is_finished,
const char* filename);
#endif // LIGHTGBM_C_API_H_
#include <LightGBM/c_api.h>
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>
#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
namespace LightGBM {
class Booster {
public:
explicit Booster(const char* filename):
boosting_(Boosting::CreateBoosting(filename)) {
}
Booster(const Dataset* train_data,
std::vector<const Dataset*> valid_data,
std::vector<std::string> valid_names,
const char* parameters)
:train_data_(train_data), valid_datas_(valid_data) {
config_.LoadFromString(parameters);
// create boosting
if (config_.io_config.input_model.size() > 0) {
Log::Error("continued train from model is not support for c_api, \
please use continued train with input score");
}
boosting_ = Boosting::CreateBoosting(config_.boosting_type, "");
// create objective function
objective_fun_ =
ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config);
// create training metric
if (config_.boosting_config->is_provide_training_metric) {
for (auto metric_type : config_.metric_types) {
Metric* metric =
Metric::CreateMetric(metric_type, config_.metric_config);
if (metric == nullptr) { continue; }
metric->Init("training", train_data_->metadata(),
train_data_->num_data());
train_metric_.push_back(metric);
}
}
// add metric for validation data
for (size_t i = 0; i < valid_datas_.size(); ++i) {
valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) {
Metric* metric = Metric::CreateMetric(metric_type, config_.metric_config);
if (metric == nullptr) { continue; }
metric->Init(valid_names[i].c_str(),
valid_datas_[i]->metadata(),
valid_datas_[i]->num_data());
valid_metrics_.back().push_back(metric);
}
}
// initialize the objective function
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
// initialize the boosting
boosting_->Init(config_.boosting_config, train_data_, objective_fun_,
ConstPtrInVectorWarpper<Metric>(train_metric_));
// add validation data into boosting
for (size_t i = 0; i < valid_datas_.size(); ++i) {
boosting_->AddDataset(valid_datas_[i],
ConstPtrInVectorWarpper<Metric>(valid_metrics_[i]));
}
}
~Booster() {
for (auto& metric : train_metric_) {
if (metric != nullptr) { delete metric; }
}
for (auto& metric : valid_metrics_) {
for (auto& sub_metric : metric) {
if (sub_metric != nullptr) { delete sub_metric; }
}
}
valid_metrics_.clear();
if (boosting_ != nullptr) { delete boosting_; }
if (objective_fun_ != nullptr) { delete objective_fun_; }
}
private:
Boosting* boosting_;
/*! \brief All configs */
OverallConfig config_;
/*! \brief Training data */
const Dataset* train_data_;
/*! \brief Validation data */
std::vector<const Dataset*> valid_datas_;
/*! \brief Metric for training data */
std::vector<Metric*> train_metric_;
/*! \brief Metrics for validation data */
std::vector<std::vector<Metric*>> valid_metrics_;
/*! \brief Training objective function */
ObjectiveFunction* objective_fun_;
};
}
...@@ -159,6 +159,7 @@ ...@@ -159,6 +159,7 @@
<ClInclude Include="..\include\LightGBM\bin.h" /> <ClInclude Include="..\include\LightGBM\bin.h" />
<ClInclude Include="..\include\LightGBM\boosting.h" /> <ClInclude Include="..\include\LightGBM\boosting.h" />
<ClInclude Include="..\include\LightGBM\config.h" /> <ClInclude Include="..\include\LightGBM\config.h" />
<ClInclude Include="..\include\LightGBM\c_api.h" />
<ClInclude Include="..\include\LightGBM\dataset.h" /> <ClInclude Include="..\include\LightGBM\dataset.h" />
<ClInclude Include="..\include\LightGBM\feature.h" /> <ClInclude Include="..\include\LightGBM\feature.h" />
<ClInclude Include="..\include\LightGBM\meta.h" /> <ClInclude Include="..\include\LightGBM\meta.h" />
...@@ -203,6 +204,7 @@ ...@@ -203,6 +204,7 @@
<ClCompile Include="..\src\application\application.cpp" /> <ClCompile Include="..\src\application\application.cpp" />
<ClCompile Include="..\src\boosting\boosting.cpp" /> <ClCompile Include="..\src\boosting\boosting.cpp" />
<ClCompile Include="..\src\boosting\gbdt.cpp" /> <ClCompile Include="..\src\boosting\gbdt.cpp" />
<ClCompile Include="..\src\c_api.cpp" />
<ClCompile Include="..\src\io\bin.cpp" /> <ClCompile Include="..\src\io\bin.cpp" />
<ClCompile Include="..\src\io\config.cpp" /> <ClCompile Include="..\src\io\config.cpp" />
<ClCompile Include="..\src\io\dataset.cpp" /> <ClCompile Include="..\src\io\dataset.cpp" />
......
...@@ -165,6 +165,9 @@ ...@@ -165,6 +165,9 @@
<ClInclude Include="..\include\LightGBM\utils\lru_pool.h"> <ClInclude Include="..\include\LightGBM\utils\lru_pool.h">
<Filter>include\LightGBM\utils</Filter> <Filter>include\LightGBM\utils</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\include\LightGBM\c_api.h">
<Filter>include\LightGBM</Filter>
</ClInclude>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="..\src\application\application.cpp"> <ClCompile Include="..\src\application\application.cpp">
...@@ -230,5 +233,8 @@ ...@@ -230,5 +233,8 @@
<ClCompile Include="..\src\main.cpp"> <ClCompile Include="..\src\main.cpp">
<Filter>src</Filter> <Filter>src</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="..\src\c_api.cpp">
<Filter>src</Filter>
</ClCompile>
</ItemGroup> </ItemGroup>
</Project> </Project>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment