Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
7509ec8a
Commit
7509ec8a
authored
Sep 06, 2019
by
Guolin Ke
Committed by
Nikita Titov
Sep 06, 2019
Browse files
[python] fix group type in lgb.cv (#2384)
parent
29525ffe
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
python-package/lightgbm/engine.py
python-package/lightgbm/engine.py
+2
-2
No files found.
python-package/lightgbm/engine.py
View file @
7509ec8a
...
@@ -307,7 +307,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
...
@@ -307,7 +307,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
if
hasattr
(
folds
,
'split'
):
if
hasattr
(
folds
,
'split'
):
group_info
=
full_data
.
get_group
()
group_info
=
full_data
.
get_group
()
if
group_info
is
not
None
:
if
group_info
is
not
None
:
group_info
=
group_info
.
as
type
(
int
)
group_info
=
np
.
array
(
group_info
,
d
type
=
int
)
flatted_group
=
np
.
repeat
(
range_
(
len
(
group_info
)),
repeats
=
group_info
)
flatted_group
=
np
.
repeat
(
range_
(
len
(
group_info
)),
repeats
=
group_info
)
else
:
else
:
flatted_group
=
np
.
zeros
(
num_data
,
dtype
=
int
)
flatted_group
=
np
.
zeros
(
num_data
,
dtype
=
int
)
...
@@ -317,7 +317,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
...
@@ -317,7 +317,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
if
not
SKLEARN_INSTALLED
:
if
not
SKLEARN_INSTALLED
:
raise
LightGBMError
(
'Scikit-learn is required for lambdarank cv.'
)
raise
LightGBMError
(
'Scikit-learn is required for lambdarank cv.'
)
# lambdarank task, split according to groups
# lambdarank task, split according to groups
group_info
=
full_data
.
get_group
()
.
as
type
(
int
)
group_info
=
np
.
array
(
full_data
.
get_group
()
,
d
type
=
int
)
flatted_group
=
np
.
repeat
(
range_
(
len
(
group_info
)),
repeats
=
group_info
)
flatted_group
=
np
.
repeat
(
range_
(
len
(
group_info
)),
repeats
=
group_info
)
group_kfold
=
_LGBMGroupKFold
(
n_splits
=
nfold
)
group_kfold
=
_LGBMGroupKFold
(
n_splits
=
nfold
)
folds
=
group_kfold
.
split
(
X
=
np
.
zeros
(
num_data
),
groups
=
flatted_group
)
folds
=
group_kfold
.
split
(
X
=
np
.
zeros
(
num_data
),
groups
=
flatted_group
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment