"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "bf1a604a5b5106d6b7f2aa07ea02be12115dcabc"
Commit a178b75b authored by Guolin Ke's avatar Guolin Ke
Browse files

change some c_api interfaces for better compatibility

parent 6837efe7
...@@ -151,6 +151,7 @@ public: ...@@ -151,6 +151,7 @@ public:
/*! \brief Disable copy */ /*! \brief Disable copy */
Boosting(const Boosting&) = delete; Boosting(const Boosting&) = delete;
static void LoadFileToBoosting(Boosting* boosting, const char* filename);
/*! /*!
* \brief Create boosting object * \brief Create boosting object
* \param type Type of boosting * \param type Type of boosting
......
...@@ -165,7 +165,7 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle, ...@@ -165,7 +165,7 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
* \param field_name field name, can be label, weight, group * \param field_name field name, can be label, weight, group
* \param field_data pointer to vector * \param field_data pointer to vector
* \param num_element number of element in field_data * \param num_element number of element in field_data
* \param type float_32:0, int32_t:1 * \param type float32 or int32
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_DatasetSetField(DatesetHandle handle, DllExport int LGBM_DatasetSetField(DatesetHandle handle,
...@@ -180,7 +180,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle, ...@@ -180,7 +180,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
* \param field_name field name * \param field_name field name
* \param out_len used to set result length * \param out_len used to set result length
* \param out_ptr pointer to the result * \param out_ptr pointer to the result
* \param out_type float_32:0, int32_t:1 * \param out_type float32 or int32
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_DatasetGetField(DatesetHandle handle, DllExport int LGBM_DatasetGetField(DatesetHandle handle,
...@@ -216,6 +216,7 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, ...@@ -216,6 +216,7 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
* \param valid_names names of validation data sets * \param valid_names names of validation data sets
* \param n_valid_datas number of validation set * \param n_valid_datas number of validation set
* \param parameters format: 'key1=value1 key2=value2' * \param parameters format: 'key1=value1 key2=value2'
* \param init_model_filename filename of model
* \prama out handle of created Booster * \prama out handle of created Booster
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
...@@ -224,6 +225,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, ...@@ -224,6 +225,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const char* valid_names[], const char* valid_names[],
int n_valid_datas, int n_valid_datas,
const char* parameters, const char* parameters,
const char* init_model_filename,
BoosterHandle* out); BoosterHandle* out);
/*! /*!
...@@ -232,7 +234,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, ...@@ -232,7 +234,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
* \param out handle of created Booster * \param out handle of created Booster
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterLoadFromModelfile( DllExport int LGBM_BoosterCreateFromModelfile(
const char* filename, const char* filename,
BoosterHandle* out); BoosterHandle* out);
......
...@@ -83,6 +83,8 @@ public: ...@@ -83,6 +83,8 @@ public:
void SetQueryBoundaries(const data_size_t* query_boundaries, data_size_t len); void SetQueryBoundaries(const data_size_t* query_boundaries, data_size_t len);
void SetQueryId(const data_size_t* query_id, data_size_t len);
/*! /*!
* \brief Set initial scores * \brief Set initial scores
* \param init_score Initial scores, this class will manage memory for init_score. * \param init_score Initial scores, this class will manage memory for init_score.
......
...@@ -15,7 +15,7 @@ BoostingType GetBoostingTypeFromModelFile(const char* filename) { ...@@ -15,7 +15,7 @@ BoostingType GetBoostingTypeFromModelFile(const char* filename) {
return BoostingType::kUnknow; return BoostingType::kUnknow;
} }
void LoadFileToBoosting(Boosting* boosting, const char* filename) { void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
if (boosting != nullptr) { if (boosting != nullptr) {
TextReader<size_t> model_reader(filename, true); TextReader<size_t> model_reader(filename, true);
model_reader.ReadAllLines(); model_reader.ReadAllLines();
......
...@@ -82,11 +82,12 @@ public: ...@@ -82,11 +82,12 @@ public:
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i])); Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
} }
} }
void LoadModelFromFile(const char* filename) {
Boosting::LoadFileToBoosting(boosting_.get(), filename);
}
~Booster() { ~Booster() {
} }
bool TrainOneIter() { bool TrainOneIter() {
return boosting_->TrainOneIter(nullptr, nullptr, false); return boosting_->TrainOneIter(nullptr, nullptr, false);
} }
...@@ -414,6 +415,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, ...@@ -414,6 +415,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const char* valid_names[], const char* valid_names[],
int n_valid_datas, int n_valid_datas,
const char* parameters, const char* parameters,
const char* init_model_filename,
BoosterHandle* out) { BoosterHandle* out) {
API_BEGIN(); API_BEGIN();
const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data); const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
...@@ -423,11 +425,15 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, ...@@ -423,11 +425,15 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
p_valid_datas.emplace_back(reinterpret_cast<const Dataset*>(valid_datas[i])); p_valid_datas.emplace_back(reinterpret_cast<const Dataset*>(valid_datas[i]));
p_valid_names.emplace_back(valid_names[i]); p_valid_names.emplace_back(valid_names[i]);
} }
*out = new Booster(p_train_data, p_valid_datas, p_valid_names, parameters); auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, p_valid_datas, p_valid_names, parameters));
if (init_model_filename != nullptr) {
ret->LoadModelFromFile(init_model_filename);
}
*out = ret.release();
API_END(); API_END();
} }
DllExport int LGBM_BoosterLoadFromModelfile( DllExport int LGBM_BoosterCreateFromModelfile(
const char* filename, const char* filename,
BoosterHandle* out) { BoosterHandle* out) {
API_BEGIN(); API_BEGIN();
......
...@@ -78,6 +78,8 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, data_si ...@@ -78,6 +78,8 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, data_si
name = Common::Trim(name); name = Common::Trim(name);
if (name == std::string("query") || name == std::string("group")) { if (name == std::string("query") || name == std::string("group")) {
metadata_.SetQueryBoundaries(field_data, num_element); metadata_.SetQueryBoundaries(field_data, num_element);
} else if (name == std::string("query_id") || name == std::string("group_id")) {
metadata_.SetQueryId(field_data, num_element);
} else { } else {
return false; return false;
} }
......
...@@ -248,6 +248,39 @@ void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size ...@@ -248,6 +248,39 @@ void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size
LoadQueryWeights(); LoadQueryWeights();
} }
void Metadata::SetQueryId(const data_size_t* query_id, data_size_t len) {
if (num_data_ != len) {
Log::Fatal("len of query id is not same with #data");
}
if (queries_.size() > 0) { queries_.clear(); }
queries_ = std::vector<data_size_t>(num_data_);
for (data_size_t i = 0; i < num_weights_; ++i) {
queries_[i] = query_id[i];
}
// need convert query_id to boundaries
std::vector<data_size_t> tmp_buffer;
data_size_t last_qid = -1;
data_size_t cur_cnt = 0;
for (data_size_t i = 0; i < num_data_; ++i) {
if (last_qid != queries_[i]) {
if (cur_cnt > 0) {
tmp_buffer.push_back(cur_cnt);
}
cur_cnt = 0;
last_qid = queries_[i];
}
++cur_cnt;
}
tmp_buffer.push_back(cur_cnt);
query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
query_boundaries_[0] = 0;
for (size_t i = 0; i < tmp_buffer.size(); ++i) {
query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
}
queries_.clear();
LoadQueryWeights();
}
void Metadata::LoadWeights() { void Metadata::LoadWeights() {
num_weights_ = 0; num_weights_ = 0;
......
...@@ -178,7 +178,7 @@ def test_booster(): ...@@ -178,7 +178,7 @@ def test_booster():
name = [c_str('test')] name = [c_str('test')]
booster = ctypes.c_void_p() booster = ctypes.c_void_p()
LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), c_array(ctypes.c_char_p, name), LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), c_array(ctypes.c_char_p, name),
len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"), ctypes.byref(booster)) len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"),None, ctypes.byref(booster))
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
for i in range(100): for i in range(100):
LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished)) LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished))
...@@ -191,7 +191,7 @@ def test_booster(): ...@@ -191,7 +191,7 @@ def test_booster():
test_free_dataset(train) test_free_dataset(train)
test_free_dataset(test[0]) test_free_dataset(test[0])
booster2 = ctypes.c_void_p() booster2 = ctypes.c_void_p()
LIB.LGBM_BoosterLoadFromModelfile(c_str('model.txt'), ctypes.byref(booster2)) LIB.LGBM_BoosterCreateFromModelfile(c_str('model.txt'), ctypes.byref(booster2))
data = [] data = []
inp = open('../../examples/binary_classification/binary.test', 'r') inp = open('../../examples/binary_classification/binary.test', 'r')
for line in inp.readlines(): for line in inp.readlines():
...@@ -214,4 +214,3 @@ def test_booster(): ...@@ -214,4 +214,3 @@ def test_booster():
test_dataset() test_dataset()
test_booster() test_booster()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment