"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "4bb2f2f854e674493e2a9aae91d2c2363f7652b1"
Commit a178b75b authored by Guolin Ke's avatar Guolin Ke
Browse files

change some c_api interfaces for better compatibility

parent 6837efe7
......@@ -151,6 +151,7 @@ public:
/*! \brief Disable copy */
Boosting(const Boosting&) = delete;
static void LoadFileToBoosting(Boosting* boosting, const char* filename);
/*!
* \brief Create boosting object
* \param type Type of boosting
......
......@@ -165,7 +165,7 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
* \param field_name field name, can be label, weight, group
* \param field_data pointer to vector
* \param num_element number of element in field_data
* \param type float_32:0, int32_t:1
* \param type float32 or int32
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetSetField(DatesetHandle handle,
......@@ -180,7 +180,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
* \param field_name field name
* \param out_len used to set result length
* \param out_ptr pointer to the result
* \param out_type float_32:0, int32_t:1
* \param out_type float32 or int32
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetGetField(DatesetHandle handle,
......@@ -216,6 +216,7 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
* \param valid_names names of validation data sets
* \param n_valid_datas number of validation set
* \param parameters format: 'key1=value1 key2=value2'
* \param init_model_filename filename of model
* \prama out handle of created Booster
* \return 0 when success, -1 when failure happens
*/
......@@ -224,6 +225,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const char* valid_names[],
int n_valid_datas,
const char* parameters,
const char* init_model_filename,
BoosterHandle* out);
/*!
......@@ -232,7 +234,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
* \param out handle of created Booster
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterLoadFromModelfile(
DllExport int LGBM_BoosterCreateFromModelfile(
const char* filename,
BoosterHandle* out);
......
......@@ -83,6 +83,8 @@ public:
void SetQueryBoundaries(const data_size_t* query_boundaries, data_size_t len);
void SetQueryId(const data_size_t* query_id, data_size_t len);
/*!
* \brief Set initial scores
* \param init_score Initial scores, this class will manage memory for init_score.
......
......@@ -15,7 +15,7 @@ BoostingType GetBoostingTypeFromModelFile(const char* filename) {
return BoostingType::kUnknow;
}
void LoadFileToBoosting(Boosting* boosting, const char* filename) {
void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
if (boosting != nullptr) {
TextReader<size_t> model_reader(filename, true);
model_reader.ReadAllLines();
......
......@@ -82,11 +82,12 @@ public:
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
}
}
void LoadModelFromFile(const char* filename) {
Boosting::LoadFileToBoosting(boosting_.get(), filename);
}
~Booster() {
}
bool TrainOneIter() {
return boosting_->TrainOneIter(nullptr, nullptr, false);
}
......@@ -414,6 +415,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const char* valid_names[],
int n_valid_datas,
const char* parameters,
const char* init_model_filename,
BoosterHandle* out) {
API_BEGIN();
const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
......@@ -423,11 +425,15 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
p_valid_datas.emplace_back(reinterpret_cast<const Dataset*>(valid_datas[i]));
p_valid_names.emplace_back(valid_names[i]);
}
*out = new Booster(p_train_data, p_valid_datas, p_valid_names, parameters);
auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, p_valid_datas, p_valid_names, parameters));
if (init_model_filename != nullptr) {
ret->LoadModelFromFile(init_model_filename);
}
*out = ret.release();
API_END();
}
DllExport int LGBM_BoosterLoadFromModelfile(
DllExport int LGBM_BoosterCreateFromModelfile(
const char* filename,
BoosterHandle* out) {
API_BEGIN();
......
......@@ -78,6 +78,8 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, data_si
name = Common::Trim(name);
if (name == std::string("query") || name == std::string("group")) {
metadata_.SetQueryBoundaries(field_data, num_element);
} else if (name == std::string("query_id") || name == std::string("group_id")) {
metadata_.SetQueryId(field_data, num_element);
} else {
return false;
}
......
......@@ -248,6 +248,39 @@ void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size
LoadQueryWeights();
}
void Metadata::SetQueryId(const data_size_t* query_id, data_size_t len) {
if (num_data_ != len) {
Log::Fatal("len of query id is not same with #data");
}
if (queries_.size() > 0) { queries_.clear(); }
queries_ = std::vector<data_size_t>(num_data_);
for (data_size_t i = 0; i < num_weights_; ++i) {
queries_[i] = query_id[i];
}
// need convert query_id to boundaries
std::vector<data_size_t> tmp_buffer;
data_size_t last_qid = -1;
data_size_t cur_cnt = 0;
for (data_size_t i = 0; i < num_data_; ++i) {
if (last_qid != queries_[i]) {
if (cur_cnt > 0) {
tmp_buffer.push_back(cur_cnt);
}
cur_cnt = 0;
last_qid = queries_[i];
}
++cur_cnt;
}
tmp_buffer.push_back(cur_cnt);
query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
query_boundaries_[0] = 0;
for (size_t i = 0; i < tmp_buffer.size(); ++i) {
query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
}
queries_.clear();
LoadQueryWeights();
}
void Metadata::LoadWeights() {
num_weights_ = 0;
......
......@@ -178,7 +178,7 @@ def test_booster():
name = [c_str('test')]
booster = ctypes.c_void_p()
LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), c_array(ctypes.c_char_p, name),
len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"), ctypes.byref(booster))
len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"),None, ctypes.byref(booster))
is_finished = ctypes.c_int(0)
for i in range(100):
LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished))
......@@ -191,7 +191,7 @@ def test_booster():
test_free_dataset(train)
test_free_dataset(test[0])
booster2 = ctypes.c_void_p()
LIB.LGBM_BoosterLoadFromModelfile(c_str('model.txt'), ctypes.byref(booster2))
LIB.LGBM_BoosterCreateFromModelfile(c_str('model.txt'), ctypes.byref(booster2))
data = []
inp = open('../../examples/binary_classification/binary.test', 'r')
for line in inp.readlines():
......@@ -214,4 +214,3 @@ def test_booster():
test_dataset()
test_booster()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment