Commit b3db9e92 authored by Belinda Trotta's avatar Belinda Trotta Committed by Nikita Titov
Browse files

Top k multi error (#2178)

* Implement top-k multiclass error metric. Add new parameter top_k_threshold.

* Add test for multiclass metrics

* Make test less sensitive to avoid floating-point issues.

* Change tabs to spaces.

* Fix problem with test in Python 2. Refactor to use np.testing. Decrease number of training rounds so loss is larger and easier to compare.

* Move multiclass tests into test_engine.py

* Change parameter name from top_k_threshold to multi_error_top_k.

* Fix top-k error metric to handle case where scores are equal. Update tests and docs.

* Change name of top-k metric to multi_error@k.

* Change tabs to spaces.

* Fix formatting.

* Fix minor issues in docs.
parent 19de2be0
...@@ -843,6 +843,18 @@ Metric Parameters ...@@ -843,6 +843,18 @@ Metric Parameters
- `NDCG <https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG>`__ and `MAP <https://makarandtapaswi.wordpress.com/2012/07/02/intuition-behind-average-precision-and-map/>`__ evaluation positions, separated by ``,`` - `NDCG <https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG>`__ and `MAP <https://makarandtapaswi.wordpress.com/2012/07/02/intuition-behind-average-precision-and-map/>`__ evaluation positions, separated by ``,``
- ``multi_error_top_k`` :raw-html:`<a id="multi_error_top_k" title="Permalink to this parameter" href="#multi_error_top_k">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int, constraints: ``multi_error_top_k > 0``
- used only with ``multi_error`` metric
- threshold for top-k multi-error metric
- the error on each sample is ``0`` if the true class is among the top ``multi_error_top_k`` predictions, and ``1`` otherwise
- more precisely, the error on a sample is ``0`` if there are at least ``num_classes - multi_error_top_k`` predictions strictly less than the prediction on the true class
- when ``multi_error_top_k=1`` this is equivalent to the usual multi-error metric
Network Parameters Network Parameters
------------------ ------------------
......
...@@ -747,6 +747,14 @@ struct Config { ...@@ -747,6 +747,14 @@ struct Config {
// desc = `NDCG <https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG>`__ and `MAP <https://makarandtapaswi.wordpress.com/2012/07/02/intuition-behind-average-precision-and-map/>`__ evaluation positions, separated by ``,`` // desc = `NDCG <https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG>`__ and `MAP <https://makarandtapaswi.wordpress.com/2012/07/02/intuition-behind-average-precision-and-map/>`__ evaluation positions, separated by ``,``
std::vector<int> eval_at; std::vector<int> eval_at;
// check = >0
// desc = used only with ``multi_error`` metric
// desc = threshold for top-k multi-error metric
// desc = the error on each sample is ``0`` if the true class is among the top ``multi_error_top_k`` predictions, and ``1`` otherwise
// descl2 = more precisely, the error on a sample is ``0`` if there are at least ``num_classes - multi_error_top_k`` predictions strictly less than the prediction on the true class
// desc = when ``multi_error_top_k=1`` this is equivalent to the usual multi-error metric
int multi_error_top_k = 1;
#pragma endregion #pragma endregion
#pragma region Network Parameters #pragma region Network Parameters
......
...@@ -260,6 +260,7 @@ std::unordered_set<std::string> Config::parameter_set({ ...@@ -260,6 +260,7 @@ std::unordered_set<std::string> Config::parameter_set({
"metric_freq", "metric_freq",
"is_provide_training_metric", "is_provide_training_metric",
"eval_at", "eval_at",
"multi_error_top_k",
"num_machines", "num_machines",
"local_listen_port", "local_listen_port",
"time_out", "time_out",
...@@ -521,6 +522,9 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -521,6 +522,9 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
eval_at = Common::StringToArray<int>(tmp_str, ','); eval_at = Common::StringToArray<int>(tmp_str, ',');
} }
GetInt(params, "multi_error_top_k", &multi_error_top_k);
CHECK(multi_error_top_k >0);
GetInt(params, "num_machines", &num_machines); GetInt(params, "num_machines", &num_machines);
CHECK(num_machines >0); CHECK(num_machines >0);
...@@ -637,6 +641,7 @@ std::string Config::SaveMembersToString() const { ...@@ -637,6 +641,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[metric_freq: " << metric_freq << "]\n"; str_buf << "[metric_freq: " << metric_freq << "]\n";
str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n"; str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n";
str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n"; str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n";
str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n";
str_buf << "[num_machines: " << num_machines << "]\n"; str_buf << "[num_machines: " << num_machines << "]\n";
str_buf << "[local_listen_port: " << local_listen_port << "]\n"; str_buf << "[local_listen_port: " << local_listen_port << "]\n";
str_buf << "[time_out: " << time_out << "]\n"; str_buf << "[time_out: " << time_out << "]\n";
......
...@@ -20,7 +20,7 @@ namespace LightGBM { ...@@ -20,7 +20,7 @@ namespace LightGBM {
template<typename PointWiseLossCalculator> template<typename PointWiseLossCalculator>
class MulticlassMetric: public Metric { class MulticlassMetric: public Metric {
public: public:
explicit MulticlassMetric(const Config& config) { explicit MulticlassMetric(const Config& config) :config_(config){
num_class_ = config.num_class; num_class_ = config.num_class;
} }
...@@ -28,7 +28,7 @@ class MulticlassMetric: public Metric { ...@@ -28,7 +28,7 @@ class MulticlassMetric: public Metric {
} }
void Init(const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back(PointWiseLossCalculator::Name()); name_.emplace_back(PointWiseLossCalculator::Name(config_));
num_data_ = num_data; num_data_ = num_data;
// get label // get label
label_ = metadata.label(); label_ = metadata.label();
...@@ -72,7 +72,7 @@ class MulticlassMetric: public Metric { ...@@ -72,7 +72,7 @@ class MulticlassMetric: public Metric {
std::vector<double> rec(num_pred_per_row); std::vector<double> rec(num_pred_per_row);
objective->ConvertOutput(raw_score.data(), rec.data()); objective->ConvertOutput(raw_score.data(), rec.data());
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec); sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec, config_);
} }
} else { } else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
...@@ -85,7 +85,7 @@ class MulticlassMetric: public Metric { ...@@ -85,7 +85,7 @@ class MulticlassMetric: public Metric {
std::vector<double> rec(num_pred_per_row); std::vector<double> rec(num_pred_per_row);
objective->ConvertOutput(raw_score.data(), rec.data()); objective->ConvertOutput(raw_score.data(), rec.data());
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i]; sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec, config_) * weights_[i];
} }
} }
} else { } else {
...@@ -98,7 +98,7 @@ class MulticlassMetric: public Metric { ...@@ -98,7 +98,7 @@ class MulticlassMetric: public Metric {
rec[k] = static_cast<double>(score[idx]); rec[k] = static_cast<double>(score[idx]);
} }
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec); sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec, config_);
} }
} else { } else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
...@@ -109,7 +109,7 @@ class MulticlassMetric: public Metric { ...@@ -109,7 +109,7 @@ class MulticlassMetric: public Metric {
rec[k] = static_cast<double>(score[idx]); rec[k] = static_cast<double>(score[idx]);
} }
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i]; sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec, config_) * weights_[i];
} }
} }
} }
...@@ -129,25 +129,28 @@ class MulticlassMetric: public Metric { ...@@ -129,25 +129,28 @@ class MulticlassMetric: public Metric {
/*! \brief Name of this test set */ /*! \brief Name of this test set */
std::vector<std::string> name_; std::vector<std::string> name_;
int num_class_; int num_class_;
/*! \brief config parameters*/
Config config_;
}; };
/*! \brief L2 loss for multiclass task */ /*! \brief top-k error for multiclass task; if k=1 (default) this is the usual multi-error */
class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> { class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
public: public:
explicit MultiErrorMetric(const Config& config) :MulticlassMetric<MultiErrorMetric>(config) {} explicit MultiErrorMetric(const Config& config) :MulticlassMetric<MultiErrorMetric>(config) {}
inline static double LossOnPoint(label_t label, std::vector<double>& score) { inline static double LossOnPoint(label_t label, std::vector<double>& score, const Config& config) {
size_t k = static_cast<size_t>(label); size_t k = static_cast<size_t>(label);
int num_larger = 0;
for (size_t i = 0; i < score.size(); ++i) { for (size_t i = 0; i < score.size(); ++i) {
if (i != k && score[i] >= score[k]) { if (score[i] >= score[k]) ++num_larger;
return 1.0f; if (num_larger > config.multi_error_top_k) return 1.0f;
}
} }
return 0.0f; return 0.0f;
} }
inline static const char* Name() { inline static const std::string Name(const Config& config) {
return "multi_error"; if (config.multi_error_top_k == 1) return "multi_error";
else return "multi_error@" + std::to_string(config.multi_error_top_k);
} }
}; };
...@@ -156,7 +159,7 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetr ...@@ -156,7 +159,7 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetr
public: public:
explicit MultiSoftmaxLoglossMetric(const Config& config) :MulticlassMetric<MultiSoftmaxLoglossMetric>(config) {} explicit MultiSoftmaxLoglossMetric(const Config& config) :MulticlassMetric<MultiSoftmaxLoglossMetric>(config) {}
inline static double LossOnPoint(label_t label, std::vector<double>& score) { inline static double LossOnPoint(label_t label, std::vector<double>& score, const Config&) {
size_t k = static_cast<size_t>(label); size_t k = static_cast<size_t>(label);
if (score[k] > kEpsilon) { if (score[k] > kEpsilon) {
return static_cast<double>(-std::log(score[k])); return static_cast<double>(-std::log(score[k]));
...@@ -165,7 +168,7 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetr ...@@ -165,7 +168,7 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetr
} }
} }
inline static const char* Name() { inline static const std::string Name(const Config&) {
return "multi_logloss"; return "multi_logloss";
} }
}; };
......
...@@ -26,6 +26,13 @@ def multi_logloss(y_true, y_pred): ...@@ -26,6 +26,13 @@ def multi_logloss(y_true, y_pred):
return np.mean([-math.log(y_pred[i][y]) for i, y in enumerate(y_true)]) return np.mean([-math.log(y_pred[i][y]) for i, y in enumerate(y_true)])
def top_k_error(y_true, y_pred, k):
if k == y_pred.shape[1]:
return 0
max_rest = np.max(-np.partition(-y_pred, k)[:, k:], axis=1)
return 1 - np.mean((y_pred[np.arange(len(y_true)), y_true] > max_rest))
class TestEngine(unittest.TestCase): class TestEngine(unittest.TestCase):
def test_binary(self): def test_binary(self):
X, y = load_breast_cancer(True) X, y = load_breast_cancer(True)
...@@ -363,6 +370,56 @@ class TestEngine(unittest.TestCase): ...@@ -363,6 +370,56 @@ class TestEngine(unittest.TestCase):
ret = multi_logloss(y_test, gbm.predict(X_test, **pred_parameter)) ret = multi_logloss(y_test, gbm.predict(X_test, **pred_parameter))
self.assertLess(ret, 0.2) self.assertLess(ret, 0.2)
def test_multi_class_error(self):
X, y = load_digits(return_X_y=True)
params = {'objective': 'multiclass', 'num_classes': 10, 'metric': 'multi_error', 'num_leaves': 4, 'seed': 0,
'num_rounds': 30, 'verbose': -1}
lgb_data = lgb.Dataset(X, label=y)
results = {}
est = lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results)
predict_default = est.predict(X)
params = {'objective': 'multiclass', 'num_classes': 10, 'metric': 'multi_error', 'multi_error_top_k': 1,
'num_leaves': 4, 'seed': 0, 'num_rounds': 30, 'verbose': -1, 'metric_freq': 10}
results = {}
est = lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results)
predict_1 = est.predict(X)
# check that default gives same result as k = 1
np.testing.assert_array_almost_equal(predict_1, predict_default, 5)
# check against independent calculation for k = 1
err = top_k_error(y, predict_1, 1)
np.testing.assert_almost_equal(results['train']['multi_error'][-1], err, 5)
# check against independent calculation for k = 2
params = {'objective': 'multiclass', 'num_classes': 10, 'metric': 'multi_error', 'multi_error_top_k': 2,
'num_leaves': 4, 'seed': 0, 'num_rounds': 30, 'verbose': -1, 'metric_freq': 10}
results = {}
est = lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results)
predict_2 = est.predict(X)
err = top_k_error(y, predict_2, 2)
np.testing.assert_almost_equal(results['train']['multi_error@2'][-1], err, 5)
# check against independent calculation for k = 10
params = {'objective': 'multiclass', 'num_classes': 10, 'metric': 'multi_error', 'multi_error_top_k': 10,
'num_leaves': 4, 'seed': 0, 'num_rounds': 30, 'verbose': -1, 'metric_freq': 10}
results = {}
est = lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results)
predict_2 = est.predict(X)
err = top_k_error(y, predict_2, 10)
np.testing.assert_almost_equal(results['train']['multi_error@10'][-1], err, 5)
# check case where predictions are equal
X = np.array([[0, 0], [0, 0]])
y = np.array([0, 1])
lgb_data = lgb.Dataset(X, label=y)
params = {'objective': 'multiclass', 'num_classes': 2, 'metric': 'multi_error', 'multi_error_top_k': 1,
'num_leaves': 4, 'seed': 0, 'num_rounds': 1, 'verbose': -1, 'metric_freq': 10}
results = {}
lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results)
np.testing.assert_almost_equal(results['train']['multi_error'][-1], 1, 5)
lgb_data = lgb.Dataset(X, label=y)
params = {'objective': 'multiclass', 'num_classes': 2, 'metric': 'multi_error', 'multi_error_top_k': 2,
'num_leaves': 4, 'seed': 0, 'num_rounds': 1, 'verbose': -1, 'metric_freq': 10}
results = {}
lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results)
np.testing.assert_almost_equal(results['train']['multi_error@2'][-1], 0, 5)
def test_early_stopping(self): def test_early_stopping(self):
X, y = load_breast_cancer(True) X, y = load_breast_cancer(True)
params = { params = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment