"vscode:/vscode.git/clone" did not exist on "1dbf4382fd7d8cd6ce92b53c3b84aa479aabfc19"
Commit 2a5d7abe authored by Guolin Ke's avatar Guolin Ke
Browse files

simple exception handle and error message. Need RAII in the future.

parent 45cbcb05
#ifndef LIGHTGBM_C_API_H_
#define LIGHTGBM_C_API_H_
#include<cstdint>
#include <cstdint>
#include <exception>
#include <stdexcept>
#include <string>
/*!
* To avoid type conversion on large data, most of our expose interface support both for float_32 and float_64.
* Except following:
......@@ -412,4 +413,30 @@ ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indi
std::vector<double>
SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<size_t>& indices);
// exception handle and error msg
static std::string& LastErrorMsg() { static std::string err_msg("Everything is fine"); return err_msg; }
inline void LGBM_SetLastError(const char* msg) {
LastErrorMsg() = msg;
}
inline int LGBM_APIHandleException(const std::exception& ex) {
LGBM_SetLastError(ex.what());
return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
LGBM_SetLastError(ex.c_str());
return -1;
}
#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;
#endif // LIGHTGBM_C_API_H_
......@@ -5,6 +5,8 @@
#include <cstdlib>
#include <cstdarg>
#include <cstring>
#include <exception>
#include <stdexcept>
namespace LightGBM {
......@@ -62,13 +64,15 @@ public:
}
static void Fatal(const char *format, ...) {
va_list val;
char str_buf[1024];
va_start(val, format);
fprintf(stderr, "[LightGBM] [Fatal] ");
vfprintf(stderr, format, val);
fprintf(stderr, "\n");
fflush(stderr);
#ifdef _MSC_VER
vsprintf_s(str_buf, format, val);
#else
vsprintf(str_buf, format, val);
#endif
va_end(val);
exit(1);
throw std::runtime_error(std::string(str_buf));
}
private:
......
......@@ -158,18 +158,15 @@ private:
using namespace LightGBM;
DllExport const char* LGBM_GetLastError() {
return "Not error msg now, will support soon";
return LastErrorMsg().c_str();
}
DllExport int LGBM_CreateDatasetFromFile(const char* filename,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out) {
API_BEGIN();
OverallConfig config;
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
......@@ -179,17 +176,16 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename,
} else {
*out = loader.LoadFromFileAlignWithOtherDataset(filename, reinterpret_cast<const Dataset*>(*reference));
}
return 0;
API_END();
}
DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
DatesetHandle* out) {
API_BEGIN();
OverallConfig config;
DatasetLoader loader(config.io_config, nullptr);
*out = loader.LoadFromBinFile(filename, 0, 1);
return 0;
API_END();
}
DllExport int LGBM_CreateDatasetFromMat(const void* data,
......@@ -200,7 +196,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out) {
API_BEGIN();
OverallConfig config;
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
......@@ -235,7 +231,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
}
ret->FinishLoad();
*out = ret;
return 0;
API_END();
}
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
......@@ -249,7 +245,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out) {
API_BEGIN();
OverallConfig config;
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
......@@ -295,10 +291,9 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
ret->FinishLoad();
*out = ret;
return 0;
API_END();
}
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
......@@ -310,6 +305,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out) {
API_BEGIN();
OverallConfig config;
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
......@@ -342,20 +338,22 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
}
ret->FinishLoad();
*out = ret;
return 0;
API_END();
}
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
delete dataset;
return 0;
API_END();
}
DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
const char* filename) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->SaveBinaryFile(filename);
return 0;
API_END();
}
DllExport int LGBM_DatasetSetField(DatesetHandle handle,
......@@ -363,6 +361,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
const void* field_data,
int64_t num_element,
int type) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false;
if (type == C_API_DTYPE_FLOAT32) {
......@@ -370,8 +369,8 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
} else if (type == C_API_DTYPE_INT32) {
is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
}
if (is_success) { return 0; }
return -1;
if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
API_END();
}
DllExport int LGBM_DatasetGetField(DatesetHandle handle,
......@@ -379,29 +378,34 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle,
int64_t* out_len,
const void** out_ptr,
int* out_type) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false;
if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
*out_type = C_API_DTYPE_FLOAT32;
return 0;
is_success = true;
} else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
*out_type = C_API_DTYPE_INT32;
return 0;
is_success = true;
}
return -1;
if (!is_success) { throw std::runtime_error("Field not found"); }
API_END();
}
DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
int64_t* out) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_data();
return 0;
API_END();
}
DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
int64_t* out) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_total_features();
return 0;
API_END();
}
......@@ -413,6 +417,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
int n_valid_datas,
const char* parameters,
BoosterHandle* out) {
API_BEGIN();
const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
std::vector<const Dataset*> p_valid_datas;
std::vector<std::string> p_valid_names;
......@@ -421,51 +426,54 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
p_valid_names.emplace_back(valid_names[i]);
}
*out = new Booster(p_train_data, p_valid_datas, p_valid_names, parameters);
return 0;
API_END();
}
DllExport int LGBM_BoosterLoadFromModelfile(
const char* filename,
BoosterHandle* out) {
API_BEGIN();
*out = new Booster(filename);
return 0;
API_END();
}
DllExport int LGBM_BoosterFree(BoosterHandle handle) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
delete ref_booster;
return 0;
API_END();
}
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter()) {
*is_finished = 1;
} else {
*is_finished = 0;
}
return 0;
API_END();
}
DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* grad,
const float* hess,
int* is_finished) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter(grad, hess)) {
*is_finished = 1;
} else {
*is_finished = 0;
}
return 0;
API_END();
}
DllExport int LGBM_BoosterEval(BoosterHandle handle,
int data,
int64_t* out_len,
float* out_results) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
auto result_buf = boosting->GetEvalAt(data);
......@@ -473,32 +481,31 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
for (size_t i = 0; i < result_buf.size(); ++i) {
(out_results)[i] = static_cast<float>(result_buf[i]);
}
return 0;
API_END();
}
DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
int64_t* out_len,
const float** out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
int len = 0;
*out_result = ref_booster->GetTrainingScore(&len);
*out_len = static_cast<int64_t>(len);
return 0;
API_END();
}
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
int data,
int64_t* out_len,
float* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
int len = 0;
boosting->GetPredictAt(data, out_result, &len);
*out_len = static_cast<int64_t>(len);
return 0;
API_END();
}
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
......@@ -507,12 +514,12 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
int data_has_header,
const char* data_filename,
const char* result_filename) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
bool bool_data_has_header = data_has_header > 0 ? true : false;
ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
return 0;
API_END();
}
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
......@@ -527,7 +534,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int predict_type,
int64_t n_used_trees,
double* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
......@@ -542,7 +549,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
out_result[i * num_class + j] = predicton_result[j];
}
}
return 0;
API_END();
}
DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
......@@ -554,7 +561,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
int predict_type,
int64_t n_used_trees,
double* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
......@@ -568,16 +575,16 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
out_result[i * num_class + j] = predicton_result[j];
}
}
return 0;
API_END();
}
DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_used_model,
const char* filename) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SaveModelToFile(num_used_model, filename);
return 0;
API_END();
}
// ---- start of some help functions
......
#include <iostream>
#include <LightGBM/application.h>
int main(int argc, char** argv) {
try {
LightGBM::Application app(argc, argv);
app.Run();
}
catch (const std::exception& ex) {
std::cerr << "Met Exceptions:" << std::endl;
std::cerr << ex.what() << std::endl;
exit(-1);
}
catch (const std::string& ex) {
std::cerr << "Met Exceptions:" << std::endl;
std::cerr << ex << std::endl;
exit(-1);
}
catch (...) {
std::cerr << "Unknown Exceptions" << std::endl;
exit(-1);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment