Commit c04c0917 authored by Nikita Titov's avatar Nikita Titov Committed by Tsukasa OMOTO
Browse files

[python] minor fixes in sklearn wrapper (#1572)

* refined random_state description

* use neutral alias and fixed desc of eval_at param, since metric can be not only ndcg

* simplified checks

* consider verbose alias

* fixed declaring function every loop iteration
parent d93fbba3
...@@ -188,7 +188,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -188,7 +188,7 @@ class LGBMModel(_LGBMModelBase):
L2 regularization term on weights. L2 regularization term on weights.
random_state : int or None, optional (default=None) random_state : int or None, optional (default=None)
Random number seed. Random number seed.
Will use default seeds in c++ code if set to None. If None, default seeds in C++ code will be used.
n_jobs : int, optional (default=-1) n_jobs : int, optional (default=-1)
Number of parallel threads. Number of parallel threads.
silent : bool, optional (default=True) silent : bool, optional (default=True)
...@@ -398,7 +398,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -398,7 +398,7 @@ class LGBMModel(_LGBMModelBase):
evals_result = {} evals_result = {}
params = self.get_params() params = self.get_params()
# user can set verbose with kwargs, it has higher priority # user can set verbose with kwargs, it has higher priority
if 'verbose' not in params and self.silent: if not any(verbose_alias in params for verbose_alias in ('verbose', 'verbosity')) and self.silent:
params['verbose'] = 0 params['verbose'] = 0
params.pop('silent', None) params.pop('silent', None)
params.pop('importance_type', None) params.pop('importance_type', None)
...@@ -407,7 +407,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -407,7 +407,7 @@ class LGBMModel(_LGBMModelBase):
if self._n_classes is not None and self._n_classes > 2: if self._n_classes is not None and self._n_classes > 2:
params['num_class'] = self._n_classes params['num_class'] = self._n_classes
if hasattr(self, '_eval_at'): if hasattr(self, '_eval_at'):
params['ndcg_eval_at'] = self._eval_at params['eval_at'] = self._eval_at
params['objective'] = self._objective params['objective'] = self._objective
if self._fobj: if self._fobj:
params['objective'] = 'None' # objective = nullptr for unknown objective params['objective'] = 'None' # objective = nullptr for unknown objective
...@@ -440,6 +440,17 @@ class LGBMModel(_LGBMModelBase): ...@@ -440,6 +440,17 @@ class LGBMModel(_LGBMModelBase):
valid_sets = [] valid_sets = []
if eval_set is not None: if eval_set is not None:
def _get_meta_data(collection, i):
if collection is None:
return None
elif isinstance(collection, list):
return collection[i] if len(collection) > i else None
elif isinstance(collection, dict):
return collection.get(i, None)
else:
raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group should be dict or list')
if isinstance(eval_set, tuple): if isinstance(eval_set, tuple):
eval_set = [eval_set] eval_set = [eval_set]
for i, valid_data in enumerate(eval_set): for i, valid_data in enumerate(eval_set):
...@@ -447,24 +458,15 @@ class LGBMModel(_LGBMModelBase): ...@@ -447,24 +458,15 @@ class LGBMModel(_LGBMModelBase):
if valid_data[0] is X and valid_data[1] is y: if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set valid_set = train_set
else: else:
def get_meta_data(collection, i): valid_weight = _get_meta_data(eval_sample_weight, i)
if collection is None: if _get_meta_data(eval_class_weight, i) is not None:
return None valid_class_sample_weight = _LGBMComputeSampleWeight(_get_meta_data(eval_class_weight, i), valid_data[1])
elif isinstance(collection, list):
return collection[i] if len(collection) > i else None
elif isinstance(collection, dict):
return collection.get(i, None)
else:
raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group should be dict or list')
valid_weight = get_meta_data(eval_sample_weight, i)
if get_meta_data(eval_class_weight, i) is not None:
valid_class_sample_weight = _LGBMComputeSampleWeight(get_meta_data(eval_class_weight, i), valid_data[1])
if valid_weight is None or len(valid_weight) == 0: if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight valid_weight = valid_class_sample_weight
else: else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight) valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = get_meta_data(eval_init_score, i) valid_init_score = _get_meta_data(eval_init_score, i)
valid_group = get_meta_data(eval_group, i) valid_group = _get_meta_data(eval_group, i)
valid_set = _construct_dataset(valid_data[0], valid_data[1], valid_weight, valid_init_score, valid_group, params) valid_set = _construct_dataset(valid_data[0], valid_data[1], valid_weight, valid_init_score, valid_group, params)
valid_sets.append(valid_set) valid_sets.append(valid_set)
...@@ -668,14 +670,14 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -668,14 +670,14 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
ova_aliases = ("multiclassova", "multiclass_ova", "ova", "ovr") ova_aliases = ("multiclassova", "multiclass_ova", "ova", "ovr")
if self._objective not in ova_aliases and not callable(self._objective): if self._objective not in ova_aliases and not callable(self._objective):
self._objective = "multiclass" self._objective = "multiclass"
if eval_metric == 'logloss' or eval_metric == 'binary_logloss': if eval_metric in ('logloss', 'binary_logloss'):
eval_metric = "multi_logloss" eval_metric = "multi_logloss"
elif eval_metric == 'error' or eval_metric == 'binary_error': elif eval_metric in ('error', 'binary_error'):
eval_metric = "multi_error" eval_metric = "multi_error"
else: else:
if eval_metric == 'logloss' or eval_metric == 'multi_logloss': if eval_metric in ('logloss', 'multi_logloss'):
eval_metric = 'binary_logloss' eval_metric = 'binary_logloss'
elif eval_metric == 'error' or eval_metric == 'multi_error': elif eval_metric in ('error', 'multi_error'):
eval_metric = 'binary_error' eval_metric = 'binary_error'
if eval_set is not None: if eval_set is not None:
...@@ -809,5 +811,5 @@ class LGBMRanker(LGBMModel): ...@@ -809,5 +811,5 @@ class LGBMRanker(LGBMModel):
+ 'eval_metric : string, list of strings, callable or None, optional (default="ndcg")\n' + 'eval_metric : string, list of strings, callable or None, optional (default="ndcg")\n'
+ _base_doc[_base_doc.find(' If string, it should be a built-in evaluation metric to use.'):_base_doc.find('early_stopping_rounds :')] + _base_doc[_base_doc.find(' If string, it should be a built-in evaluation metric to use.'):_base_doc.find('early_stopping_rounds :')]
+ 'eval_at : list of int, optional (default=[1])\n' + 'eval_at : list of int, optional (default=[1])\n'
' The evaluation positions of NDCG.\n' ' The evaluation positions of the specified metric.\n'
+ _base_doc[_base_doc.find(' early_stopping_rounds :'):]) + _base_doc[_base_doc.find(' early_stopping_rounds :'):])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment