Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
b27d81ea
Unverified
Commit
b27d81ea
authored
Mar 04, 2024
by
James Lamb
Committed by
GitHub
Mar 04, 2024
Browse files
[ci] [python-package] check for untyped definitions with mypy (#6339)
parent
1a292f89
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
38 additions
and
33 deletions
+38
-33
.ci/test.sh
.ci/test.sh
+1
-0
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+11
-11
python-package/lightgbm/compat.py
python-package/lightgbm/compat.py
+16
-16
python-package/lightgbm/plotting.py
python-package/lightgbm/plotting.py
+8
-5
python-package/lightgbm/sklearn.py
python-package/lightgbm/sklearn.py
+1
-1
python-package/pyproject.toml
python-package/pyproject.toml
+1
-0
No files found.
.ci/test.sh
View file @
b27d81ea
...
...
@@ -74,6 +74,7 @@ if [[ $TASK == "lint" ]]; then
${
CONDA_PYTHON_REQUIREMENT
}
\
cmakelint
\
cpplint
\
'matplotlib>=3.8.3'
\
mypy
\
'pre-commit>=3.6.0'
\
'pyarrow>=14.0'
\
...
...
python-package/lightgbm/basic.py
View file @
b27d81ea
...
...
@@ -13,7 +13,7 @@ from os import SEEK_END, environ
from
os.path
import
getsize
from
pathlib
import
Path
from
tempfile
import
NamedTemporaryFile
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Iterable
,
Iterator
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
scipy.sparse
...
...
@@ -537,13 +537,13 @@ def _param_dict_to_str(data: Optional[Dict[str, Any]]) -> str:
class
_TempFile
:
"""Proxy class to workaround errors on Windows."""
def
__enter__
(
self
):
def
__enter__
(
self
)
->
"_TempFile"
:
with
NamedTemporaryFile
(
prefix
=
"lightgbm_tmp_"
,
delete
=
True
)
as
f
:
self
.
name
=
f
.
name
self
.
path
=
Path
(
self
.
name
)
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
)
:
def
__exit__
(
self
,
exc_type
:
Any
,
exc_val
:
Any
,
exc_tb
:
Any
)
->
None
:
if
self
.
path
.
is_file
():
self
.
path
.
unlink
()
...
...
@@ -595,7 +595,7 @@ class _ConfigAliases:
)
@
classmethod
def
get
(
cls
,
*
args
)
->
Set
[
str
]:
def
get
(
cls
,
*
args
:
str
)
->
Set
[
str
]:
if
cls
.
aliases
is
None
:
cls
.
aliases
=
cls
.
_get_all_param_aliases
()
ret
=
set
()
...
...
@@ -610,7 +610,7 @@ class _ConfigAliases:
return
cls
.
aliases
.
get
(
name
,
[
name
])
@
classmethod
def
get_by_alias
(
cls
,
*
args
)
->
Set
[
str
]:
def
get_by_alias
(
cls
,
*
args
:
str
)
->
Set
[
str
]:
if
cls
.
aliases
is
None
:
cls
.
aliases
=
cls
.
_get_all_param_aliases
()
ret
=
set
(
args
)
...
...
@@ -1563,7 +1563,7 @@ class _InnerPredictor:
start_iteration
:
int
,
num_iteration
:
int
,
predict_type
:
int
,
):
)
->
Tuple
[
Union
[
List
[
scipy
.
sparse
.
csc_matrix
],
List
[
scipy
.
sparse
.
csr_matrix
]],
int
]
:
ptr_indptr
,
type_ptr_indptr
,
__
=
_c_int_array
(
csc
.
indptr
)
ptr_data
,
type_ptr_data
,
_
=
_c_float_array
(
csc
.
data
)
csc_indices
=
csc
.
indices
.
astype
(
np
.
int32
,
copy
=
False
)
...
...
@@ -1813,7 +1813,7 @@ class Dataset:
self
.
_need_slice
=
True
self
.
_predictor
:
Optional
[
_InnerPredictor
]
=
None
self
.
pandas_categorical
:
Optional
[
List
[
List
]]
=
None
self
.
_params_back_up
=
None
self
.
_params_back_up
:
Optional
[
Dict
[
str
,
Any
]]
=
None
self
.
version
=
0
self
.
_start_row
=
0
# Used when pushing rows one by one.
...
...
@@ -2195,7 +2195,7 @@ class Dataset:
return
self
.
set_feature_name
(
feature_name
)
@
staticmethod
def
_yield_row_from_seqlist
(
seqs
:
List
[
Sequence
],
indices
:
Iterable
[
int
]):
def
_yield_row_from_seqlist
(
seqs
:
List
[
Sequence
],
indices
:
Iterable
[
int
])
->
Iterator
[
np
.
ndarray
]
:
offset
=
0
seq_id
=
0
seq
=
seqs
[
seq_id
]
...
...
@@ -2697,7 +2697,7 @@ class Dataset:
return
self
params
=
deepcopy
(
params
)
def
update
():
def
update
()
->
None
:
if
not
self
.
params
:
self
.
params
=
params
else
:
...
...
@@ -3704,7 +3704,7 @@ class Booster:
def
__copy__
(
self
)
->
"Booster"
:
return
self
.
__deepcopy__
(
None
)
def
__deepcopy__
(
self
,
_
)
->
"Booster"
:
def
__deepcopy__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
"Booster"
:
model_str
=
self
.
model_to_string
(
num_iteration
=-
1
)
return
Booster
(
model_str
=
model_str
)
...
...
@@ -4757,7 +4757,7 @@ class Booster:
dataset_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
free_raw_data
:
bool
=
True
,
validate_features
:
bool
=
False
,
**
kwargs
,
**
kwargs
:
Any
,
)
->
"Booster"
:
"""Refit the existing Booster by new data.
...
...
python-package/lightgbm/compat.py
View file @
b27d81ea
# coding: utf-8
"""Compatibility library."""
from
typing
import
List
from
typing
import
Any
,
List
"""pandas"""
try
:
...
...
@@ -20,19 +20,19 @@ except ImportError:
class
pd_Series
:
# type: ignore
"""Dummy class for pandas.Series."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
class
pd_DataFrame
:
# type: ignore
"""Dummy class for pandas.DataFrame."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
class
pd_CategoricalDtype
:
# type: ignore
"""Dummy class for pandas.CategoricalDtype."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
concat
=
None
...
...
@@ -45,7 +45,7 @@ except ImportError:
class
np_random_Generator
:
# type: ignore
"""Dummy class for np.random.Generator."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
...
...
@@ -80,7 +80,7 @@ except ImportError:
class
dt_DataTable
:
# type: ignore
"""Dummy class for datatable.DataTable."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
...
...
@@ -104,7 +104,7 @@ try:
from
sklearn.utils.validation
import
check_consistent_length
# dummy function to support older version of scikit-learn
def
_check_sample_weight
(
sample_weight
,
X
,
dtype
=
None
)
:
def
_check_sample_weight
(
sample_weight
:
Any
,
X
:
Any
,
dtype
:
Any
=
None
)
->
Any
:
check_consistent_length
(
sample_weight
,
X
)
return
sample_weight
...
...
@@ -176,31 +176,31 @@ except ImportError:
class
Client
:
# type: ignore
"""Dummy class for dask.distributed.Client."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
class
Future
:
# type: ignore
"""Dummy class for dask.distributed.Future."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
class
dask_Array
:
# type: ignore
"""Dummy class for dask.array.Array."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
class
dask_DataFrame
:
# type: ignore
"""Dummy class for dask.dataframe.DataFrame."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
class
dask_Series
:
# type: ignore
"""Dummy class for dask.dataframe.Series."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
...
...
@@ -222,19 +222,19 @@ except ImportError:
class
pa_Array
:
# type: ignore
"""Dummy class for pa.Array."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
class
pa_ChunkedArray
:
# type: ignore
"""Dummy class for pa.ChunkedArray."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
class
pa_Table
:
# type: ignore
"""Dummy class for pa.Table."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
class
arrow_cffi
:
# type: ignore
...
...
@@ -245,7 +245,7 @@ except ImportError:
cast
=
None
new
=
None
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
pass
class
pa_compute
:
# type: ignore
...
...
python-package/lightgbm/plotting.py
View file @
b27d81ea
...
...
@@ -3,7 +3,7 @@
import
math
from
copy
import
deepcopy
from
io
import
BytesIO
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -19,6 +19,9 @@ __all__ = [
"plot_tree"
,
]
if
TYPE_CHECKING
:
import
matplotlib
def
_check_not_tuple_of_2_elements
(
obj
:
Any
,
obj_name
:
str
)
->
None
:
"""Check object is not tuple or does not have 2 elements."""
...
...
@@ -32,7 +35,7 @@ def _float2str(value: float, precision: Optional[int]) -> str:
def
plot_importance
(
booster
:
Union
[
Booster
,
LGBMModel
],
ax
=
None
,
ax
:
"Optional[matplotlib.axes.Axes]"
=
None
,
height
:
float
=
0.2
,
xlim
:
Optional
[
Tuple
[
float
,
float
]]
=
None
,
ylim
:
Optional
[
Tuple
[
float
,
float
]]
=
None
,
...
...
@@ -168,7 +171,7 @@ def plot_split_value_histogram(
booster
:
Union
[
Booster
,
LGBMModel
],
feature
:
Union
[
int
,
str
],
bins
:
Union
[
int
,
str
,
None
]
=
None
,
ax
=
None
,
ax
:
"Optional[matplotlib.axes.Axes]"
=
None
,
width_coef
:
float
=
0.8
,
xlim
:
Optional
[
Tuple
[
float
,
float
]]
=
None
,
ylim
:
Optional
[
Tuple
[
float
,
float
]]
=
None
,
...
...
@@ -284,7 +287,7 @@ def plot_metric(
booster
:
Union
[
Dict
,
LGBMModel
],
metric
:
Optional
[
str
]
=
None
,
dataset_names
:
Optional
[
List
[
str
]]
=
None
,
ax
=
None
,
ax
:
"Optional[matplotlib.axes.Axes]"
=
None
,
xlim
:
Optional
[
Tuple
[
float
,
float
]]
=
None
,
ylim
:
Optional
[
Tuple
[
float
,
float
]]
=
None
,
title
:
Optional
[
str
]
=
"Metric during training"
,
...
...
@@ -735,7 +738,7 @@ def create_tree_digraph(
def
plot_tree
(
booster
:
Union
[
Booster
,
LGBMModel
],
ax
=
None
,
ax
:
"Optional[matplotlib.axes.Axes]"
=
None
,
tree_index
:
int
=
0
,
figsize
:
Optional
[
Tuple
[
float
,
float
]]
=
None
,
dpi
:
Optional
[
int
]
=
None
,
...
...
python-package/lightgbm/sklearn.py
View file @
b27d81ea
...
...
@@ -478,7 +478,7 @@ class LGBMModel(_LGBMModelBase):
random_state
:
Optional
[
Union
[
int
,
np
.
random
.
RandomState
,
"np.random.Generator"
]]
=
None
,
n_jobs
:
Optional
[
int
]
=
None
,
importance_type
:
str
=
"split"
,
**
kwargs
,
**
kwargs
:
Any
,
):
r
"""Construct a gradient boosting model.
...
...
python-package/pyproject.toml
View file @
b27d81ea
...
...
@@ -92,6 +92,7 @@ skip_glob = [
]
[tool.mypy]
disallow_untyped_defs
=
true
exclude
=
'build/*|compile/*|docs/*|examples/*|external_libs/*|lightgbm-python/*|tests/*'
ignore_missing_imports
=
true
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment