Unverified Commit 86530988 authored by sbruch's avatar sbruch Committed by GitHub
Browse files

Implementation of XE_NDCG_MART for the ranking task (#2620)

* Implementation of XE_NDCG loss function for ranking.

* Add citation

* Check in example usage for xe_ndcg loss.

* Seed the generator when a seed is provided in the config. Add unit-tests for xe_ndcg

* Update documentation

* Fix indentation

* Address issues raised by reviewers.

* Clean up include statements.

* Fix issues raised by reviewers.

* Regenerate parameters.rst

* Add a note to explain that reproducing xe_ndcg results requires num_threads to be one.

* Introduce objective_seed and use that in rank_xendcg instead of directly using seed

* Change default value of objective_seed
parent ef0b2d82
...@@ -51,7 +51,7 @@ Core Parameters ...@@ -51,7 +51,7 @@ Core Parameters
- **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent functions - **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent functions
- ``objective`` :raw-html:`<a id="objective" title="Permalink to this parameter" href="#objective">&#x1F517;&#xFE0E;</a>`, default = ``regression``, type = enum, options: ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gamma``, ``tweedie``, ``binary``, ``multiclass``, ``multiclassova``, ``cross_entropy``, ``cross_entropy_lambda``, ``lambdarank``, aliases: ``objective_type``, ``app``, ``application`` - ``objective`` :raw-html:`<a id="objective" title="Permalink to this parameter" href="#objective">&#x1F517;&#xFE0E;</a>`, default = ``regression``, type = enum, options: ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gamma``, ``tweedie``, ``binary``, ``multiclass``, ``multiclassova``, ``cross_entropy``, ``cross_entropy_lambda``, ``lambdarank``, ``rank_xendcg``, aliases: ``objective_type``, ``app``, ``application``
- regression application - regression application
...@@ -99,6 +99,10 @@ Core Parameters ...@@ -99,6 +99,10 @@ Core Parameters
- all values in ``label`` must be smaller than number of elements in ``label_gain`` - all values in ``label`` must be smaller than number of elements in ``label_gain``
- ``rank_xendcg``, `XE_NDCG_MART <https://arxiv.org/abs/1911.09798>`__ ranking objective function, aliases: ``xendcg``, ``xe_ndcg``, ``xe_ndcg_mart``, ``xendcg_mart``
- to obtain reproducible results, you should disable parallelism by setting ``num_threads`` to 1
- ``boosting`` :raw-html:`<a id="boosting" title="Permalink to this parameter" href="#boosting">&#x1F517;&#xFE0E;</a>`, default = ``gbdt``, type = enum, options: ``gbdt``, ``rf``, ``dart``, ``goss``, aliases: ``boosting_type``, ``boost`` - ``boosting`` :raw-html:`<a id="boosting" title="Permalink to this parameter" href="#boosting">&#x1F517;&#xFE0E;</a>`, default = ``gbdt``, type = enum, options: ``gbdt``, ``rf``, ``dart``, ``goss``, aliases: ``boosting_type``, ``boost``
- ``gbdt``, traditional Gradient Boosting Decision Tree, aliases: ``gbrt`` - ``gbdt``, traditional Gradient Boosting Decision Tree, aliases: ``gbrt``
...@@ -852,6 +856,12 @@ Objective Parameters ...@@ -852,6 +856,12 @@ Objective Parameters
- separate by ``,`` - separate by ``,``
- ``objective_seed`` :raw-html:`<a id="objective_seed" title="Permalink to this parameter" href="#objective_seed">&#x1F517;&#xFE0E;</a>`, default = ``5``, type = int
- random seed for objectives
- used only in the ``rank_xendcg`` objective
Metric Parameters Metric Parameters
----------------- -----------------
......
XE_NDCG Ranking Example
=======================
Here is an example for LightGBM to train a ranking model with the [XE_NDCG loss](https://arxiv.org/abs/1911.09798).
***You must follow the [installation instructions](https://lightgbm.readthedocs.io/en/latest/Installation-Guide.html)
for the following commands to work. The `lightgbm` binary must be built and available at the root of this project.***
Training
--------
Run the following command in this folder:
```bash
"../../lightgbm" config=train.conf
```
Prediction
----------
You should finish training first.
Run the following command in this folder:
```bash
"../../lightgbm" config=predict.conf
```
Data Format
-----------
To learn more about the query format used in this example, check out the
[query data format](https://lightgbm.readthedocs.io/en/latest/Parameters.html#query-data).
task = predict
data = rank.test
input_model= LightGBM_model.txt
This diff is collapsed.
12
19
18
10
15
15
22
23
18
16
16
11
6
13
17
21
20
16
13
16
21
15
10
19
10
13
18
17
23
24
16
13
17
24
17
10
17
15
18
16
9
9
21
14
13
13
13
10
10
6
This diff is collapsed.
1
13
5
8
19
12
18
5
14
13
8
9
16
11
21
14
21
9
14
11
20
18
13
20
22
22
13
17
10
13
12
13
13
23
18
13
20
12
22
14
13
23
13
14
14
5
13
15
14
14
16
16
15
21
22
10
22
18
25
16
12
12
15
15
25
13
9
12
8
16
25
19
24
12
16
10
16
9
17
15
7
9
15
14
16
17
8
17
12
18
23
10
12
12
4
14
12
15
27
16
20
13
19
13
17
17
16
12
15
14
14
19
12
23
18
16
9
23
11
15
8
10
10
16
11
15
22
16
17
23
16
22
17
14
12
14
20
15
17
15
15
22
9
21
9
17
16
15
13
13
15
14
18
21
14
17
15
14
16
12
17
19
16
11
18
11
13
14
9
16
15
16
25
9
13
22
16
18
20
14
11
9
16
19
19
11
11
13
14
14
13
16
6
21
16
12
16
11
24
12
10
# task type, support train and predict
task = train
# boosting type, support gbdt for now, alias: boosting, boost
boosting_type = gbdt
# application type, support following application
# regression , regression task
# binary , binary classification task
# lambdarank , lambdarank task
# alias: application, app
objective = rank_xendcg
# eval metrics, support multi metric, delimite by ',' , support following metrics
# l1
# l2 , default metric for regression
# ndcg , default metric for lambdarank
# auc
# binary_logloss , default metric for binary
# binary_error
metric = ndcg
# evaluation position for ndcg metric, alias : ndcg_at
ndcg_eval_at = 1,3,5
# frequence for metric output
metric_freq = 1
# true if need output metric for training data, alias: tranining_metric, train_metric
is_training_metric = true
# number of bins for feature bucket, 255 is a recommend setting, it can save memories, and also has good accuracy.
max_bin = 255
# training data
# if exsting weight file, should name to "rank.train.weight"
# if exsting query file, should name to "rank.train.query"
# alias: train_data, train
data = rank.train
# validation data, support multi validation data, separated by ','
# if exsting weight file, should name to "rank.test.weight"
# if exsting query file, should name to "rank.test.query"
# alias: valid, test, test_data,
valid_data = rank.test
# number of trees(iterations), alias: num_tree, num_iteration, num_iterations, num_round, num_rounds
num_trees = 100
# shrinkage rate , alias: shrinkage_rate
learning_rate = 0.1
# number of leaves for one tree, alias: num_leaf
num_leaves = 31
# type of tree learner, support following types:
# serial , single machine version
# feature , use feature parallel to train
# data , use data parallel to train
# voting , use voting based parallel to train
# alias: tree
tree_learner = serial
# Set num_threads and objective_seed for stable unit-tests. Comment out otherwise.
num_threads = 1
objective_seed = 1025
# feature sub-sample, will random select 80% feature to train on each iteration
# alias: sub_feature
feature_fraction = 1.0
# Support bagging (data sub-sample), will perform bagging every 5 iterations
bagging_freq = 1
# Bagging farction, will random select 80% data on bagging
# alias: sub_row
bagging_fraction = 0.9
# minimal number data for one leaf, use this to deal with over-fit
# alias : min_data_per_leaf, min_data
min_data_in_leaf = 50
# minimal sum hessians for one leaf, use this to deal with over-fit
min_sum_hessian_in_leaf = 5.0
# save memory and faster speed for sparse feature, alias: is_sparse
is_enable_sparse = true
# when data is bigger than memory size, set this to true. otherwise set false will have faster speed
# alias: two_round_loading, two_round
use_two_round_loading = false
# true if need to save data to binary file and application will auto load data from binary file next time
# alias: is_save_binary, save_binary
is_save_binary_file = false
# output model file
output_model = LightGBM_model.txt
# support continuous train from trained gbdt model
# input_model= trained_model.txt
# output prediction file for predict task
# output_result= prediction.txt
# support continuous train from initial score file
# input_init_score= init_score.txt
# number of machines in parallel training, alias: num_machine
num_machines = 1
# local listening port in parallel training, alias: local_port
local_listen_port = 12400
# machines list file for parallel training, alias: mlist
machine_list_file = mlist.txt
...@@ -102,7 +102,7 @@ struct Config { ...@@ -102,7 +102,7 @@ struct Config {
// [doc-only] // [doc-only]
// type = enum // type = enum
// options = regression, regression_l1, huber, fair, poisson, quantile, mape, gamma, tweedie, binary, multiclass, multiclassova, cross_entropy, cross_entropy_lambda, lambdarank // options = regression, regression_l1, huber, fair, poisson, quantile, mape, gamma, tweedie, binary, multiclass, multiclassova, cross_entropy, cross_entropy_lambda, lambdarank, rank_xendcg
// alias = objective_type, app, application // alias = objective_type, app, application
// desc = regression application // desc = regression application
// descl2 = ``regression``, L2 loss, aliases: ``regression_l2``, ``l2``, ``mean_squared_error``, ``mse``, ``l2_root``, ``root_mean_squared_error``, ``rmse`` // descl2 = ``regression``, L2 loss, aliases: ``regression_l2``, ``l2``, ``mean_squared_error``, ``mse``, ``l2_root``, ``root_mean_squared_error``, ``rmse``
...@@ -127,6 +127,8 @@ struct Config { ...@@ -127,6 +127,8 @@ struct Config {
// descl2 = label should be ``int`` type in lambdarank tasks, and larger number represents the higher relevance (e.g. 0:bad, 1:fair, 2:good, 3:perfect) // descl2 = label should be ``int`` type in lambdarank tasks, and larger number represents the higher relevance (e.g. 0:bad, 1:fair, 2:good, 3:perfect)
// descl2 = `label_gain <#objective-parameters>`__ can be used to set the gain (weight) of ``int`` label // descl2 = `label_gain <#objective-parameters>`__ can be used to set the gain (weight) of ``int`` label
// descl2 = all values in ``label`` must be smaller than number of elements in ``label_gain`` // descl2 = all values in ``label`` must be smaller than number of elements in ``label_gain``
// desc = ``rank_xendcg``, `XE_NDCG_MART <https://arxiv.org/abs/1911.09798>`__ ranking objective function, aliases: ``xendcg``, ``xe_ndcg``, ``xe_ndcg_mart``, ``xendcg_mart``
// descl2 = to obtain reproducible results, you should disable parallelism by setting ``num_threads`` to 1
std::string objective = "regression"; std::string objective = "regression";
// [doc-only] // [doc-only]
...@@ -754,6 +756,10 @@ struct Config { ...@@ -754,6 +756,10 @@ struct Config {
// desc = separate by ``,`` // desc = separate by ``,``
std::vector<double> label_gain; std::vector<double> label_gain;
// desc = random seed for objectives
// desc = used only in the ``rank_xendcg`` objective
int objective_seed = 5;
#pragma endregion #pragma endregion
#pragma region Metric Parameters #pragma region Metric Parameters
...@@ -1004,6 +1010,9 @@ inline std::string ParseObjectiveAlias(const std::string& type) { ...@@ -1004,6 +1010,9 @@ inline std::string ParseObjectiveAlias(const std::string& type) {
return "cross_entropy_lambda"; return "cross_entropy_lambda";
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) { } else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
return "mape"; return "mape";
} else if (type == std::string("rank_xendcg") || type == std::string("xendcg") || type == std::string("xe_ndcg")
|| type == std::string("xe_ndcg_mart") || type == std::string("xendcg_mart")) {
return "rank_xendcg";
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) { } else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
return "custom"; return "custom";
} }
...@@ -1019,7 +1028,8 @@ inline std::string ParseMetricAlias(const std::string& type) { ...@@ -1019,7 +1028,8 @@ inline std::string ParseMetricAlias(const std::string& type) {
return "l1"; return "l1";
} else if (type == std::string("binary_logloss") || type == std::string("binary")) { } else if (type == std::string("binary_logloss") || type == std::string("binary")) {
return "binary_logloss"; return "binary_logloss";
} else if (type == std::string("ndcg") || type == std::string("lambdarank")) { } else if (type == std::string("ndcg") || type == std::string("lambdarank") || type == std::string("rank_xendcg")
|| type == std::string("xendcg") || type == std::string("xe_ndcg") || type == std::string("xe_ndcg_mart") || type == std::string("xendcg_mart")) {
return "ndcg"; return "ndcg";
} else if (type == std::string("map") || type == std::string("mean_average_precision")) { } else if (type == std::string("map") || type == std::string("mean_average_precision")) {
return "map"; return "map";
......
...@@ -192,6 +192,7 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -192,6 +192,7 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) {
bagging_seed = static_cast<int>(rand.NextShort(0, int_max)); bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
drop_seed = static_cast<int>(rand.NextShort(0, int_max)); drop_seed = static_cast<int>(rand.NextShort(0, int_max));
feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max)); feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
objective_seed = static_cast<int>(rand.NextShort(0, int_max));
} }
GetTaskType(params, &task); GetTaskType(params, &task);
......
...@@ -272,6 +272,7 @@ std::unordered_set<std::string> Config::parameter_set({ ...@@ -272,6 +272,7 @@ std::unordered_set<std::string> Config::parameter_set({
"max_position", "max_position",
"lambdamart_norm", "lambdamart_norm",
"label_gain", "label_gain",
"objective_seed",
"metric", "metric",
"metric_freq", "metric_freq",
"is_provide_training_metric", "is_provide_training_metric",
...@@ -553,6 +554,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -553,6 +554,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
label_gain = Common::StringToArray<double>(tmp_str, ','); label_gain = Common::StringToArray<double>(tmp_str, ',');
} }
GetInt(params, "objective_seed", &objective_seed);
GetInt(params, "metric_freq", &metric_freq); GetInt(params, "metric_freq", &metric_freq);
CHECK(metric_freq >0); CHECK(metric_freq >0);
...@@ -688,6 +691,7 @@ std::string Config::SaveMembersToString() const { ...@@ -688,6 +691,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[max_position: " << max_position << "]\n"; str_buf << "[max_position: " << max_position << "]\n";
str_buf << "[lambdamart_norm: " << lambdamart_norm << "]\n"; str_buf << "[lambdamart_norm: " << lambdamart_norm << "]\n";
str_buf << "[label_gain: " << Common::Join(label_gain, ",") << "]\n"; str_buf << "[label_gain: " << Common::Join(label_gain, ",") << "]\n";
str_buf << "[objective_seed: " << objective_seed << "]\n";
str_buf << "[metric_freq: " << metric_freq << "]\n"; str_buf << "[metric_freq: " << metric_freq << "]\n";
str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n"; str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n";
str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n"; str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n";
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "binary_objective.hpp" #include "binary_objective.hpp"
#include "multiclass_objective.hpp" #include "multiclass_objective.hpp"
#include "rank_objective.hpp" #include "rank_objective.hpp"
#include "rank_xendcg_objective.hpp"
#include "regression_objective.hpp" #include "regression_objective.hpp"
#include "xentropy_objective.hpp" #include "xentropy_objective.hpp"
...@@ -29,6 +30,8 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& ...@@ -29,6 +30,8 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
return new BinaryLogloss(config); return new BinaryLogloss(config);
} else if (type == std::string("lambdarank")) { } else if (type == std::string("lambdarank")) {
return new LambdarankNDCG(config); return new LambdarankNDCG(config);
} else if (type == std::string("rank_xendcg")) {
return new RankXENDCG(config);
} else if (type == std::string("multiclass")) { } else if (type == std::string("multiclass")) {
return new MulticlassSoftmax(config); return new MulticlassSoftmax(config);
} else if (type == std::string("multiclassova")) { } else if (type == std::string("multiclassova")) {
...@@ -68,6 +71,8 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& ...@@ -68,6 +71,8 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
return new BinaryLogloss(strs); return new BinaryLogloss(strs);
} else if (type == std::string("lambdarank")) { } else if (type == std::string("lambdarank")) {
return new LambdarankNDCG(strs); return new LambdarankNDCG(strs);
} else if (type == std::string("rank_xendcg")) {
return new RankXENDCG(strs);
} else if (type == std::string("multiclass")) { } else if (type == std::string("multiclass")) {
return new MulticlassSoftmax(strs); return new MulticlassSoftmax(strs);
} else if (type == std::string("multiclassova")) { } else if (type == std::string("multiclassova")) {
......
/*!
* Copyright (c) 2019 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_OBJECTIVE_RANK_XENDCG_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_RANK_XENDCG_OBJECTIVE_HPP_
#include <LightGBM/objective_function.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
#include <string>
#include <vector>
namespace LightGBM {
/*!
* \brief Implementation of the learning-to-rank objective function, XE_NDCG [arxiv.org/abs/1911.09798].
*/
class RankXENDCG: public ObjectiveFunction {
public:
explicit RankXENDCG(const Config& config) {
rand_ = new Random(config.objective_seed);
}
explicit RankXENDCG(const std::vector<std::string>&) {
rand_ = new Random();
}
~RankXENDCG() {
}
void Init(const Metadata& metadata, data_size_t) override {
// get label
label_ = metadata.label();
// get boundries
query_boundaries_ = metadata.query_boundaries();
if (query_boundaries_ == nullptr) {
Log::Fatal("RankXENDCG tasks require query information");
}
num_queries_ = metadata.num_queries();
}
void GetGradients(const double* score, score_t* gradients,
score_t* hessians) const override {
#pragma omp parallel for schedule(guided)
for (data_size_t i = 0; i < num_queries_; ++i) {
GetGradientsForOneQuery(score, gradients, hessians, i);
}
}
inline void GetGradientsForOneQuery(
const double* score,
score_t* lambdas, score_t* hessians, data_size_t query_id) const {
// get doc boundary for current query
const data_size_t start = query_boundaries_[query_id];
const data_size_t cnt =
query_boundaries_[query_id + 1] - query_boundaries_[query_id];
// add pointers with offset
const label_t* label = label_ + start;
score += start;
lambdas += start;
hessians += start;
// Turn scores into a probability distribution using Softmax.
std::vector<double> rho(cnt);
Common::Softmax(score, &rho[0], cnt);
// Prepare a vector of gammas, a parameter of the loss.
std::vector<double> gammas(cnt);
for (data_size_t i = 0; i < cnt; ++i) {
gammas[i] = rand_->NextFloat();
}
// Skip query if sum of labels is 0.
float sum_labels = 0;
for (data_size_t i = 0; i < cnt; ++i) {
sum_labels += phi(label[i], gammas[i]);
}
if (sum_labels == 0) {
return;
}
// Approximate gradients and inverse Hessian.
// First order terms.
std::vector<double> L1s(cnt);
for (data_size_t i = 0; i < cnt; ++i) {
L1s[i] = -phi(label[i], gammas[i])/sum_labels + rho[i];
}
// Second-order terms.
std::vector<double> L2s(cnt);
for (data_size_t i = 0; i < cnt; ++i) {
for (data_size_t j = 0; j < cnt; ++j) {
if (i == j) continue;
L2s[i] += L1s[j] / (1 - rho[j]);
}
}
// Third-order terms.
std::vector<double> L3s(cnt);
for (data_size_t i = 0; i < cnt; ++i) {
for (data_size_t j = 0; j < cnt; ++j) {
if (i == j) continue;
L3s[i] += rho[j] * L2s[j] / (1 - rho[j]);
}
}
// Finally, prepare lambdas and hessians.
for (data_size_t i = 0; i < cnt; ++i) {
lambdas[i] = static_cast<score_t>(
L1s[i] + rho[i]*L2s[i] + rho[i]*L3s[i]);
hessians[i] = static_cast<score_t>(rho[i] * (1.0 - rho[i]));
}
}
double phi(const label_t l, double g) const {
return Common::Pow(2, l) - g;
}
const char* GetName() const override {
return "rank_xendcg";
}
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName();
return str_buf.str();
}
bool NeedAccuratePrediction() const override { return false; }
private:
/*! \brief Number of queries */
data_size_t num_queries_;
/*! \brief Pointer of label */
const label_t* label_;
/*! \brief Query boundries */
const data_size_t* query_boundaries_;
/*! \brief Pseudo-random number generator */
Random* rand_;
};
} // namespace LightGBM
#endif // LightGBM_OBJECTIVE_RANK_XENDCG_OBJECTIVE_HPP_
...@@ -110,3 +110,15 @@ class TestEngine(unittest.TestCase): ...@@ -110,3 +110,15 @@ class TestEngine(unittest.TestCase):
sk_pred = gbm.predict(X_test) sk_pred = gbm.predict(X_test)
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred) fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
fd.file_load_check(lgb_train, '.train') fd.file_load_check(lgb_train, '.train')
def test_xendcg(self):
fd = FileLoader('../../examples/xendcg', 'rank')
X_train, y_train, _ = fd.load_dataset('.train', is_sparse=True)
X_test, _, X_test_fn = fd.load_dataset('.test', is_sparse=True)
group_train = fd.load_field('.train.query')
lgb_train = lgb.Dataset(X_train, y_train, group=group_train)
gbm = lgb.LGBMRanker(**fd.params)
gbm.fit(X_train, y_train, group=group_train)
sk_pred = gbm.predict(X_test)
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
fd.file_load_check(lgb_train, '.train')
...@@ -117,6 +117,21 @@ class TestSklearn(unittest.TestCase): ...@@ -117,6 +117,21 @@ class TestSklearn(unittest.TestCase):
self.assertGreater(gbm.best_score_['valid_0']['ndcg@1'], 0.6333) self.assertGreater(gbm.best_score_['valid_0']['ndcg@1'], 0.6333)
self.assertGreater(gbm.best_score_['valid_0']['ndcg@3'], 0.6048) self.assertGreater(gbm.best_score_['valid_0']['ndcg@3'], 0.6048)
def test_xendcg(self):
dir_path = os.path.dirname(os.path.realpath(__file__))
X_train, y_train = load_svmlight_file(os.path.join(dir_path, '../../examples/xendcg/rank.train'))
X_test, y_test = load_svmlight_file(os.path.join(dir_path, '../../examples/xendcg/rank.test'))
q_train = np.loadtxt(os.path.join(dir_path, '../../examples/xendcg/rank.train.query'))
q_test = np.loadtxt(os.path.join(dir_path, '../../examples/xendcg/rank.test.query'))
gbm = lgb.LGBMRanker(n_estimators=50, objective='rank_xendcg', random_state=5, n_jobs=1)
gbm.fit(X_train, y_train, group=q_train, eval_set=[(X_test, y_test)],
eval_group=[q_test], eval_at=[1, 3], early_stopping_rounds=10, verbose=False,
eval_metric='ndcg',
callbacks=[lgb.reset_parameter(learning_rate=lambda x: max(0.01, 0.1 - 0.01 * x))])
self.assertLessEqual(gbm.best_iteration_, 24)
self.assertGreater(gbm.best_score_['valid_0']['ndcg@1'], 0.6579)
self.assertGreater(gbm.best_score_['valid_0']['ndcg@3'], 0.6421)
def test_regression_with_custom_objective(self): def test_regression_with_custom_objective(self):
X, y = load_boston(True) X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
......
...@@ -237,6 +237,7 @@ ...@@ -237,6 +237,7 @@
<ClInclude Include="..\src\network\socket_wrapper.hpp" /> <ClInclude Include="..\src\network\socket_wrapper.hpp" />
<ClInclude Include="..\src\objective\binary_objective.hpp" /> <ClInclude Include="..\src\objective\binary_objective.hpp" />
<ClInclude Include="..\src\objective\rank_objective.hpp" /> <ClInclude Include="..\src\objective\rank_objective.hpp" />
<ClInclude Include="..\src\objective\rank_xendcg_objective.hpp" />
<ClInclude Include="..\src\objective\regression_objective.hpp" /> <ClInclude Include="..\src\objective\regression_objective.hpp" />
<ClInclude Include="..\src\objective\multiclass_objective.hpp" /> <ClInclude Include="..\src\objective\multiclass_objective.hpp" />
<ClInclude Include="..\src\objective\xentropy_objective.hpp" /> <ClInclude Include="..\src\objective\xentropy_objective.hpp" />
...@@ -284,4 +285,4 @@ ...@@ -284,4 +285,4 @@
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets"> <ImportGroup Label="ExtensionTargets">
</ImportGroup> </ImportGroup>
</Project> </Project>
\ No newline at end of file
...@@ -87,6 +87,9 @@ ...@@ -87,6 +87,9 @@
<ClInclude Include="..\src\objective\rank_objective.hpp"> <ClInclude Include="..\src\objective\rank_objective.hpp">
<Filter>src\objective</Filter> <Filter>src\objective</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\src\objective\rank_xendcg_objective.hpp">
<Filter>src\objective</Filter>
</ClInclude>
<ClInclude Include="..\src\objective\regression_objective.hpp"> <ClInclude Include="..\src\objective\regression_objective.hpp">
<Filter>src\objective</Filter> <Filter>src\objective</Filter>
</ClInclude> </ClInclude>
...@@ -306,4 +309,4 @@ ...@@ -306,4 +309,4 @@
<Filter>src\io</Filter> <Filter>src\io</Filter>
</ClCompile> </ClCompile>
</ItemGroup> </ItemGroup>
</Project> </Project>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment