Unverified Commit 9f6e4413 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] keep consistent state for Dataset fields (#2390)

* keep consistent state for Dataset fields

* hotfix
parent f52be9be
......@@ -1339,6 +1339,7 @@ class Dataset(object):
if self.handle is not None:
label = list_to_1d_numpy(_label_from_pandas(label), name='label')
self.set_field('label', label)
self.label = self.get_field('label') # original values can be modified at cpp side
return self
def set_weight(self, weight):
......@@ -1360,6 +1361,7 @@ class Dataset(object):
if self.handle is not None and weight is not None:
weight = list_to_1d_numpy(weight, name='weight')
self.set_field('weight', weight)
self.weight = self.get_field('weight') # original values can be modified at cpp side
return self
def set_init_score(self, init_score):
......@@ -1379,6 +1381,7 @@ class Dataset(object):
if self.handle is not None and init_score is not None:
init_score = list_to_1d_numpy(init_score, np.float64, name='init_score')
self.set_field('init_score', init_score)
self.init_score = self.get_field('init_score') # original values can be modified at cpp side
return self
def set_group(self, group):
......
......@@ -281,3 +281,34 @@ class TestBasic(unittest.TestCase):
with open(p2name, 'rt') as f:
p2txt = f.read()
self.assertEqual(p1txt, p2txt)
def test_consistent_state_for_dataset_fields(self):
def check_asserts(data):
np.testing.assert_allclose(data.label, data.get_label())
np.testing.assert_allclose(data.label, data.get_field('label'))
self.assertFalse(np.isnan(data.label[0]))
self.assertFalse(np.isinf(data.label[1]))
np.testing.assert_allclose(data.weight, data.get_weight())
np.testing.assert_allclose(data.weight, data.get_field('weight'))
self.assertFalse(np.isnan(data.weight[0]))
self.assertFalse(np.isinf(data.weight[1]))
np.testing.assert_allclose(data.init_score, data.get_init_score())
np.testing.assert_allclose(data.init_score, data.get_field('init_score'))
self.assertFalse(np.isnan(data.init_score[0]))
self.assertFalse(np.isinf(data.init_score[1]))
self.assertTrue(np.all(np.isclose([data.label[0], data.weight[0], data.init_score[0]],
data.label[0])))
self.assertAlmostEqual(data.label[1], data.weight[1])
X, y = load_breast_cancer(True)
sequence = np.ones(y.shape[0])
sequence[0] = np.nan
sequence[1] = np.inf
lgb_data = lgb.Dataset(X, sequence, weight=sequence, init_score=sequence).construct()
check_asserts(lgb_data)
lgb_data = lgb.Dataset(X, y).construct()
lgb_data.set_label(sequence)
lgb_data.set_weight(sequence)
lgb_data.set_init_score(sequence)
check_asserts(lgb_data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment