Commit 2a5d7abe authored by Guolin Ke's avatar Guolin Ke
Browse files

simple exception handle and error message. Need RAII in the future.

parent 45cbcb05
#ifndef LIGHTGBM_C_API_H_ #ifndef LIGHTGBM_C_API_H_
#define LIGHTGBM_C_API_H_ #define LIGHTGBM_C_API_H_
#include<cstdint> #include <cstdint>
#include <exception>
#include <stdexcept>
#include <string>
/*! /*!
* To avoid type conversion on large data, most of our expose interface support both for float_32 and float_64. * To avoid type conversion on large data, most of our expose interface support both for float_32 and float_64.
* Except following: * Except following:
...@@ -412,4 +413,30 @@ ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indi ...@@ -412,4 +413,30 @@ ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indi
std::vector<double> std::vector<double>
SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<size_t>& indices); SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<size_t>& indices);
// exception handle and error msg
static std::string& LastErrorMsg() { static std::string err_msg("Everything is fine"); return err_msg; }
inline void LGBM_SetLastError(const char* msg) {
LastErrorMsg() = msg;
}
inline int LGBM_APIHandleException(const std::exception& ex) {
LGBM_SetLastError(ex.what());
return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
LGBM_SetLastError(ex.c_str());
return -1;
}
#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;
#endif // LIGHTGBM_C_API_H_ #endif // LIGHTGBM_C_API_H_
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <cstdlib> #include <cstdlib>
#include <cstdarg> #include <cstdarg>
#include <cstring> #include <cstring>
#include <exception>
#include <stdexcept>
namespace LightGBM { namespace LightGBM {
...@@ -62,13 +64,15 @@ public: ...@@ -62,13 +64,15 @@ public:
} }
static void Fatal(const char *format, ...) { static void Fatal(const char *format, ...) {
va_list val; va_list val;
char str_buf[1024];
va_start(val, format); va_start(val, format);
fprintf(stderr, "[LightGBM] [Fatal] "); #ifdef _MSC_VER
vfprintf(stderr, format, val); vsprintf_s(str_buf, format, val);
fprintf(stderr, "\n"); #else
fflush(stderr); vsprintf(str_buf, format, val);
#endif
va_end(val); va_end(val);
exit(1); throw std::runtime_error(std::string(str_buf));
} }
private: private:
......
...@@ -158,18 +158,15 @@ private: ...@@ -158,18 +158,15 @@ private:
using namespace LightGBM; using namespace LightGBM;
DllExport const char* LGBM_GetLastError() { DllExport const char* LGBM_GetLastError() {
return "Not error msg now, will support soon"; return LastErrorMsg().c_str();
} }
DllExport int LGBM_CreateDatasetFromFile(const char* filename, DllExport int LGBM_CreateDatasetFromFile(const char* filename,
const char* parameters, const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out) { DatesetHandle* out) {
API_BEGIN();
OverallConfig config; OverallConfig config;
config.LoadFromString(parameters); config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr); DatasetLoader loader(config.io_config, nullptr);
...@@ -179,17 +176,16 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename, ...@@ -179,17 +176,16 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename,
} else { } else {
*out = loader.LoadFromFileAlignWithOtherDataset(filename, reinterpret_cast<const Dataset*>(*reference)); *out = loader.LoadFromFileAlignWithOtherDataset(filename, reinterpret_cast<const Dataset*>(*reference));
} }
return 0; API_END();
} }
DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename, DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
DatesetHandle* out) { DatesetHandle* out) {
API_BEGIN();
OverallConfig config; OverallConfig config;
DatasetLoader loader(config.io_config, nullptr); DatasetLoader loader(config.io_config, nullptr);
*out = loader.LoadFromBinFile(filename, 0, 1); *out = loader.LoadFromBinFile(filename, 0, 1);
return 0; API_END();
} }
DllExport int LGBM_CreateDatasetFromMat(const void* data, DllExport int LGBM_CreateDatasetFromMat(const void* data,
...@@ -200,7 +196,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, ...@@ -200,7 +196,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
const char* parameters, const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out) { DatesetHandle* out) {
API_BEGIN();
OverallConfig config; OverallConfig config;
config.LoadFromString(parameters); config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr); DatasetLoader loader(config.io_config, nullptr);
...@@ -235,7 +231,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, ...@@ -235,7 +231,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
} }
ret->FinishLoad(); ret->FinishLoad();
*out = ret; *out = ret;
return 0; API_END();
} }
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
...@@ -249,7 +245,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, ...@@ -249,7 +245,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
const char* parameters, const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out) { DatesetHandle* out) {
API_BEGIN();
OverallConfig config; OverallConfig config;
config.LoadFromString(parameters); config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr); DatasetLoader loader(config.io_config, nullptr);
...@@ -295,10 +291,9 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, ...@@ -295,10 +291,9 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
ret->FinishLoad(); ret->FinishLoad();
*out = ret; *out = ret;
return 0; API_END();
} }
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
int col_ptr_type, int col_ptr_type,
const int32_t* indices, const int32_t* indices,
...@@ -310,6 +305,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, ...@@ -310,6 +305,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
const char* parameters, const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out) { DatesetHandle* out) {
API_BEGIN();
OverallConfig config; OverallConfig config;
config.LoadFromString(parameters); config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr); DatasetLoader loader(config.io_config, nullptr);
...@@ -342,20 +338,22 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, ...@@ -342,20 +338,22 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
} }
ret->FinishLoad(); ret->FinishLoad();
*out = ret; *out = ret;
return 0; API_END();
} }
DllExport int LGBM_DatasetFree(DatesetHandle handle) { DllExport int LGBM_DatasetFree(DatesetHandle handle) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
delete dataset; delete dataset;
return 0; API_END();
} }
DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle, DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
const char* filename) { const char* filename) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->SaveBinaryFile(filename); dataset->SaveBinaryFile(filename);
return 0; API_END();
} }
DllExport int LGBM_DatasetSetField(DatesetHandle handle, DllExport int LGBM_DatasetSetField(DatesetHandle handle,
...@@ -363,6 +361,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle, ...@@ -363,6 +361,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
const void* field_data, const void* field_data,
int64_t num_element, int64_t num_element,
int type) { int type) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false; bool is_success = false;
if (type == C_API_DTYPE_FLOAT32) { if (type == C_API_DTYPE_FLOAT32) {
...@@ -370,8 +369,8 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle, ...@@ -370,8 +369,8 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
} else if (type == C_API_DTYPE_INT32) { } else if (type == C_API_DTYPE_INT32) {
is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element)); is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
} }
if (is_success) { return 0; } if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
return -1; API_END();
} }
DllExport int LGBM_DatasetGetField(DatesetHandle handle, DllExport int LGBM_DatasetGetField(DatesetHandle handle,
...@@ -379,29 +378,34 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle, ...@@ -379,29 +378,34 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle,
int64_t* out_len, int64_t* out_len,
const void** out_ptr, const void** out_ptr,
int* out_type) { int* out_type) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false;
if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) { if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
*out_type = C_API_DTYPE_FLOAT32; *out_type = C_API_DTYPE_FLOAT32;
return 0; is_success = true;
} else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) { } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
*out_type = C_API_DTYPE_INT32; *out_type = C_API_DTYPE_INT32;
return 0; is_success = true;
} }
return -1; if (!is_success) { throw std::runtime_error("Field not found"); }
API_END();
} }
DllExport int LGBM_DatasetGetNumData(DatesetHandle handle, DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
int64_t* out) { int64_t* out) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_data(); *out = dataset->num_data();
return 0; API_END();
} }
DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
int64_t* out) { int64_t* out) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_total_features(); *out = dataset->num_total_features();
return 0; API_END();
} }
...@@ -413,6 +417,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, ...@@ -413,6 +417,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
int n_valid_datas, int n_valid_datas,
const char* parameters, const char* parameters,
BoosterHandle* out) { BoosterHandle* out) {
API_BEGIN();
const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data); const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
std::vector<const Dataset*> p_valid_datas; std::vector<const Dataset*> p_valid_datas;
std::vector<std::string> p_valid_names; std::vector<std::string> p_valid_names;
...@@ -421,51 +426,54 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, ...@@ -421,51 +426,54 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
p_valid_names.emplace_back(valid_names[i]); p_valid_names.emplace_back(valid_names[i]);
} }
*out = new Booster(p_train_data, p_valid_datas, p_valid_names, parameters); *out = new Booster(p_train_data, p_valid_datas, p_valid_names, parameters);
return 0; API_END();
} }
DllExport int LGBM_BoosterLoadFromModelfile( DllExport int LGBM_BoosterLoadFromModelfile(
const char* filename, const char* filename,
BoosterHandle* out) { BoosterHandle* out) {
API_BEGIN();
*out = new Booster(filename); *out = new Booster(filename);
return 0; API_END();
} }
DllExport int LGBM_BoosterFree(BoosterHandle handle) { DllExport int LGBM_BoosterFree(BoosterHandle handle) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
delete ref_booster; delete ref_booster;
return 0; API_END();
} }
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) { DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter()) { if (ref_booster->TrainOneIter()) {
*is_finished = 1; *is_finished = 1;
} else { } else {
*is_finished = 0; *is_finished = 0;
} }
return 0; API_END();
} }
DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* grad, const float* grad,
const float* hess, const float* hess,
int* is_finished) { int* is_finished) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter(grad, hess)) { if (ref_booster->TrainOneIter(grad, hess)) {
*is_finished = 1; *is_finished = 1;
} else { } else {
*is_finished = 0; *is_finished = 0;
} }
return 0; API_END();
} }
DllExport int LGBM_BoosterEval(BoosterHandle handle, DllExport int LGBM_BoosterEval(BoosterHandle handle,
int data, int data,
int64_t* out_len, int64_t* out_len,
float* out_results) { float* out_results) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting(); auto boosting = ref_booster->GetBoosting();
auto result_buf = boosting->GetEvalAt(data); auto result_buf = boosting->GetEvalAt(data);
...@@ -473,32 +481,31 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle, ...@@ -473,32 +481,31 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
for (size_t i = 0; i < result_buf.size(); ++i) { for (size_t i = 0; i < result_buf.size(); ++i) {
(out_results)[i] = static_cast<float>(result_buf[i]); (out_results)[i] = static_cast<float>(result_buf[i]);
} }
return 0; API_END();
} }
DllExport int LGBM_BoosterGetScore(BoosterHandle handle, DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
int64_t* out_len, int64_t* out_len,
const float** out_result) { const float** out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
int len = 0; int len = 0;
*out_result = ref_booster->GetTrainingScore(&len); *out_result = ref_booster->GetTrainingScore(&len);
*out_len = static_cast<int64_t>(len); *out_len = static_cast<int64_t>(len);
API_END();
return 0;
} }
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
int data, int data,
int64_t* out_len, int64_t* out_len,
float* out_result) { float* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting(); auto boosting = ref_booster->GetBoosting();
int len = 0; int len = 0;
boosting->GetPredictAt(data, out_result, &len); boosting->GetPredictAt(data, out_result, &len);
*out_len = static_cast<int64_t>(len); *out_len = static_cast<int64_t>(len);
return 0; API_END();
} }
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
...@@ -507,12 +514,12 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -507,12 +514,12 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
int data_has_header, int data_has_header,
const char* data_filename, const char* data_filename,
const char* result_filename) { const char* result_filename) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type); ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
bool bool_data_has_header = data_has_header > 0 ? true : false; bool bool_data_has_header = data_has_header > 0 ? true : false;
ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header); ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
return 0; API_END();
} }
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
...@@ -527,7 +534,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -527,7 +534,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int predict_type, int predict_type,
int64_t n_used_trees, int64_t n_used_trees,
double* out_result) { double* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type); ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
...@@ -542,7 +549,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -542,7 +549,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
out_result[i * num_class + j] = predicton_result[j]; out_result[i * num_class + j] = predicton_result[j];
} }
} }
return 0; API_END();
} }
DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
...@@ -554,7 +561,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -554,7 +561,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
int predict_type, int predict_type,
int64_t n_used_trees, int64_t n_used_trees,
double* out_result) { double* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type); ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
...@@ -568,16 +575,16 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -568,16 +575,16 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
out_result[i * num_class + j] = predicton_result[j]; out_result[i * num_class + j] = predicton_result[j];
} }
} }
return 0; API_END();
} }
DllExport int LGBM_BoosterSaveModel(BoosterHandle handle, DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_used_model, int num_used_model,
const char* filename) { const char* filename) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SaveModelToFile(num_used_model, filename); ref_booster->SaveModelToFile(num_used_model, filename);
return 0; API_END();
} }
// ---- start of some help functions // ---- start of some help functions
......
#include <iostream>
#include <LightGBM/application.h> #include <LightGBM/application.h>
int main(int argc, char** argv) { int main(int argc, char** argv) {
LightGBM::Application app(argc, argv); try {
app.Run(); LightGBM::Application app(argc, argv);
app.Run();
}
catch (const std::exception& ex) {
std::cerr << "Met Exceptions:" << std::endl;
std::cerr << ex.what() << std::endl;
exit(-1);
}
catch (const std::string& ex) {
std::cerr << "Met Exceptions:" << std::endl;
std::cerr << ex << std::endl;
exit(-1);
}
catch (...) {
std::cerr << "Unknown Exceptions" << std::endl;
exit(-1);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment