Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
48e3629d
Unverified
Commit
48e3629d
authored
Jan 03, 2024
by
James Lamb
Committed by
GitHub
Jan 03, 2024
Browse files
[python-package] fix mypy error about pandas categorical features (#6253)
parent
2bd60c8f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+5
-5
No files found.
python-package/lightgbm/basic.py
View file @
48e3629d
...
@@ -786,7 +786,7 @@ def _data_from_pandas(
...
@@ -786,7 +786,7 @@ def _data_from_pandas(
feature_name
:
_LGBM_FeatureNameConfiguration
,
feature_name
:
_LGBM_FeatureNameConfiguration
,
categorical_feature
:
_LGBM_CategoricalFeatureConfiguration
,
categorical_feature
:
_LGBM_CategoricalFeatureConfiguration
,
pandas_categorical
:
Optional
[
List
[
List
]]
pandas_categorical
:
Optional
[
List
[
List
]]
)
->
Tuple
[
np
.
ndarray
,
List
[
str
],
List
[
str
],
List
[
List
]]:
)
->
Tuple
[
np
.
ndarray
,
List
[
str
],
Union
[
List
[
str
],
List
[
int
]],
List
[
List
]]:
if
len
(
data
.
shape
)
!=
2
or
data
.
shape
[
0
]
<
1
:
if
len
(
data
.
shape
)
!=
2
or
data
.
shape
[
0
]
<
1
:
raise
ValueError
(
'Input data must be 2 dimensional and non empty.'
)
raise
ValueError
(
'Input data must be 2 dimensional and non empty.'
)
...
@@ -800,7 +800,7 @@ def _data_from_pandas(
...
@@ -800,7 +800,7 @@ def _data_from_pandas(
# determine categorical features
# determine categorical features
cat_cols
=
[
col
for
col
,
dtype
in
zip
(
data
.
columns
,
data
.
dtypes
)
if
isinstance
(
dtype
,
pd_CategoricalDtype
)]
cat_cols
=
[
col
for
col
,
dtype
in
zip
(
data
.
columns
,
data
.
dtypes
)
if
isinstance
(
dtype
,
pd_CategoricalDtype
)]
cat_cols_not_ordered
=
[
col
for
col
in
cat_cols
if
not
data
[
col
].
cat
.
ordered
]
cat_cols_not_ordered
:
List
[
str
]
=
[
col
for
col
in
cat_cols
if
not
data
[
col
].
cat
.
ordered
]
if
pandas_categorical
is
None
:
# train dataset
if
pandas_categorical
is
None
:
# train dataset
pandas_categorical
=
[
list
(
data
[
col
].
cat
.
categories
)
for
col
in
cat_cols
]
pandas_categorical
=
[
list
(
data
[
col
].
cat
.
categories
)
for
col
in
cat_cols
]
else
:
else
:
...
@@ -811,10 +811,10 @@ def _data_from_pandas(
...
@@ -811,10 +811,10 @@ def _data_from_pandas(
data
[
col
]
=
data
[
col
].
cat
.
set_categories
(
category
)
data
[
col
]
=
data
[
col
].
cat
.
set_categories
(
category
)
if
len
(
cat_cols
):
# cat_cols is list
if
len
(
cat_cols
):
# cat_cols is list
data
[
cat_cols
]
=
data
[
cat_cols
].
apply
(
lambda
x
:
x
.
cat
.
codes
).
replace
({
-
1
:
np
.
nan
})
data
[
cat_cols
]
=
data
[
cat_cols
].
apply
(
lambda
x
:
x
.
cat
.
codes
).
replace
({
-
1
:
np
.
nan
})
if
categorical_feature
==
'auto'
:
# use cat cols from DataFrame
# use cat cols from DataFrame
if
categorical_feature
==
'auto'
:
categorical_feature
=
cat_cols_not_ordered
categorical_feature
=
cat_cols_not_ordered
else
:
# use cat cols specified by user
categorical_feature
=
list
(
categorical_feature
)
# type: ignore[assignment]
df_dtypes
=
[
dtype
.
type
for
dtype
in
data
.
dtypes
]
df_dtypes
=
[
dtype
.
type
for
dtype
in
data
.
dtypes
]
# so that the target dtype considers floats
# so that the target dtype considers floats
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment