"vscode:/vscode.git/clone" did not exist on "6f19edde528904ee3d61bdfa6fcf1e012a3ae2ed"
Unverified Commit a4478f7e authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] Fix training on subset constructed without params (#5213)

* Update basic.py

* Update test_engine.py

* Add return type annotation
parent 17bfe1a1
...@@ -1348,12 +1348,12 @@ class Dataset: ...@@ -1348,12 +1348,12 @@ class Dataset:
self._start_row += nrow self._start_row += nrow
return self return self
def get_params(self): def get_params(self) -> Dict[str, Any]:
"""Get the used parameters in the Dataset. """Get the used parameters in the Dataset.
Returns Returns
------- -------
params : dict or None params : dict
The used parameters in this Dataset object. The used parameters in this Dataset object.
""" """
if self.params is not None: if self.params is not None:
...@@ -1380,6 +1380,8 @@ class Dataset: ...@@ -1380,6 +1380,8 @@ class Dataset:
"weight_column", "weight_column",
"zero_as_missing") "zero_as_missing")
return {k: v for k, v in self.params.items() if k in dataset_params} return {k: v for k, v in self.params.items() if k in dataset_params}
else:
return {}
def _free_handle(self): def _free_handle(self):
if self.handle is not None: if self.handle is not None:
......
...@@ -1481,6 +1481,18 @@ def test_init_with_subset(): ...@@ -1481,6 +1481,18 @@ def test_init_with_subset():
assert subset_data_4.get_data() == "lgb_train_data.bin" assert subset_data_4.get_data() == "lgb_train_data.bin"
def test_training_on_constructed_subset_without_params():
X = np.random.random((100, 10))
y = np.random.random(100)
lgb_data = lgb.Dataset(X, y)
subset_indices = [1, 2, 3, 4]
subset = lgb_data.subset(subset_indices).construct()
bst = lgb.train({}, subset, num_boost_round=1)
assert subset.get_params() == {}
assert subset.num_data() == len(subset_indices)
assert bst.current_iteration() == 1
def generate_trainset_for_monotone_constraints_tests(x3_to_category=True): def generate_trainset_for_monotone_constraints_tests(x3_to_category=True):
number_of_dpoints = 3000 number_of_dpoints = 3000
x1_positively_correlated_with_y = np.random.random(size=number_of_dpoints) x1_positively_correlated_with_y = np.random.random(size=number_of_dpoints)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment