Commit fc383361 authored by Guolin Ke's avatar Guolin Ke
Browse files

remove data name in metric

parent 8639107f
...@@ -222,7 +222,6 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, ...@@ -222,7 +222,6 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
*/ */
DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const DatesetHandle valid_datas[], const DatesetHandle valid_datas[],
const char* valid_names[],
int n_valid_datas, int n_valid_datas,
const char* parameters, const char* parameters,
const char* init_model_filename, const char* init_model_filename,
...@@ -267,6 +266,18 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, ...@@ -267,6 +266,18 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* hess, const float* hess,
int* is_finished); int* is_finished);
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len);
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, const char*** out_strs);
/*! /*!
* \brief get evaluation for training data and validation data * \brief get evaluation for training data and validation data
* \param handle handle * \param handle handle
...@@ -275,7 +286,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, ...@@ -275,7 +286,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
* \param out_result the string containing evaluation statistics, should allocate memory before call this function * \param out_result the string containing evaluation statistics, should allocate memory before call this function
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterEval(BoosterHandle handle, DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
int data, int data,
int64_t* out_len, int64_t* out_len,
float* out_results); float* out_results);
...@@ -287,7 +298,7 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle, ...@@ -287,7 +298,7 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
* \param out_result used to set a pointer to array * \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterGetScore(BoosterHandle handle, DllExport int LGBM_BoosterGetTrainingScore(BoosterHandle handle,
int64_t* out_len, int64_t* out_len,
const float** out_result); const float** out_result);
......
...@@ -24,8 +24,7 @@ public: ...@@ -24,8 +24,7 @@ public:
* \param metadata Label data * \param metadata Label data
* \param num_data Number of data * \param num_data Number of data
*/ */
virtual void Init(const char* test_name, virtual void Init(const Metadata& metadata, data_size_t num_data) = 0;
const Metadata& metadata, data_size_t num_data) = 0;
virtual const std::vector<std::string>& GetName() const = 0; virtual const std::vector<std::string>& GetName() const = 0;
......
...@@ -139,8 +139,7 @@ void Application::LoadData() { ...@@ -139,8 +139,7 @@ void Application::LoadData() {
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config)); auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init("training", train_data_->metadata(), metric->Init(train_data_->metadata(), train_data_->num_data());
train_data_->num_data());
train_metric_.push_back(std::move(metric)); train_metric_.push_back(std::move(metric));
} }
} }
...@@ -164,9 +163,8 @@ void Application::LoadData() { ...@@ -164,9 +163,8 @@ void Application::LoadData() {
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config)); auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init(config_.io_config.valid_data_filenames[i].c_str(), metric->Init(valid_datas_.back()->metadata(),
valid_datas_.back()->metadata(), valid_datas_.back()->num_data());
valid_datas_.back()->num_data());
valid_metrics_.back().push_back(std::move(metric)); valid_metrics_.back().push_back(std::move(metric));
} }
valid_metrics_.back().shrink_to_fit(); valid_metrics_.back().shrink_to_fit();
......
...@@ -236,7 +236,7 @@ bool GBDT::OutputMetric(int iter) { ...@@ -236,7 +236,7 @@ bool GBDT::OutputMetric(int iter) {
auto name = sub_metric->GetName(); auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score()); auto scores = sub_metric->Eval(train_score_updater_->score());
for (size_t k = 0; k < name.size(); ++k) { for (size_t k = 0; k < name.size(); ++k) {
Log::Info("Iteration: %d, %s : %f", iter, name[k].c_str(), scores[k]); Log::Info("Iteration:%d, training %s : %f", iter, name[k].c_str(), scores[k]);
} }
} }
} }
...@@ -248,7 +248,7 @@ bool GBDT::OutputMetric(int iter) { ...@@ -248,7 +248,7 @@ bool GBDT::OutputMetric(int iter) {
if ((iter % gbdt_config_->output_freq) == 0) { if ((iter % gbdt_config_->output_freq) == 0) {
auto name = valid_metrics_[i][j]->GetName(); auto name = valid_metrics_[i][j]->GetName();
for (size_t k = 0; k < name.size(); ++k) { for (size_t k = 0; k < name.size(); ++k) {
Log::Info("Iteration: %d, %s : %f", iter, name[k].c_str(), test_scores[k]); Log::Info("Iteration:%d, valid_%d %s : %f", iter, i + 1, name[k].c_str(), test_scores[k]);
} }
} }
if (!ret && early_stopping_round_ > 0) { if (!ret && early_stopping_round_ > 0) {
......
...@@ -29,7 +29,6 @@ public: ...@@ -29,7 +29,6 @@ public:
Booster(const Dataset* train_data, Booster(const Dataset* train_data,
std::vector<const Dataset*> valid_data, std::vector<const Dataset*> valid_data,
std::vector<std::string> valid_names,
const char* parameters) const char* parameters)
:train_data_(train_data), valid_datas_(valid_data) { :train_data_(train_data), valid_datas_(valid_data) {
config_.LoadFromString(parameters); config_.LoadFromString(parameters);
...@@ -50,8 +49,7 @@ public: ...@@ -50,8 +49,7 @@ public:
auto metric = std::unique_ptr<Metric>( auto metric = std::unique_ptr<Metric>(
Metric::CreateMetric(metric_type, config_.metric_config)); Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init("training", train_data_->metadata(), metric->Init(train_data_->metadata(), train_data_->num_data());
train_data_->num_data());
train_metric_.push_back(std::move(metric)); train_metric_.push_back(std::move(metric));
} }
train_metric_.shrink_to_fit(); train_metric_.shrink_to_fit();
...@@ -61,9 +59,7 @@ public: ...@@ -61,9 +59,7 @@ public:
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config)); auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init(valid_names[i].c_str(), metric->Init(valid_datas_[i]->metadata(), valid_datas_[i]->num_data());
valid_datas_[i]->metadata(),
valid_datas_[i]->num_data());
valid_metrics_.back().push_back(std::move(metric)); valid_metrics_.back().push_back(std::move(metric));
} }
valid_metrics_.back().shrink_to_fit(); valid_metrics_.back().shrink_to_fit();
...@@ -82,12 +78,15 @@ public: ...@@ -82,12 +78,15 @@ public:
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i])); Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
} }
} }
void LoadModelFromFile(const char* filename) { void LoadModelFromFile(const char* filename) {
Boosting::LoadFileToBoosting(boosting_.get(), filename); Boosting::LoadFileToBoosting(boosting_.get(), filename);
} }
~Booster() { ~Booster() {
} }
bool TrainOneIter() { bool TrainOneIter() {
return boosting_->TrainOneIter(nullptr, nullptr, false); return boosting_->TrainOneIter(nullptr, nullptr, false);
} }
...@@ -121,7 +120,25 @@ public: ...@@ -121,7 +120,25 @@ public:
void SaveModelToFile(int num_used_model, const char* filename) { void SaveModelToFile(int num_used_model, const char* filename) {
boosting_->SaveModelToFile(num_used_model, true, filename); boosting_->SaveModelToFile(num_used_model, true, filename);
} }
int GetEvalCounts() const {
int ret = 0;
for (const auto& metric : train_metric_) {
ret += static_cast<int>(metric->GetName().size());
}
return ret;
}
int GetEvalNames(const char*** out_strs) const {
int idx = 0;
for (const auto& metric : train_metric_) {
for (const auto& name : metric->GetName()) {
*(out_strs[idx++]) = name.c_str();
}
}
return idx;
}
const Boosting* GetBoosting() const { return boosting_.get(); } const Boosting* GetBoosting() const { return boosting_.get(); }
const float* GetTrainingScore(int* out_len) const { return boosting_->GetTrainingScore(out_len); } const float* GetTrainingScore(int* out_len) const { return boosting_->GetTrainingScore(out_len); }
...@@ -412,7 +429,6 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, ...@@ -412,7 +429,6 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const DatesetHandle valid_datas[], const DatesetHandle valid_datas[],
const char* valid_names[],
int n_valid_datas, int n_valid_datas,
const char* parameters, const char* parameters,
const char* init_model_filename, const char* init_model_filename,
...@@ -420,12 +436,10 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, ...@@ -420,12 +436,10 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
API_BEGIN(); API_BEGIN();
const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data); const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
std::vector<const Dataset*> p_valid_datas; std::vector<const Dataset*> p_valid_datas;
std::vector<std::string> p_valid_names;
for (int i = 0; i < n_valid_datas; ++i) { for (int i = 0; i < n_valid_datas; ++i) {
p_valid_datas.emplace_back(reinterpret_cast<const Dataset*>(valid_datas[i])); p_valid_datas.emplace_back(reinterpret_cast<const Dataset*>(valid_datas[i]));
p_valid_names.emplace_back(valid_names[i]);
} }
auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, p_valid_datas, p_valid_names, parameters)); auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, p_valid_datas, parameters));
if (init_model_filename != nullptr) { if (init_model_filename != nullptr) {
ret->LoadModelFromFile(init_model_filename); ret->LoadModelFromFile(init_model_filename);
} }
...@@ -472,7 +486,30 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, ...@@ -472,7 +486,30 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
API_END(); API_END();
} }
DllExport int LGBM_BoosterEval(BoosterHandle handle, /*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetEvalCounts();
API_END();
}
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, const char*** out_strs) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetEvalNames(out_strs);
API_END();
}
DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
int data, int data,
int64_t* out_len, int64_t* out_len,
float* out_results) { float* out_results) {
...@@ -487,7 +524,7 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle, ...@@ -487,7 +524,7 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
API_END(); API_END();
} }
DllExport int LGBM_BoosterGetScore(BoosterHandle handle, DllExport int LGBM_BoosterGetTrainingScore(BoosterHandle handle,
int64_t* out_len, int64_t* out_len,
const float** out_result) { const float** out_result) {
API_BEGIN(); API_BEGIN();
......
...@@ -29,11 +29,8 @@ public: ...@@ -29,11 +29,8 @@ public:
} }
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back(PointWiseLossCalculator::Name());
std::stringstream str_buf;
str_buf << test_name << "'s : " << PointWiseLossCalculator::Name();
name_.emplace_back(str_buf.str());
num_data_ = num_data; num_data_ = num_data;
// get label // get label
...@@ -162,10 +159,8 @@ public: ...@@ -162,10 +159,8 @@ public:
return 1.0f; return 1.0f;
} }
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf; name_.emplace_back("AUC");
str_buf << test_name << "'s : AUC";
name_.emplace_back(str_buf.str());
num_data_ = num_data; num_data_ = num_data;
// get label // get label
......
...@@ -23,10 +23,9 @@ public: ...@@ -23,10 +23,9 @@ public:
} }
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf;
str_buf << test_name << " : " << PointWiseLossCalculator::Name(); name_.emplace_back(PointWiseLossCalculator::Name());
name_.emplace_back(str_buf.str());
num_data_ = num_data; num_data_ = num_data;
// get label // get label
label_ = metadata.label(); label_ = metadata.label();
......
...@@ -33,12 +33,9 @@ public: ...@@ -33,12 +33,9 @@ public:
~NDCGMetric() { ~NDCGMetric() {
} }
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
for (auto k : eval_at_) { for (auto k : eval_at_) {
std::stringstream str_buf; name_.emplace_back(std::string("NDCG@") + std::to_string(k));
str_buf << test_name << "'s : ";
str_buf << "NDCG@" + std::to_string(k) + " ";
name_.emplace_back(str_buf.str());
} }
num_data_ = num_data; num_data_ = num_data;
// get label // get label
......
...@@ -31,10 +31,8 @@ public: ...@@ -31,10 +31,8 @@ public:
return -1.0f; return -1.0f;
} }
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf; name_.emplace_back(PointWiseLossCalculator::Name());
str_buf << test_name << " : " << PointWiseLossCalculator::Name();
name_.emplace_back(str_buf.str());
num_data_ = num_data; num_data_ = num_data;
// get label // get label
......
...@@ -175,16 +175,15 @@ def test_dataset(): ...@@ -175,16 +175,15 @@ def test_dataset():
def test_booster(): def test_booster():
train = test_load_from_mat('../../examples/binary_classification/binary.train', None) train = test_load_from_mat('../../examples/binary_classification/binary.train', None)
test = [test_load_from_mat('../../examples/binary_classification/binary.test', train)] test = [test_load_from_mat('../../examples/binary_classification/binary.test', train)]
name = [c_str('test')]
booster = ctypes.c_void_p() booster = ctypes.c_void_p()
LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), c_array(ctypes.c_char_p, name), LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test),
len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"),None, ctypes.byref(booster)) len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"),None, ctypes.byref(booster))
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
for i in range(100): for i in range(100):
LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished)) LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished))
result = np.array([0.0], dtype=np.float32) result = np.array([0.0], dtype=np.float32)
out_len = ctypes.c_ulong(0) out_len = ctypes.c_ulong(0)
LIB.LGBM_BoosterEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_float))) LIB.LGBM_BoosterGetEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))
print ('%d Iteration test AUC %f' %(i, result[0])) print ('%d Iteration test AUC %f' %(i, result[0]))
LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt')) LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt'))
LIB.LGBM_BoosterFree(booster) LIB.LGBM_BoosterFree(booster)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment