Commit f83e4aa4 authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[test] added regression test for subsetting with group (#1654)

parent 5457ef6b
...@@ -6,7 +6,7 @@ import unittest ...@@ -6,7 +6,7 @@ import unittest
import lightgbm as lgb import lightgbm as lgb
import numpy as np import numpy as np
from sklearn.datasets import load_breast_cancer, dump_svmlight_file from sklearn.datasets import load_breast_cancer, dump_svmlight_file, load_svmlight_file
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
...@@ -78,3 +78,14 @@ class TestBasic(unittest.TestCase): ...@@ -78,3 +78,14 @@ class TestBasic(unittest.TestCase):
train_data.construct() train_data.construct()
valid_data.construct() valid_data.construct()
def test_subset_group(self):
X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/lambdarank/rank.train'))
q_train = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/lambdarank/rank.train.query'))
lgb_train = lgb.Dataset(X_train, y_train, group=q_train)
self.assertEqual(len(lgb_train.get_group()), 201)
subset = lgb_train.subset(list(lgb.compat.range_(10))).construct()
subset_group = subset.get_group()
self.assertEqual(len(subset_group), 2)
self.assertEqual(subset_group[0], 1)
self.assertEqual(subset_group[1], 9)
...@@ -424,9 +424,16 @@ class TestEngine(unittest.TestCase): ...@@ -424,9 +424,16 @@ class TestEngine(unittest.TestCase):
# lambdarank # lambdarank
X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/lambdarank/rank.train')) X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/lambdarank/rank.train'))
q_train = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/lambdarank/rank.train.query')) q_train = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/lambdarank/rank.train.query'))
params_lambdarank = {'objective': 'lambdarank', 'verbose': -1} params_lambdarank = {'objective': 'lambdarank', 'verbose': -1, 'eval_at': 3}
lgb_train = lgb.Dataset(X_train, y_train, group=q_train) lgb_train = lgb.Dataset(X_train, y_train, group=q_train)
lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3, stratified=False, metrics='l2', verbose_eval=False) # ... with NDCG (default) metric
cv_res = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3, stratified=False, verbose_eval=False)
self.assertEqual(len(cv_res), 2)
self.assertFalse(np.isnan(cv_res['ndcg@3-mean']).any())
# ... with l2 metric
cv_res = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3, stratified=False, metrics='l2', verbose_eval=False)
self.assertEqual(len(cv_res), 2)
self.assertFalse(np.isnan(cv_res['l2-mean']).any())
def test_feature_name(self): def test_feature_name(self):
X, y = load_boston(True) X, y = load_boston(True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment