"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "b37065dbd5510079aac341a34895f4854c782fcb"
Commit 80495ca6 authored by Guolin Ke's avatar Guolin Ke
Browse files

try fix core_dump error in c_api

parent 5442ed78
......@@ -165,12 +165,10 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename,
DatasetLoader loader(config.io_config, nullptr);
loader.SetHeader(filename);
if (reference == nullptr) {
*out = new std::shared_ptr<Dataset>(loader.LoadFromFile(filename));
*out = loader.LoadFromFile(filename);
} else {
*out = new std::shared_ptr<Dataset>(
loader.LoadFromFileAlignWithOtherDataset(filename,
reinterpret_cast<const std::shared_ptr<Dataset>*>(*reference)->get())
);
*out = loader.LoadFromFileAlignWithOtherDataset(filename,
reinterpret_cast<const Dataset*>(*reference));
}
API_END();
}
......@@ -180,7 +178,7 @@ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
API_BEGIN();
OverallConfig config;
DatasetLoader loader(config.io_config, nullptr);
*out = new std::shared_ptr<Dataset>(loader.LoadFromBinFile(filename, 0, 1));
*out = loader.LoadFromBinFile(filename, 0, 1);
API_END();
}
......@@ -217,7 +215,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
} else {
ret.reset(new Dataset(nrow, config.io_config.num_class));
ret->CopyFeatureMapperFrom(
reinterpret_cast<const std::shared_ptr<Dataset>*>(*reference)->get(),
reinterpret_cast<const Dataset*>(*reference),
config.io_config.is_enable_sparse);
}
......@@ -228,7 +226,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
ret->PushOneRow(tid, i, one_row);
}
ret->FinishLoad();
*out = new std::shared_ptr<Dataset>(ret.release());
*out = ret.release();
API_END();
}
......@@ -278,7 +276,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
} else {
ret.reset(new Dataset(nrow, config.io_config.num_class));
ret->CopyFeatureMapperFrom(
reinterpret_cast<const std::shared_ptr<Dataset>*>(*reference)->get(),
reinterpret_cast<const Dataset*>(*reference),
config.io_config.is_enable_sparse);
}
......@@ -289,7 +287,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
ret->PushOneRow(tid, i, one_row);
}
ret->FinishLoad();
*out = new std::shared_ptr<Dataset>(ret.release());
*out = ret.release();
API_END();
}
......@@ -327,7 +325,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
} else {
ret.reset(new Dataset(nrow, config.io_config.num_class));
ret->CopyFeatureMapperFrom(
reinterpret_cast<const std::shared_ptr<Dataset>*>(*reference)->get(),
reinterpret_cast<const Dataset*>(*reference),
config.io_config.is_enable_sparse);
}
......@@ -338,21 +336,21 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
ret->PushOneColumn(tid, i, one_col);
}
ret->FinishLoad();
*out = new std::shared_ptr<Dataset>(ret.release());
*out = ret.release();
API_END();
}
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
API_BEGIN();
delete reinterpret_cast<std::shared_ptr<Dataset>*>(handle);
delete reinterpret_cast<Dataset*>(handle);
API_END();
}
DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
const char* filename) {
API_BEGIN();
auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle);
dataset->get()->SaveBinaryFile(filename);
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->SaveBinaryFile(filename);
API_END();
}
......@@ -362,12 +360,12 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
int64_t num_element,
int type) {
API_BEGIN();
auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle);
auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false;
if (type == C_API_DTYPE_FLOAT32) {
is_success = dataset->get()->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
} else if (type == C_API_DTYPE_INT32) {
is_success = dataset->get()->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
}
if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
API_END();
......@@ -379,12 +377,12 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle,
const void** out_ptr,
int* out_type) {
API_BEGIN();
auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle);
auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false;
if (dataset->get()->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
*out_type = C_API_DTYPE_FLOAT32;
is_success = true;
} else if (dataset->get()->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
} else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
*out_type = C_API_DTYPE_INT32;
is_success = true;
}
......@@ -395,16 +393,16 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle,
DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
int64_t* out) {
API_BEGIN();
auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle);
*out = dataset->get()->num_data();
auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_data();
API_END();
}
DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
int64_t* out) {
API_BEGIN();
auto dataset = reinterpret_cast<std::shared_ptr<Dataset>*>(handle);
*out = dataset->get()->num_total_features();
auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_total_features();
API_END();
}
......@@ -418,14 +416,14 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const char* parameters,
BoosterHandle* out) {
API_BEGIN();
const Dataset* p_train_data = reinterpret_cast<const std::shared_ptr<Dataset>*>(train_data)->get();
const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
std::vector<const Dataset*> p_valid_datas;
std::vector<std::string> p_valid_names;
for (int i = 0; i < n_valid_datas; ++i) {
p_valid_datas.emplace_back(reinterpret_cast<const std::shared_ptr<Dataset>*>(valid_datas[i])->get());
p_valid_datas.emplace_back(reinterpret_cast<const Dataset*>(valid_datas[i]));
p_valid_names.emplace_back(valid_names[i]);
}
*out = new std::shared_ptr<Booster>(new Booster(p_train_data, p_valid_datas, p_valid_names, parameters));
*out = new Booster(p_train_data, p_valid_datas, p_valid_names, parameters);
API_END();
}
......@@ -433,19 +431,19 @@ DllExport int LGBM_BoosterLoadFromModelfile(
const char* filename,
BoosterHandle* out) {
API_BEGIN();
*out = new std::shared_ptr<Booster>(new Booster(filename));
*out = new Booster(filename);
API_END();
}
DllExport int LGBM_BoosterFree(BoosterHandle handle) {
API_BEGIN();
delete reinterpret_cast<std::shared_ptr<Booster>*>(handle);
delete reinterpret_cast<Booster*>(handle);
API_END();
}
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter()) {
*is_finished = 1;
} else {
......@@ -459,7 +457,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* hess,
int* is_finished) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter(grad, hess)) {
*is_finished = 1;
} else {
......@@ -473,7 +471,7 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
int64_t* out_len,
float* out_results) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
auto result_buf = boosting->GetEvalAt(data);
*out_len = static_cast<int64_t>(result_buf.size());
......@@ -487,7 +485,7 @@ DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
int64_t* out_len,
const float** out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
int len = 0;
*out_result = ref_booster->GetTrainingScore(&len);
*out_len = static_cast<int64_t>(len);
......@@ -499,7 +497,7 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
int64_t* out_len,
float* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
int len = 0;
boosting->GetPredictAt(data, out_result, &len);
......@@ -514,7 +512,7 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* data_filename,
const char* result_filename) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
bool bool_data_has_header = data_has_header > 0 ? true : false;
ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
......@@ -534,7 +532,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t n_used_trees,
double* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
......@@ -561,7 +559,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
int64_t n_used_trees,
double* out_result) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
......@@ -581,7 +579,7 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_used_model,
const char* filename) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<std::shared_ptr<Booster>*>(handle)->get();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SaveModelToFile(num_used_model, filename);
API_END();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment