Commit 7509ec8a authored by Guolin Ke's avatar Guolin Ke Committed by Nikita Titov
Browse files

[python] fix group type in lgb.cv (#2384)

parent 29525ffe
......@@ -307,7 +307,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
if hasattr(folds, 'split'):
group_info = full_data.get_group()
if group_info is not None:
group_info = group_info.astype(int)
group_info = np.array(group_info, dtype=int)
flatted_group = np.repeat(range_(len(group_info)), repeats=group_info)
else:
flatted_group = np.zeros(num_data, dtype=int)
......@@ -317,7 +317,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for lambdarank cv.')
# lambdarank task, split according to groups
group_info = full_data.get_group().astype(int)
group_info = np.array(full_data.get_group(), dtype=int)
flatted_group = np.repeat(range_(len(group_info)), repeats=group_info)
group_kfold = _LGBMGroupKFold(n_splits=nfold)
folds = group_kfold.split(X=np.zeros(num_data), groups=flatted_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment