"tests/vscode:/vscode.git/clone" did not exist on "6f0bc4813292873abc9317e4ee8b7413bae0e1cd"
Commit c04c0917 authored by Nikita Titov's avatar Nikita Titov Committed by Tsukasa OMOTO
Browse files

[python] minor fixes in sklearn wrapper (#1572)

* refined random_state description

* use neutral alias and fixed desc of eval_at param, since metric can be not only ndcg

* simplified checks

* consider verbose alias

* fixed declaring function every loop iteration
parent d93fbba3
......@@ -188,7 +188,7 @@ class LGBMModel(_LGBMModelBase):
L2 regularization term on weights.
random_state : int or None, optional (default=None)
Random number seed.
Will use default seeds in c++ code if set to None.
If None, default seeds in C++ code will be used.
n_jobs : int, optional (default=-1)
Number of parallel threads.
silent : bool, optional (default=True)
......@@ -398,7 +398,7 @@ class LGBMModel(_LGBMModelBase):
evals_result = {}
params = self.get_params()
# user can set verbose with kwargs, it has higher priority
if 'verbose' not in params and self.silent:
if not any(verbose_alias in params for verbose_alias in ('verbose', 'verbosity')) and self.silent:
params['verbose'] = 0
params.pop('silent', None)
params.pop('importance_type', None)
......@@ -407,7 +407,7 @@ class LGBMModel(_LGBMModelBase):
if self._n_classes is not None and self._n_classes > 2:
params['num_class'] = self._n_classes
if hasattr(self, '_eval_at'):
params['ndcg_eval_at'] = self._eval_at
params['eval_at'] = self._eval_at
params['objective'] = self._objective
if self._fobj:
params['objective'] = 'None' # objective = nullptr for unknown objective
......@@ -440,14 +440,8 @@ class LGBMModel(_LGBMModelBase):
valid_sets = []
if eval_set is not None:
if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
# reduce cost for prediction training data
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
def get_meta_data(collection, i):
def _get_meta_data(collection, i):
if collection is None:
return None
elif isinstance(collection, list):
......@@ -456,15 +450,23 @@ class LGBMModel(_LGBMModelBase):
return collection.get(i, None)
else:
raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group should be dict or list')
valid_weight = get_meta_data(eval_sample_weight, i)
if get_meta_data(eval_class_weight, i) is not None:
valid_class_sample_weight = _LGBMComputeSampleWeight(get_meta_data(eval_class_weight, i), valid_data[1])
if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
# reduce cost for prediction training data
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = _get_meta_data(eval_sample_weight, i)
if _get_meta_data(eval_class_weight, i) is not None:
valid_class_sample_weight = _LGBMComputeSampleWeight(_get_meta_data(eval_class_weight, i), valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = get_meta_data(eval_init_score, i)
valid_group = get_meta_data(eval_group, i)
valid_init_score = _get_meta_data(eval_init_score, i)
valid_group = _get_meta_data(eval_group, i)
valid_set = _construct_dataset(valid_data[0], valid_data[1], valid_weight, valid_init_score, valid_group, params)
valid_sets.append(valid_set)
......@@ -668,14 +670,14 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
ova_aliases = ("multiclassova", "multiclass_ova", "ova", "ovr")
if self._objective not in ova_aliases and not callable(self._objective):
self._objective = "multiclass"
if eval_metric == 'logloss' or eval_metric == 'binary_logloss':
if eval_metric in ('logloss', 'binary_logloss'):
eval_metric = "multi_logloss"
elif eval_metric == 'error' or eval_metric == 'binary_error':
elif eval_metric in ('error', 'binary_error'):
eval_metric = "multi_error"
else:
if eval_metric == 'logloss' or eval_metric == 'multi_logloss':
if eval_metric in ('logloss', 'multi_logloss'):
eval_metric = 'binary_logloss'
elif eval_metric == 'error' or eval_metric == 'multi_error':
elif eval_metric in ('error', 'multi_error'):
eval_metric = 'binary_error'
if eval_set is not None:
......@@ -809,5 +811,5 @@ class LGBMRanker(LGBMModel):
+ 'eval_metric : string, list of strings, callable or None, optional (default="ndcg")\n'
+ _base_doc[_base_doc.find(' If string, it should be a built-in evaluation metric to use.'):_base_doc.find('early_stopping_rounds :')]
+ 'eval_at : list of int, optional (default=[1])\n'
' The evaluation positions of NDCG.\n'
' The evaluation positions of the specified metric.\n'
+ _base_doc[_base_doc.find(' early_stopping_rounds :'):])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment