Commit 7825084f authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[python] break extremely large lines and fix classname (#1698)

* break extremely large lines in basic.py

* break extremely large lines in callback.py

* break extremely large lines in engine.py

* break extremely large lines in sklearn.py

* hotfixes
parent a760eae4
...@@ -134,7 +134,7 @@ def param_dict_to_str(data): ...@@ -134,7 +134,7 @@ def param_dict_to_str(data):
return ' '.join(pairs) return ' '.join(pairs)
class _temp_file(object): class _TempFile(object):
def __enter__(self): def __enter__(self):
with NamedTemporaryFile(prefix="lightgbm_tmp_", delete=True) as f: with NamedTemporaryFile(prefix="lightgbm_tmp_", delete=True) as f:
self.name = f.name self.name = f.name
...@@ -192,7 +192,8 @@ def convert_from_sliced_object(data): ...@@ -192,7 +192,8 @@ def convert_from_sliced_object(data):
"""fix the memory of multi-dimensional sliced object""" """fix the memory of multi-dimensional sliced object"""
if data.base is not None and isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray): if data.base is not None and isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray):
if not data.flags.c_contiguous: if not data.flags.c_contiguous:
warnings.warn("Usage of np.ndarray subset (sliced data) is not recommended due to it will double the peak memory cost in LightGBM.") warnings.warn("Usage of np.ndarray subset (sliced data) is not recommended "
"due to it will double the peak memory cost in LightGBM.")
return np.copy(data) return np.copy(data)
return data return data
...@@ -271,7 +272,8 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica ...@@ -271,7 +272,8 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
bad_fields = [data.columns[i] for i, dtype in bad_fields = [data.columns[i] for i, dtype in
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER] enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields """ msg = ("DataFrame.dtypes for data must be int, float or bool.\n"
"Did not expect the data types in fields ")
raise ValueError(msg + ', '.join(bad_fields)) raise ValueError(msg + ', '.join(bad_fields))
data = data.values.astype('float') data = data.values.astype('float')
else: else:
...@@ -295,7 +297,8 @@ def _label_from_pandas(label): ...@@ -295,7 +297,8 @@ def _label_from_pandas(label):
def _save_pandas_categorical(file_name, pandas_categorical): def _save_pandas_categorical(file_name, pandas_categorical):
with open(file_name, 'a') as f: with open(file_name, 'a') as f:
f.write('\npandas_categorical:' + json.dumps(pandas_categorical, default=json_default_with_numpy) + '\n') f.write('\npandas_categorical:'
+ json.dumps(pandas_categorical, default=json_default_with_numpy) + '\n')
def _load_pandas_categorical(file_name): def _load_pandas_categorical(file_name):
...@@ -418,7 +421,7 @@ class _InnerPredictor(object): ...@@ -418,7 +421,7 @@ class _InnerPredictor(object):
num_iteration = self.num_total_iteration num_iteration = self.num_total_iteration
if isinstance(data, string_type): if isinstance(data, string_type):
with _temp_file() as f: with _TempFile() as f:
_safe_call(_LIB.LGBM_BoosterPredictForFile( _safe_call(_LIB.LGBM_BoosterPredictForFile(
self.handle, self.handle,
c_str(data), c_str(data),
...@@ -521,7 +524,8 @@ class _InnerPredictor(object): ...@@ -521,7 +524,8 @@ class _InnerPredictor(object):
n_preds = [self.__get_num_preds(num_iteration, i, predict_type) for i in np.diff([0] + list(sections) + [nrow])] n_preds = [self.__get_num_preds(num_iteration, i, predict_type) for i in np.diff([0] + list(sections) + [nrow])]
n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum() n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum()
preds = np.zeros(sum(n_preds), dtype=np.float64) preds = np.zeros(sum(n_preds), dtype=np.float64)
for chunk, (start_idx_pred, end_idx_pred) in zip_(np.array_split(mat, sections), zip_(n_preds_sections, n_preds_sections[1:])): for chunk, (start_idx_pred, end_idx_pred) in zip_(np.array_split(mat, sections),
zip_(n_preds_sections, n_preds_sections[1:])):
# avoid memory consumption by arrays concatenation operations # avoid memory consumption by arrays concatenation operations
inner_predict(chunk, num_iteration, predict_type, preds[start_idx_pred:end_idx_pred]) inner_predict(chunk, num_iteration, predict_type, preds[start_idx_pred:end_idx_pred])
return preds, nrow return preds, nrow
...@@ -692,16 +696,22 @@ class Dataset(object): ...@@ -692,16 +696,22 @@ class Dataset(object):
if reference is not None: if reference is not None:
self.pandas_categorical = reference.pandas_categorical self.pandas_categorical = reference.pandas_categorical
categorical_feature = reference.categorical_feature categorical_feature = reference.categorical_feature
data, feature_name, categorical_feature, self.pandas_categorical = _data_from_pandas(data, feature_name, categorical_feature, self.pandas_categorical) data, feature_name, categorical_feature, self.pandas_categorical = _data_from_pandas(data,
feature_name,
categorical_feature,
self.pandas_categorical)
label = _label_from_pandas(label) label = _label_from_pandas(label)
self.data_has_header = False self.data_has_header = False
# process for args # process for args
params = {} if params is None else params params = {} if params is None else params
args_names = getattr(self.__class__, '_lazy_init').__code__.co_varnames[:getattr(self.__class__, '_lazy_init').__code__.co_argcount] args_names = (getattr(self.__class__, '_lazy_init')
.__code__
.co_varnames[:getattr(self.__class__, '_lazy_init').__code__.co_argcount])
for key, _ in params.items(): for key, _ in params.items():
if key in args_names: if key in args_names:
warnings.warn('{0} keyword has been found in `params` and will be ignored. ' warnings.warn('{0} keyword has been found in `params` and will be ignored.\n'
'Please use {0} argument of the Dataset constructor to pass this parameter.'.format(key)) 'Please use {0} argument of the Dataset constructor to pass this parameter.'
.format(key))
self.predictor = predictor self.predictor = predictor
# user can set verbose with params, it has higher priority # user can set verbose with params, it has higher priority
if not any(verbose_alias in params for verbose_alias in ('verbose', 'verbosity')) and silent: if not any(verbose_alias in params for verbose_alias in ('verbose', 'verbosity')) and silent:
...@@ -930,7 +940,8 @@ class Dataset(object): ...@@ -930,7 +940,8 @@ class Dataset(object):
if self.used_indices is None: if self.used_indices is None:
# create valid # create valid
self._lazy_init(self.data, label=self.label, reference=self.reference, self._lazy_init(self.data, label=self.label, reference=self.reference,
weight=self.weight, group=self.group, init_score=self.init_score, predictor=self._predictor, weight=self.weight, group=self.group,
init_score=self.init_score, predictor=self._predictor,
silent=self.silent, feature_name=self.feature_name, params=self.params) silent=self.silent, feature_name=self.feature_name, params=self.params)
else: else:
# construct subset # construct subset
...@@ -938,7 +949,8 @@ class Dataset(object): ...@@ -938,7 +949,8 @@ class Dataset(object):
assert used_indices.flags.c_contiguous assert used_indices.flags.c_contiguous
if self.reference.group is not None: if self.reference.group is not None:
group_info = np.array(self.reference.group).astype(int) group_info = np.array(self.reference.group).astype(int)
_, self.group = np.unique(np.repeat(range_(len(group_info)), repeats=group_info)[self.used_indices], return_counts=True) _, self.group = np.unique(np.repeat(range_(len(group_info)), repeats=group_info)[self.used_indices],
return_counts=True)
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
params_str = param_dict_to_str(self.params) params_str = param_dict_to_str(self.params)
_safe_call(_LIB.LGBM_DatasetGetSubset( _safe_call(_LIB.LGBM_DatasetGetSubset(
...@@ -954,8 +966,9 @@ class Dataset(object): ...@@ -954,8 +966,9 @@ class Dataset(object):
else: else:
# create train # create train
self._lazy_init(self.data, label=self.label, self._lazy_init(self.data, label=self.label,
weight=self.weight, group=self.group, init_score=self.init_score, weight=self.weight, group=self.group,
predictor=self._predictor, silent=self.silent, feature_name=self.feature_name, init_score=self.init_score, predictor=self._predictor,
silent=self.silent, feature_name=self.feature_name,
categorical_feature=self.categorical_feature, params=self.params) categorical_feature=self.categorical_feature, params=self.params)
if self.free_raw_data: if self.free_raw_data:
self.data = None self.data = None
...@@ -1158,11 +1171,13 @@ class Dataset(object): ...@@ -1158,11 +1171,13 @@ class Dataset(object):
warnings.warn('Using categorical_feature in Dataset.') warnings.warn('Using categorical_feature in Dataset.')
return self return self
else: else:
warnings.warn('categorical_feature in Dataset is overridden. New categorical_feature is {}'.format(sorted(list(categorical_feature)))) warnings.warn('categorical_feature in Dataset is overridden.\n'
'New categorical_feature is {}'.format(sorted(list(categorical_feature))))
self.categorical_feature = categorical_feature self.categorical_feature = categorical_feature
return self._free_handle() return self._free_handle()
else: else:
raise LightGBMError("Cannot set categorical feature after freed raw data, set free_raw_data=False when construct Dataset to avoid this.") raise LightGBMError("Cannot set categorical feature after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this.")
def _set_predictor(self, predictor): def _set_predictor(self, predictor):
""" """
...@@ -1175,7 +1190,8 @@ class Dataset(object): ...@@ -1175,7 +1190,8 @@ class Dataset(object):
self._predictor = predictor self._predictor = predictor
return self._free_handle() return self._free_handle()
else: else:
raise LightGBMError("Cannot set predictor after freed raw data, set free_raw_data=False when construct Dataset to avoid this.") raise LightGBMError("Cannot set predictor after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this.")
def set_reference(self, reference): def set_reference(self, reference):
"""Set reference Dataset. """Set reference Dataset.
...@@ -1190,7 +1206,9 @@ class Dataset(object): ...@@ -1190,7 +1206,9 @@ class Dataset(object):
self : Dataset self : Dataset
Dataset with set reference. Dataset with set reference.
""" """
self.set_categorical_feature(reference.categorical_feature).set_feature_name(reference.feature_name)._set_predictor(reference._predictor) self.set_categorical_feature(reference.categorical_feature) \
.set_feature_name(reference.feature_name) \
._set_predictor(reference._predictor)
# we're done if self and reference share a common upstrem reference # we're done if self and reference share a common upstrem reference
if self.get_ref_chain().intersection(reference.get_ref_chain()): if self.get_ref_chain().intersection(reference.get_ref_chain()):
return self return self
...@@ -1198,7 +1216,8 @@ class Dataset(object): ...@@ -1198,7 +1216,8 @@ class Dataset(object):
self.reference = reference self.reference = reference
return self._free_handle() return self._free_handle()
else: else:
raise LightGBMError("Cannot set reference after freed raw data, set free_raw_data=False when construct Dataset to avoid this.") raise LightGBMError("Cannot set reference after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this.")
def set_feature_name(self, feature_name): def set_feature_name(self, feature_name):
"""Set feature name. """Set feature name.
...@@ -1217,7 +1236,8 @@ class Dataset(object): ...@@ -1217,7 +1236,8 @@ class Dataset(object):
self.feature_name = feature_name self.feature_name = feature_name
if self.handle is not None and feature_name is not None and feature_name != 'auto': if self.handle is not None and feature_name is not None and feature_name != 'auto':
if len(feature_name) != self.num_feature(): if len(feature_name) != self.num_feature():
raise ValueError("Length of feature_name({}) and num_feature({}) don't match".format(len(feature_name), self.num_feature())) raise ValueError("Length of feature_name({}) and num_feature({}) don't match"
.format(len(feature_name), self.num_feature()))
c_feature_name = [c_str(name) for name in feature_name] c_feature_name = [c_str(name) for name in feature_name]
_safe_call(_LIB.LGBM_DatasetSetFeatureNames( _safe_call(_LIB.LGBM_DatasetSetFeatureNames(
self.handle, self.handle,
...@@ -1445,7 +1465,8 @@ class Booster(object): ...@@ -1445,7 +1465,8 @@ class Booster(object):
if train_set is not None: if train_set is not None:
# Training task # Training task
if not isinstance(train_set, Dataset): if not isinstance(train_set, Dataset):
raise TypeError('Training data should be Dataset instance, met {}'.format(type(train_set).__name__)) raise TypeError('Training data should be Dataset instance, met {}'
.format(type(train_set).__name__))
params_str = param_dict_to_str(params) params_str = param_dict_to_str(params)
# construct booster object # construct booster object
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
...@@ -1640,9 +1661,11 @@ class Booster(object): ...@@ -1640,9 +1661,11 @@ class Booster(object):
Booster with set validation data. Booster with set validation data.
""" """
if not isinstance(data, Dataset): if not isinstance(data, Dataset):
raise TypeError('Validation data should be Dataset instance, met {}'.format(type(data).__name__)) raise TypeError('Validation data should be Dataset instance, met {}'
.format(type(data).__name__))
if data._predictor is not self.__init_predictor: if data._predictor is not self.__init_predictor:
raise LightGBMError("Add validation data failed, you should use same predictor for these data") raise LightGBMError("Add validation data failed, "
"you should use same predictor for these data")
_safe_call(_LIB.LGBM_BoosterAddValidData( _safe_call(_LIB.LGBM_BoosterAddValidData(
self.handle, self.handle,
data.construct().handle)) data.construct().handle))
...@@ -1700,9 +1723,11 @@ class Booster(object): ...@@ -1700,9 +1723,11 @@ class Booster(object):
# need reset training data # need reset training data
if train_set is not None and train_set is not self.train_set: if train_set is not None and train_set is not self.train_set:
if not isinstance(train_set, Dataset): if not isinstance(train_set, Dataset):
raise TypeError('Training data should be Dataset instance, met {}'.format(type(train_set).__name__)) raise TypeError('Training data should be Dataset instance, met {}'
.format(type(train_set).__name__))
if train_set._predictor is not self.__init_predictor: if train_set._predictor is not self.__init_predictor:
raise LightGBMError("Replace training data failed, you should use same predictor for these data") raise LightGBMError("Replace training data failed, "
"you should use same predictor for these data")
self.train_set = train_set self.train_set = train_set
_safe_call(_LIB.LGBM_BoosterResetTrainingData( _safe_call(_LIB.LGBM_BoosterResetTrainingData(
self.handle, self.handle,
...@@ -1748,7 +1773,8 @@ class Booster(object): ...@@ -1748,7 +1773,8 @@ class Booster(object):
assert grad.flags.c_contiguous assert grad.flags.c_contiguous
assert hess.flags.c_contiguous assert hess.flags.c_contiguous
if len(grad) != len(hess): if len(grad) != len(hess):
raise ValueError("Lengths of gradient({}) and hessian({}) don't match".format(len(grad), len(hess))) raise ValueError("Lengths of gradient({}) and hessian({}) don't match"
.format(len(grad), len(hess)))
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom( _safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom(
self.handle, self.handle,
...@@ -2051,7 +2077,8 @@ class Booster(object): ...@@ -2051,7 +2077,8 @@ class Booster(object):
ptr_string_buffer)) ptr_string_buffer))
return json.loads(string_buffer.value.decode()) return json.loads(string_buffer.value.decode())
def predict(self, data, num_iteration=None, raw_score=False, pred_leaf=False, pred_contrib=False, def predict(self, data, num_iteration=None,
raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, is_reshape=True, pred_parameter=None, **kwargs): data_has_header=False, is_reshape=True, pred_parameter=None, **kwargs):
"""Make a prediction. """Make a prediction.
...@@ -2064,7 +2091,6 @@ class Booster(object): ...@@ -2064,7 +2091,6 @@ class Booster(object):
Limit number of iterations in the prediction. Limit number of iterations in the prediction.
If None, if the best iteration exists, it is used; otherwise, all iterations are used. If None, if the best iteration exists, it is used; otherwise, all iterations are used.
If <= 0, all iterations are used (no limits). If <= 0, all iterations are used (no limits).
raw_score : bool, optional (default=False) raw_score : bool, optional (default=False)
Whether to predict raw scores. Whether to predict raw scores.
pred_leaf : bool, optional (default=False) pred_leaf : bool, optional (default=False)
...@@ -2093,7 +2119,9 @@ class Booster(object): ...@@ -2093,7 +2119,9 @@ class Booster(object):
predictor = self._to_predictor(kwargs) predictor = self._to_predictor(kwargs)
if num_iteration is None: if num_iteration is None:
num_iteration = self.best_iteration num_iteration = self.best_iteration
return predictor.predict(data, num_iteration, raw_score, pred_leaf, pred_contrib, data_has_header, is_reshape) return predictor.predict(data, num_iteration,
raw_score, pred_leaf, pred_contrib,
data_has_header, is_reshape)
def refit(self, data, label, decay_rate=0.9, **kwargs): def refit(self, data, label, decay_rate=0.9, **kwargs):
"""Refit the existing Booster by new data. """Refit the existing Booster by new data.
...@@ -2106,7 +2134,8 @@ class Booster(object): ...@@ -2106,7 +2134,8 @@ class Booster(object):
label : list, numpy 1-D array or pandas one-column DataFrame/Series label : list, numpy 1-D array or pandas one-column DataFrame/Series
Label for refit. Label for refit.
decay_rate : float, optional (default=0.9) decay_rate : float, optional (default=0.9)
Decay rate of refit, will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees. Decay rate of refit,
will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees.
**kwargs : other parameters for refit **kwargs : other parameters for refit
These parameters will be passed to ``predict`` method. These parameters will be passed to ``predict`` method.
...@@ -2248,7 +2277,8 @@ class Booster(object): ...@@ -2248,7 +2277,8 @@ class Booster(object):
if tmp_out_len.value != self.__num_inner_eval: if tmp_out_len.value != self.__num_inner_eval:
raise ValueError("Wrong length of eval results") raise ValueError("Wrong length of eval results")
for i in range_(self.__num_inner_eval): for i in range_(self.__num_inner_eval):
ret.append((data_name, self.__name_inner_eval[i], result[i], self.__higher_better_inner_eval[i])) ret.append((data_name, self.__name_inner_eval[i],
result[i], self.__higher_better_inner_eval[i]))
if feval is not None: if feval is not None:
if data_idx == 0: if data_idx == 0:
cur_data = self.train_set cur_data = self.train_set
......
...@@ -133,7 +133,8 @@ def reset_parameter(**kwargs): ...@@ -133,7 +133,8 @@ def reset_parameter(**kwargs):
raise RuntimeError("cannot reset {} during training".format(repr(key))) raise RuntimeError("cannot reset {} during training".format(repr(key)))
if isinstance(value, list): if isinstance(value, list):
if len(value) != env.end_iteration - env.begin_iteration: if len(value) != env.end_iteration - env.begin_iteration:
raise ValueError("Length of list {} has to equal to 'num_boost_round'.".format(repr(key))) raise ValueError("Length of list {} has to equal to 'num_boost_round'."
.format(repr(key)))
new_param = value[env.iteration - env.begin_iteration] new_param = value[env.iteration - env.begin_iteration]
else: else:
new_param = value(env.iteration - env.begin_iteration) new_param = value(env.iteration - env.begin_iteration)
...@@ -180,7 +181,8 @@ def early_stopping(stopping_rounds, verbose=True): ...@@ -180,7 +181,8 @@ def early_stopping(stopping_rounds, verbose=True):
def init(env): def init(env):
"""internal function""" """internal function"""
if not env.evaluation_result_list: if not env.evaluation_result_list:
raise ValueError('For early stopping, at least one dataset and eval metric is required for evaluation') raise ValueError('For early stopping, '
'at least one dataset and eval metric is required for evaluation')
if verbose: if verbose:
msg = "Training until validation scores don't improve for {} rounds." msg = "Training until validation scores don't improve for {} rounds."
......
...@@ -129,7 +129,10 @@ def train(params, train_set, num_boost_round=100, ...@@ -129,7 +129,10 @@ def train(params, train_set, num_boost_round=100,
if not isinstance(train_set, Dataset): if not isinstance(train_set, Dataset):
raise TypeError("Training only accepts Dataset object") raise TypeError("Training only accepts Dataset object")
train_set._update_params(params)._set_predictor(predictor).set_feature_name(feature_name).set_categorical_feature(categorical_feature) train_set._update_params(params) \
._set_predictor(predictor) \
.set_feature_name(feature_name) \
.set_categorical_feature(categorical_feature)
is_valid_contain_train = False is_valid_contain_train = False
train_data_name = "training" train_data_name = "training"
...@@ -341,7 +344,7 @@ def cv(params, train_set, num_boost_round=100, ...@@ -341,7 +344,7 @@ def cv(params, train_set, num_boost_round=100,
Data to be trained on. Data to be trained on.
num_boost_round : int, optional (default=100) num_boost_round : int, optional (default=100)
Number of boosting iterations. Number of boosting iterations.
folds : a generator or iterator of (train_idx, test_idx) tuples, scikit-learn splitter object or None, optional (default=None) folds : generator or iterator of (train_idx, test_idx) tuples, scikit-learn splitter object or None, optional (default=None)
If generator or iterator, it should yield the train and test indices for the each fold. If generator or iterator, it should yield the train and test indices for the each fold.
If object, it should be one of the scikit-learn splitter classes If object, it should be one of the scikit-learn splitter classes
(http://scikit-learn.org/stable/modules/classes.html#splitter-classes) (http://scikit-learn.org/stable/modules/classes.html#splitter-classes)
...@@ -434,7 +437,10 @@ def cv(params, train_set, num_boost_round=100, ...@@ -434,7 +437,10 @@ def cv(params, train_set, num_boost_round=100,
predictor = init_model._to_predictor(dict(init_model.params, **params)) predictor = init_model._to_predictor(dict(init_model.params, **params))
else: else:
predictor = None predictor = None
train_set._update_params(params)._set_predictor(predictor).set_feature_name(feature_name).set_categorical_feature(categorical_feature) train_set._update_params(params) \
._set_predictor(predictor) \
.set_feature_name(feature_name) \
.set_categorical_feature(categorical_feature)
if metrics is not None: if metrics is not None:
params['metric'] = metrics params['metric'] = metrics
......
...@@ -469,7 +469,8 @@ class LGBMModel(_LGBMModelBase): ...@@ -469,7 +469,8 @@ class LGBMModel(_LGBMModelBase):
elif isinstance(collection, dict): elif isinstance(collection, dict):
return collection.get(i, None) return collection.get(i, None)
else: else:
raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group should be dict or list') raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group '
'should be dict or list')
if isinstance(eval_set, tuple): if isinstance(eval_set, tuple):
eval_set = [eval_set] eval_set = [eval_set]
...@@ -480,14 +481,16 @@ class LGBMModel(_LGBMModelBase): ...@@ -480,14 +481,16 @@ class LGBMModel(_LGBMModelBase):
else: else:
valid_weight = _get_meta_data(eval_sample_weight, i) valid_weight = _get_meta_data(eval_sample_weight, i)
if _get_meta_data(eval_class_weight, i) is not None: if _get_meta_data(eval_class_weight, i) is not None:
valid_class_sample_weight = _LGBMComputeSampleWeight(_get_meta_data(eval_class_weight, i), valid_data[1]) valid_class_sample_weight = _LGBMComputeSampleWeight(_get_meta_data(eval_class_weight, i),
valid_data[1])
if valid_weight is None or len(valid_weight) == 0: if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight valid_weight = valid_class_sample_weight
else: else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight) valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _get_meta_data(eval_init_score, i) valid_init_score = _get_meta_data(eval_init_score, i)
valid_group = _get_meta_data(eval_group, i) valid_group = _get_meta_data(eval_group, i)
valid_set = _construct_dataset(valid_data[0], valid_data[1], valid_weight, valid_init_score, valid_group, params) valid_set = _construct_dataset(valid_data[0], valid_data[1],
valid_weight, valid_init_score, valid_group, params)
valid_sets.append(valid_set) valid_sets.append(valid_set)
self._Booster = train(params, train_set, self._Booster = train(params, train_set,
...@@ -786,8 +789,10 @@ class LGBMRanker(LGBMModel): ...@@ -786,8 +789,10 @@ class LGBMRanker(LGBMModel):
raise ValueError("Eval_group cannot be None when eval_set is not None") raise ValueError("Eval_group cannot be None when eval_set is not None")
elif len(eval_group) != len(eval_set): elif len(eval_group) != len(eval_set):
raise ValueError("Length of eval_group should be equal to eval_set") raise ValueError("Length of eval_group should be equal to eval_set")
elif (isinstance(eval_group, dict) and any(i not in eval_group or eval_group[i] is None for i in range_(len(eval_group)))) \ elif (isinstance(eval_group, dict)
or (isinstance(eval_group, list) and any(group is None for group in eval_group)): and any(i not in eval_group or eval_group[i] is None for i in range_(len(eval_group)))
or isinstance(eval_group, list)
and any(group is None for group in eval_group)):
raise ValueError("Should set group for all eval datasets for ranking task; " raise ValueError("Should set group for all eval datasets for ranking task; "
"if you use dict, the index should start from 0") "if you use dict, the index should start from 0")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment