Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
6de94ef8
Unverified
Commit
6de94ef8
authored
Mar 31, 2023
by
James Lamb
Committed by
GitHub
Mar 31, 2023
Browse files
[python-package] add type annotations on some array methods in basic.py (#5813)
parent
216eaff7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
4 deletions
+15
-4
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+15
-4
No files found.
python-package/lightgbm/basic.py
View file @
6de94ef8
...
...
@@ -34,6 +34,10 @@ _ctypes_int_ptr = Union[
"ctypes._Pointer[ctypes.c_int32]"
,
"ctypes._Pointer[ctypes.c_int64]"
]
_ctypes_int_array
=
Union
[
"ctypes.Array[ctypes._Pointer[ctypes.c_int32]]"
,
"ctypes.Array[ctypes._Pointer[ctypes.c_int64]]"
]
_ctypes_float_ptr
=
Union
[
"ctypes._Pointer[ctypes.c_float]"
,
"ctypes._Pointer[ctypes.c_double]"
...
...
@@ -589,13 +593,16 @@ def _convert_from_sliced_object(data: np.ndarray) -> np.ndarray:
return
data
def
_c_float_array
(
data
):
def
_c_float_array
(
data
:
np
.
ndarray
)
->
Tuple
[
_ctypes_float_ptr
,
int
,
np
.
ndarray
]:
"""Get pointer of float numpy array / list."""
if
_is_1d_list
(
data
):
data
=
np
.
array
(
data
,
copy
=
False
)
if
_is_numpy_1d_array
(
data
):
data
=
_convert_from_sliced_object
(
data
)
assert
data
.
flags
.
c_contiguous
ptr_data
:
_ctypes_float_ptr
if
data
.
dtype
==
np
.
float32
:
ptr_data
=
data
.
ctypes
.
data_as
(
ctypes
.
POINTER
(
ctypes
.
c_float
))
type_data
=
_C_API_DTYPE_FLOAT32
...
...
@@ -609,13 +616,16 @@ def _c_float_array(data):
return
(
ptr_data
,
type_data
,
data
)
# return `data` to avoid the temporary copy is freed
def
_c_int_array
(
data
):
def
_c_int_array
(
data
:
np
.
ndarray
)
->
Tuple
[
_ctypes_int_ptr
,
int
,
np
.
ndarray
]:
"""Get pointer of int numpy array / list."""
if
_is_1d_list
(
data
):
data
=
np
.
array
(
data
,
copy
=
False
)
if
_is_numpy_1d_array
(
data
):
data
=
_convert_from_sliced_object
(
data
)
assert
data
.
flags
.
c_contiguous
ptr_data
:
_ctypes_int_ptr
if
data
.
dtype
==
np
.
int32
:
ptr_data
=
data
.
ctypes
.
data_as
(
ctypes
.
POINTER
(
ctypes
.
c_int32
))
type_data
=
_C_API_DTYPE_INT32
...
...
@@ -1624,10 +1634,10 @@ class Dataset:
# c type: double**
# each double* element points to start of each column of sample data.
sample_col_ptr
=
(
ctypes
.
POINTER
(
ctypes
.
c_double
)
*
ncol
)()
sample_col_ptr
:
_ctypes_float_array
=
(
ctypes
.
POINTER
(
ctypes
.
c_double
)
*
ncol
)()
# c type int**
# each int* points to start of indices for each column
indices_col_ptr
=
(
ctypes
.
POINTER
(
ctypes
.
c_int32
)
*
ncol
)()
indices_col_ptr
:
_ctypes_int_array
=
(
ctypes
.
POINTER
(
ctypes
.
c_int32
)
*
ncol
)()
for
i
in
range
(
ncol
):
sample_col_ptr
[
i
]
=
_c_float_array
(
sample_data
[
i
])[
0
]
indices_col_ptr
[
i
]
=
_c_int_array
(
sample_indices
[
i
])[
0
]
...
...
@@ -2374,6 +2384,7 @@ class Dataset:
dtype
=
np
.
int32
if
field_name
==
'group'
else
np
.
float32
data
=
_list_to_1d_numpy
(
data
,
dtype
,
name
=
field_name
)
ptr_data
:
Union
[
_ctypes_float_ptr
,
_ctypes_int_ptr
]
if
data
.
dtype
==
np
.
float32
or
data
.
dtype
==
np
.
float64
:
ptr_data
,
type_data
,
_
=
_c_float_array
(
data
)
elif
data
.
dtype
==
np
.
int32
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment