Commit 616388e0 authored by Guolin Ke's avatar Guolin Ke
Browse files

[python-package] pass params of engine.train and engine.cv to Dataset.

parent 292f972e
...@@ -927,6 +927,12 @@ class Dataset(object): ...@@ -927,6 +927,12 @@ class Dataset(object):
ret._set_predictor(self._predictor) ret._set_predictor(self._predictor)
return ret return ret
def _update_params(self, params):
if not self.params:
self.params = params
else:
self.params.update(params)
def construct(self): def construct(self):
""" """
Lazy init Lazy init
......
...@@ -92,6 +92,7 @@ def train(params, train_set, num_boost_round=100, ...@@ -92,6 +92,7 @@ def train(params, train_set, num_boost_round=100,
if not isinstance(train_set, Dataset): if not isinstance(train_set, Dataset):
raise TypeError("Traninig only accepts Dataset object") raise TypeError("Traninig only accepts Dataset object")
train_set._update_params(params)
train_set._set_predictor(predictor) train_set._set_predictor(predictor)
train_set.set_feature_name(feature_name) train_set.set_feature_name(feature_name)
train_set.set_categorical_feature(categorical_feature) train_set.set_categorical_feature(categorical_feature)
...@@ -120,7 +121,8 @@ def train(params, train_set, num_boost_round=100, ...@@ -120,7 +121,8 @@ def train(params, train_set, num_boost_round=100,
name_valid_sets.append(valid_names[i]) name_valid_sets.append(valid_names[i])
else: else:
name_valid_sets.append('valid_'+str(i)) name_valid_sets.append('valid_'+str(i))
for valid_data in valid_sets:
valid_data._update_params(params)
"""process callbacks""" """process callbacks"""
if callbacks is None: if callbacks is None:
callbacks = set() callbacks = set()
...@@ -332,7 +334,7 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False, ...@@ -332,7 +334,7 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False,
predictor = init_model._to_predictor() predictor = init_model._to_predictor()
else: else:
predictor = None predictor = None
train_set._update_params(params)
train_set._set_predictor(predictor) train_set._set_predictor(predictor)
train_set.set_feature_name(feature_name) train_set.set_feature_name(feature_name)
train_set.set_categorical_feature(categorical_feature) train_set.set_categorical_feature(categorical_feature)
......
...@@ -273,11 +273,11 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) ...@@ -273,11 +273,11 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf); GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0); CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0);
GetDouble(params, "lambda_l1", &lambda_l1); GetDouble(params, "lambda_l1", &lambda_l1);
CHECK(lambda_l1 >= 0.0f) CHECK(lambda_l1 >= 0.0f);
GetDouble(params, "lambda_l2", &lambda_l2); GetDouble(params, "lambda_l2", &lambda_l2);
CHECK(lambda_l2 >= 0.0f) CHECK(lambda_l2 >= 0.0f);
GetDouble(params, "min_gain_to_split", &min_gain_to_split); GetDouble(params, "min_gain_to_split", &min_gain_to_split);
CHECK(min_gain_to_split >= 0.0f) CHECK(min_gain_to_split >= 0.0f);
GetInt(params, "num_leaves", &num_leaves); GetInt(params, "num_leaves", &num_leaves);
CHECK(num_leaves > 1); CHECK(num_leaves > 1);
GetInt(params, "feature_fraction_seed", &feature_fraction_seed); GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment