Commit fc383361 authored by Guolin Ke's avatar Guolin Ke
Browse files

remove data name in metric

parent 8639107f
......@@ -222,7 +222,6 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
*/
DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const DatesetHandle valid_datas[],
const char* valid_names[],
int n_valid_datas,
const char* parameters,
const char* init_model_filename,
......@@ -267,6 +266,18 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* hess,
int* is_finished);
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len);
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, const char*** out_strs);
/*!
* \brief get evaluation for training data and validation data
* \param handle handle
......@@ -275,7 +286,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
* \param out_result the string containing evaluation statistics, should allocate memory before call this function
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterEval(BoosterHandle handle,
DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
int data,
int64_t* out_len,
float* out_results);
......@@ -287,7 +298,7 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
DllExport int LGBM_BoosterGetTrainingScore(BoosterHandle handle,
int64_t* out_len,
const float** out_result);
......
......@@ -24,8 +24,7 @@ public:
* \param metadata Label data
* \param num_data Number of data
*/
virtual void Init(const char* test_name,
const Metadata& metadata, data_size_t num_data) = 0;
virtual void Init(const Metadata& metadata, data_size_t num_data) = 0;
virtual const std::vector<std::string>& GetName() const = 0;
......
......@@ -139,8 +139,7 @@ void Application::LoadData() {
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init("training", train_data_->metadata(),
train_data_->num_data());
metric->Init(train_data_->metadata(), train_data_->num_data());
train_metric_.push_back(std::move(metric));
}
}
......@@ -164,9 +163,8 @@ void Application::LoadData() {
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(config_.io_config.valid_data_filenames[i].c_str(),
valid_datas_.back()->metadata(),
valid_datas_.back()->num_data());
metric->Init(valid_datas_.back()->metadata(),
valid_datas_.back()->num_data());
valid_metrics_.back().push_back(std::move(metric));
}
valid_metrics_.back().shrink_to_fit();
......
......@@ -236,7 +236,7 @@ bool GBDT::OutputMetric(int iter) {
auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score());
for (size_t k = 0; k < name.size(); ++k) {
Log::Info("Iteration: %d, %s : %f", iter, name[k].c_str(), scores[k]);
Log::Info("Iteration:%d, training %s : %f", iter, name[k].c_str(), scores[k]);
}
}
}
......@@ -248,7 +248,7 @@ bool GBDT::OutputMetric(int iter) {
if ((iter % gbdt_config_->output_freq) == 0) {
auto name = valid_metrics_[i][j]->GetName();
for (size_t k = 0; k < name.size(); ++k) {
Log::Info("Iteration: %d, %s : %f", iter, name[k].c_str(), test_scores[k]);
Log::Info("Iteration:%d, valid_%d %s : %f", iter, i + 1, name[k].c_str(), test_scores[k]);
}
}
if (!ret && early_stopping_round_ > 0) {
......
......@@ -29,7 +29,6 @@ public:
Booster(const Dataset* train_data,
std::vector<const Dataset*> valid_data,
std::vector<std::string> valid_names,
const char* parameters)
:train_data_(train_data), valid_datas_(valid_data) {
config_.LoadFromString(parameters);
......@@ -50,8 +49,7 @@ public:
auto metric = std::unique_ptr<Metric>(
Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init("training", train_data_->metadata(),
train_data_->num_data());
metric->Init(train_data_->metadata(), train_data_->num_data());
train_metric_.push_back(std::move(metric));
}
train_metric_.shrink_to_fit();
......@@ -61,9 +59,7 @@ public:
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(valid_names[i].c_str(),
valid_datas_[i]->metadata(),
valid_datas_[i]->num_data());
metric->Init(valid_datas_[i]->metadata(), valid_datas_[i]->num_data());
valid_metrics_.back().push_back(std::move(metric));
}
valid_metrics_.back().shrink_to_fit();
......@@ -82,12 +78,15 @@ public:
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
}
}
void LoadModelFromFile(const char* filename) {
Boosting::LoadFileToBoosting(boosting_.get(), filename);
}
~Booster() {
}
bool TrainOneIter() {
return boosting_->TrainOneIter(nullptr, nullptr, false);
}
......@@ -121,7 +120,25 @@ public:
void SaveModelToFile(int num_used_model, const char* filename) {
boosting_->SaveModelToFile(num_used_model, true, filename);
}
int GetEvalCounts() const {
int ret = 0;
for (const auto& metric : train_metric_) {
ret += static_cast<int>(metric->GetName().size());
}
return ret;
}
int GetEvalNames(const char*** out_strs) const {
int idx = 0;
for (const auto& metric : train_metric_) {
for (const auto& name : metric->GetName()) {
*(out_strs[idx++]) = name.c_str();
}
}
return idx;
}
const Boosting* GetBoosting() const { return boosting_.get(); }
const float* GetTrainingScore(int* out_len) const { return boosting_->GetTrainingScore(out_len); }
......@@ -412,7 +429,6 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const DatesetHandle valid_datas[],
const char* valid_names[],
int n_valid_datas,
const char* parameters,
const char* init_model_filename,
......@@ -420,12 +436,10 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
API_BEGIN();
const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
std::vector<const Dataset*> p_valid_datas;
std::vector<std::string> p_valid_names;
for (int i = 0; i < n_valid_datas; ++i) {
p_valid_datas.emplace_back(reinterpret_cast<const Dataset*>(valid_datas[i]));
p_valid_names.emplace_back(valid_names[i]);
}
auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, p_valid_datas, p_valid_names, parameters));
auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, p_valid_datas, parameters));
if (init_model_filename != nullptr) {
ret->LoadModelFromFile(init_model_filename);
}
......@@ -472,7 +486,30 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
API_END();
}
DllExport int LGBM_BoosterEval(BoosterHandle handle,
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetEvalCounts();
API_END();
}
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, const char*** out_strs) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetEvalNames(out_strs);
API_END();
}
DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
int data,
int64_t* out_len,
float* out_results) {
......@@ -487,7 +524,7 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
API_END();
}
DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
DllExport int LGBM_BoosterGetTrainingScore(BoosterHandle handle,
int64_t* out_len,
const float** out_result) {
API_BEGIN();
......
......@@ -29,11 +29,8 @@ public:
}
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf;
str_buf << test_name << "'s : " << PointWiseLossCalculator::Name();
name_.emplace_back(str_buf.str());
void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back(PointWiseLossCalculator::Name());
num_data_ = num_data;
// get label
......@@ -162,10 +159,8 @@ public:
return 1.0f;
}
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf;
str_buf << test_name << "'s : AUC";
name_.emplace_back(str_buf.str());
void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back("AUC");
num_data_ = num_data;
// get label
......
......@@ -23,10 +23,9 @@ public:
}
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf;
str_buf << test_name << " : " << PointWiseLossCalculator::Name();
name_.emplace_back(str_buf.str());
void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back(PointWiseLossCalculator::Name());
num_data_ = num_data;
// get label
label_ = metadata.label();
......
......@@ -33,12 +33,9 @@ public:
~NDCGMetric() {
}
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
void Init(const Metadata& metadata, data_size_t num_data) override {
for (auto k : eval_at_) {
std::stringstream str_buf;
str_buf << test_name << "'s : ";
str_buf << "NDCG@" + std::to_string(k) + " ";
name_.emplace_back(str_buf.str());
name_.emplace_back(std::string("NDCG@") + std::to_string(k));
}
num_data_ = num_data;
// get label
......
......@@ -31,10 +31,8 @@ public:
return -1.0f;
}
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf;
str_buf << test_name << " : " << PointWiseLossCalculator::Name();
name_.emplace_back(str_buf.str());
void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back(PointWiseLossCalculator::Name());
num_data_ = num_data;
// get label
......
......@@ -175,16 +175,15 @@ def test_dataset():
def test_booster():
train = test_load_from_mat('../../examples/binary_classification/binary.train', None)
test = [test_load_from_mat('../../examples/binary_classification/binary.test', train)]
name = [c_str('test')]
booster = ctypes.c_void_p()
LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), c_array(ctypes.c_char_p, name),
LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test),
len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"),None, ctypes.byref(booster))
is_finished = ctypes.c_int(0)
for i in range(100):
LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished))
result = np.array([0.0], dtype=np.float32)
out_len = ctypes.c_ulong(0)
LIB.LGBM_BoosterEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))
LIB.LGBM_BoosterGetEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))
print ('%d Iteration test AUC %f' %(i, result[0]))
LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt'))
LIB.LGBM_BoosterFree(booster)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment