Commit 80495ca6 authored by Guolin Ke's avatar Guolin Ke
Browse files

try fix core_dump error in c_api

parent 5442ed78
...@@ -165,12 +165,10 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename, ...@@ -165,12 +165,10 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename,
DatasetLoader loader(config.io_config, nullptr); DatasetLoader loader(config.io_config, nullptr);
loader.SetHeader(filename); loader.SetHeader(filename);
if (reference == nullptr) { if (reference == nullptr) {
*out = new std::shared_ptr<Dataset>(loader.LoadFromFile(filename)); *out = loader.LoadFromFile(filename);
} else { } else {
*out = new std::shared_ptr<Dataset>( *out = loader.LoadFromFileAlignWithOtherDataset(filename,
loader.LoadFromFileAlignWithOtherDataset(filename, reinterpret_cast<const Dataset*>(*reference));
reinterpret_cast<const std::shared_ptr<Dataset>*>(*reference)->get())
);
} }
API_END(); API_END();
} }
...@@ -180,7 +178,7 @@ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename, ...@@ -180,7 +178,7 @@ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
API_BEGIN(); API_BEGIN();
OverallConfig config; OverallConfig config;
DatasetLoader loader(config.io_config, nullptr); DatasetLoader loader(config.io_config, nullptr);
*out = new std::shared_ptr<Dataset>(loader.LoadFromBinFile(filename, 0, 1)); *out = loader.LoadFromBinFile(filename, 0, 1);
API_END(); API_END();
} }
...@@ -217,7 +215,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, ...@@ -217,7 +215,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
} else { } else {
ret.reset(new Dataset(nrow, config.io_config.num_class)); ret.reset(new Dataset(nrow, config.io_config.num_class));
ret->CopyFeatureMapperFrom( ret->CopyFeatureMapperFrom(
reinterpret_cast<const std::shared_ptr<Dataset>*>(*reference)->get(), reinterpret_cast<const Dataset*>(*reference),
config.io_config.is_enable_sparse); config.io_config.is_enable_sparse);
} }
...@@ -228,7 +226,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, ...@@ -228,7 +226,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
ret->PushOneRow(tid, i, one_row); ret->PushOneRow(tid, i, one_row);
} }
ret->FinishLoad(); ret->FinishLoad();
*out = new std::shared_ptr<Dataset>(ret.release()); *out = ret.release();
API_END(); API_END();
} }
...@@ -278,7 +276,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, ...@@ -278,7 +276,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
} else { } else {
ret.reset(new Dataset(nrow, config.io_config.num_class)); ret.reset(new Dataset(nrow, config.io_config.num_class));
ret->CopyFeatureMapperFrom( ret->CopyFeatureMapperFrom(
reinterpret_cast<const std::shared_ptr<Dataset>*>(*reference)->get(), reinterpret_cast<const Dataset*>(*reference),
config.io_config.is_enable_sparse); config.io_config.is_enable_sparse);
} }
...@@ -289,7 +287,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, ...@@ -289,7 +287,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
ret->PushOneRow(tid, i, one_row); ret->PushOneRow(tid, i, one_row);
} }
ret->FinishLoad(); ret->FinishLoad();
*out = new std::shared_ptr<Dataset>(ret.release()); *out = ret.release();
API_END(); API_END();
} }
...@@ -327,7 +325,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, ...@@ -327,7 +325,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
} else { } else {
ret.reset(new Dataset(nrow, config.io_config.num_class)); ret.reset(new Dataset(nrow, config.io_config.num_class));
ret->CopyFeatureMapperFrom( ret->CopyFeatureMapperFrom(
reinterpret_cast<const std::shared_ptr<Dataset>*>(*reference)->get(), reinterpret_cast<const Dataset*>(*reference),
config.io_config.is_enable_sparse); config.io_config.is_enable_sparse);
} }
...@@ -338,21 +336,21 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, ...@@ -338,21 +336,21 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
ret->PushOneColumn(tid, i, one_col); ret->PushOneColumn(tid, i, one_col);
} }
ret->FinishLoad(); ret->FinishLoad();
*out = new std::shared_ptr<Dataset>(ret.release()); *out = ret.release();
API_END(); API_END();
} }
DllExport int LGBM_DatasetFree(DatesetHandle handle) { DllExport int LGBM_DatasetFree(DatesetHandle handle) {
API_BEGIN(); API_BEGIN();
delete reinterpret_cast<std::shared_ptr<Dataset>*>(handle); delete reinterpret_cast<Dataset*>(handle);
API_END(); API_END();
} }
DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle, DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
const char* filename) { const char* filename) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->get()->SaveBinaryFile(filename); dataset->SaveBinaryFile(filename);
API_END(); API_END();
} }
...@@ -362,12 +360,12 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle, ...@@ -362,12 +360,12 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
int64_t num_element, int64_t num_element,
int type) { int type) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false; bool is_success = false;
if (type == C_API_DTYPE_FLOAT32) { if (type == C_API_DTYPE_FLOAT32) {
is_success = dataset->get()->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element)); is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
} else if (type == C_API_DTYPE_INT32) { } else if (type == C_API_DTYPE_INT32) {
is_success = dataset->get()->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element)); is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
} }
if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); } if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
API_END(); API_END();
...@@ -379,12 +377,12 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle, ...@@ -379,12 +377,12 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle,
const void** out_ptr, const void** out_ptr,
int* out_type) { int* out_type) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false; bool is_success = false;
if (dataset->get()->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) { if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
*out_type = C_API_DTYPE_FLOAT32; *out_type = C_API_DTYPE_FLOAT32;
is_success = true; is_success = true;
} else if (dataset->get()->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) { } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
*out_type = C_API_DTYPE_INT32; *out_type = C_API_DTYPE_INT32;
is_success = true; is_success = true;
} }
...@@ -395,16 +393,16 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle, ...@@ -395,16 +393,16 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle,
DllExport int LGBM_DatasetGetNumData(DatesetHandle handle, DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
int64_t* out) { int64_t* out) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->get()->num_data(); *out = dataset->num_data();
API_END(); API_END();
} }
DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
int64_t* out) { int64_t* out) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->get()->num_total_features(); *out = dataset->num_total_features();
API_END(); API_END();
} }
...@@ -418,14 +416,14 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, ...@@ -418,14 +416,14 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const char* parameters, const char* parameters,
BoosterHandle* out) { BoosterHandle* out) {
API_BEGIN(); API_BEGIN();
const Dataset* p_train_data = reinterpret_cast<const std::shared_ptr<Dataset>*>(train_data)->get(); const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
std::vector<const Dataset*> p_valid_datas; std::vector<const Dataset*> p_valid_datas;
std::vector<std::string> p_valid_names; std::vector<std::string> p_valid_names;
for (int i = 0; i < n_valid_datas; ++i) { for (int i = 0; i < n_valid_datas; ++i) {
p_valid_datas.emplace_back(reinterpret_cast<const std::shared_ptr<Dataset>*>(valid_datas[i])->get()); p_valid_datas.emplace_back(reinterpret_cast<const Dataset*>(valid_datas[i]));
p_valid_names.emplace_back(valid_names[i]); p_valid_names.emplace_back(valid_names[i]);
} }
*out = new std::shared_ptr<Booster>(new Booster(p_train_data, p_valid_datas, p_valid_names, parameters)); *out = new Booster(p_train_data, p_valid_datas, p_valid_names, parameters);
API_END(); API_END();
} }
...@@ -433,19 +431,19 @@ DllExport int LGBM_BoosterLoadFromModelfile( ...@@ -433,19 +431,19 @@ DllExport int LGBM_BoosterLoadFromModelfile(
const char* filename, const char* filename,
BoosterHandle* out) { BoosterHandle* out) {
API_BEGIN(); API_BEGIN();
*out = new std::shared_ptr<Booster>(new Booster(filename)); *out = new Booster(filename);
API_END(); API_END();
} }
DllExport int LGBM_BoosterFree(BoosterHandle handle) { DllExport int LGBM_BoosterFree(BoosterHandle handle) {
API_BEGIN(); API_BEGIN();
delete reinterpret_cast<std::shared_ptr<Booster>*>(handle); delete reinterpret_cast<Booster*>(handle);
API_END(); API_END();
} }
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) { DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get(); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter()) { if (ref_booster->TrainOneIter()) {
*is_finished = 1; *is_finished = 1;
} else { } else {
...@@ -459,7 +457,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, ...@@ -459,7 +457,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* hess, const float* hess,
int* is_finished) { int* is_finished) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get(); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter(grad, hess)) { if (ref_booster->TrainOneIter(grad, hess)) {
*is_finished = 1; *is_finished = 1;
} else { } else {
...@@ -473,7 +471,7 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle, ...@@ -473,7 +471,7 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
int64_t* out_len, int64_t* out_len,
float* out_results) { float* out_results) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get(); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting(); auto boosting = ref_booster->GetBoosting();
auto result_buf = boosting->GetEvalAt(data); auto result_buf = boosting->GetEvalAt(data);
*out_len = static_cast<int64_t>(result_buf.size()); *out_len = static_cast<int64_t>(result_buf.size());
...@@ -487,7 +485,7 @@ DllExport int LGBM_BoosterGetScore(BoosterHandle handle, ...@@ -487,7 +485,7 @@ DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
int64_t* out_len, int64_t* out_len,
const float** out_result) { const float** out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get(); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
int len = 0; int len = 0;
*out_result = ref_booster->GetTrainingScore(&len); *out_result = ref_booster->GetTrainingScore(&len);
*out_len = static_cast<int64_t>(len); *out_len = static_cast<int64_t>(len);
...@@ -499,7 +497,7 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, ...@@ -499,7 +497,7 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
int64_t* out_len, int64_t* out_len,
float* out_result) { float* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get(); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting(); auto boosting = ref_booster->GetBoosting();
int len = 0; int len = 0;
boosting->GetPredictAt(data, out_result, &len); boosting->GetPredictAt(data, out_result, &len);
...@@ -514,7 +512,7 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -514,7 +512,7 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* data_filename, const char* data_filename,
const char* result_filename) { const char* result_filename) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get(); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type); ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
bool bool_data_has_header = data_has_header > 0 ? true : false; bool bool_data_has_header = data_has_header > 0 ? true : false;
ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header); ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
...@@ -534,7 +532,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -534,7 +532,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t n_used_trees, int64_t n_used_trees,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get(); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type); ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
...@@ -561,7 +559,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -561,7 +559,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
int64_t n_used_trees, int64_t n_used_trees,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get(); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type); ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
...@@ -581,7 +579,7 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle, ...@@ -581,7 +579,7 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_used_model, int num_used_model,
const char* filename) { const char* filename) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get(); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SaveModelToFile(num_used_model, filename); ref_booster->SaveModelToFile(num_used_model, filename);
API_END(); API_END();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment