Commit 616388e0 authored by Guolin Ke's avatar Guolin Ke
Browse files

[python-package] pass params of engine.train and engine.cv to Dataset.

parent 292f972e
......@@ -927,6 +927,12 @@ class Dataset(object):
ret._set_predictor(self._predictor)
return ret
def _update_params(self, params):
if not self.params:
self.params = params
else:
self.params.update(params)
def construct(self):
"""
Lazy init
......
......@@ -92,6 +92,7 @@ def train(params, train_set, num_boost_round=100,
if not isinstance(train_set, Dataset):
raise TypeError("Traninig only accepts Dataset object")
train_set._update_params(params)
train_set._set_predictor(predictor)
train_set.set_feature_name(feature_name)
train_set.set_categorical_feature(categorical_feature)
......@@ -120,7 +121,8 @@ def train(params, train_set, num_boost_round=100,
name_valid_sets.append(valid_names[i])
else:
name_valid_sets.append('valid_'+str(i))
for valid_data in valid_sets:
valid_data._update_params(params)
"""process callbacks"""
if callbacks is None:
callbacks = set()
......@@ -332,7 +334,7 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False,
predictor = init_model._to_predictor()
else:
predictor = None
train_set._update_params(params)
train_set._set_predictor(predictor)
train_set.set_feature_name(feature_name)
train_set.set_categorical_feature(categorical_feature)
......
......@@ -273,12 +273,12 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0);
GetDouble(params, "lambda_l1", &lambda_l1);
CHECK(lambda_l1 >= 0.0f)
GetDouble(params, "lambda_l2", &lambda_l2);
CHECK(lambda_l2 >= 0.0f)
GetDouble(params, "min_gain_to_split", &min_gain_to_split);
CHECK(min_gain_to_split >= 0.0f)
GetInt(params, "num_leaves", &num_leaves);
CHECK(lambda_l1 >= 0.0f);
GetDouble(params, "lambda_l2", &lambda_l2);
CHECK(lambda_l2 >= 0.0f);
GetDouble(params, "min_gain_to_split", &min_gain_to_split);
CHECK(min_gain_to_split >= 0.0f);
GetInt(params, "num_leaves", &num_leaves);
CHECK(num_leaves > 1);
GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
GetDouble(params, "feature_fraction", &feature_fraction);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment