Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
01774bb9
Unverified
Commit
01774bb9
authored
Aug 23, 2022
by
James Lamb
Committed by
GitHub
Aug 23, 2022
Browse files
[python-package] add more type hints on Dataset (#5431)
parent
78f95e41
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
14 deletions
+32
-14
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+32
-14
No files found.
python-package/lightgbm/basic.py
View file @
01774bb9
...
@@ -619,7 +619,10 @@ def _dump_pandas_categorical(pandas_categorical, file_name=None):
...
@@ -619,7 +619,10 @@ def _dump_pandas_categorical(pandas_categorical, file_name=None):
return
pandas_str
return
pandas_str
def
_load_pandas_categorical
(
file_name
=
None
,
model_str
=
None
):
def
_load_pandas_categorical
(
file_name
:
Optional
[
Union
[
str
,
Path
]]
=
None
,
model_str
:
Optional
[
str
]
=
None
)
->
Optional
[
str
]:
pandas_key
=
'pandas_categorical:'
pandas_key
=
'pandas_categorical:'
offset
=
-
len
(
pandas_key
)
offset
=
-
len
(
pandas_key
)
if
file_name
is
not
None
:
if
file_name
is
not
None
:
...
@@ -1879,7 +1882,15 @@ class Dataset:
...
@@ -1879,7 +1882,15 @@ class Dataset:
self
.
feature_name
=
self
.
get_feature_name
()
self
.
feature_name
=
self
.
get_feature_name
()
return
self
return
self
def
create_valid
(
self
,
data
,
label
=
None
,
weight
=
None
,
group
=
None
,
init_score
=
None
,
params
=
None
):
def
create_valid
(
self
,
data
,
label
=
None
,
weight
=
None
,
group
=
None
,
init_score
=
None
,
params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
"Dataset"
:
"""Create validation data align with current Dataset.
"""Create validation data align with current Dataset.
Parameters
Parameters
...
@@ -1966,7 +1977,7 @@ class Dataset:
...
@@ -1966,7 +1977,7 @@ class Dataset:
c_str
(
str
(
filename
))))
c_str
(
str
(
filename
))))
return
self
return
self
def
_update_params
(
self
,
params
)
:
def
_update_params
(
self
,
params
:
Optional
[
Dict
[
str
,
Any
]])
->
"Dataset"
:
if
not
params
:
if
not
params
:
return
self
return
self
params
=
deepcopy
(
params
)
params
=
deepcopy
(
params
)
...
@@ -1999,7 +2010,11 @@ class Dataset:
...
@@ -1999,7 +2010,11 @@ class Dataset:
self
.
params_back_up
=
None
self
.
params_back_up
=
None
return
self
return
self
def
set_field
(
self
,
field_name
,
data
):
def
set_field
(
self
,
field_name
:
str
,
data
)
->
"Dataset"
:
"""Set property into the Dataset.
"""Set property into the Dataset.
Parameters
Parameters
...
@@ -2135,7 +2150,10 @@ class Dataset:
...
@@ -2135,7 +2150,10 @@ class Dataset:
raise
LightGBMError
(
"Cannot set categorical feature after freed raw data, "
raise
LightGBMError
(
"Cannot set categorical feature after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this."
)
"set free_raw_data=False when construct Dataset to avoid this."
)
def
_set_predictor
(
self
,
predictor
):
def
_set_predictor
(
self
,
predictor
:
Optional
[
_InnerPredictor
]
)
->
"Dataset"
:
"""Set predictor for continued training.
"""Set predictor for continued training.
It is not recommended for user to call this function.
It is not recommended for user to call this function.
...
@@ -2156,7 +2174,7 @@ class Dataset:
...
@@ -2156,7 +2174,7 @@ class Dataset:
"set free_raw_data=False when construct Dataset to avoid this."
)
"set free_raw_data=False when construct Dataset to avoid this."
)
return
self
return
self
def
set_reference
(
self
,
reference
)
:
def
set_reference
(
self
,
reference
:
"Dataset"
)
->
"Dataset"
:
"""Set reference Dataset.
"""Set reference Dataset.
Parameters
Parameters
...
@@ -2207,7 +2225,7 @@ class Dataset:
...
@@ -2207,7 +2225,7 @@ class Dataset:
ctypes
.
c_int
(
len
(
feature_name
))))
ctypes
.
c_int
(
len
(
feature_name
))))
return
self
return
self
def
set_label
(
self
,
label
):
def
set_label
(
self
,
label
)
->
"Dataset"
:
"""Set label of Dataset.
"""Set label of Dataset.
Parameters
Parameters
...
@@ -2227,7 +2245,7 @@ class Dataset:
...
@@ -2227,7 +2245,7 @@ class Dataset:
self
.
label
=
self
.
get_field
(
'label'
)
# original values can be modified at cpp side
self
.
label
=
self
.
get_field
(
'label'
)
# original values can be modified at cpp side
return
self
return
self
def
set_weight
(
self
,
weight
):
def
set_weight
(
self
,
weight
)
->
"Dataset"
:
"""Set weight of each instance.
"""Set weight of each instance.
Parameters
Parameters
...
@@ -2249,7 +2267,7 @@ class Dataset:
...
@@ -2249,7 +2267,7 @@ class Dataset:
self
.
weight
=
self
.
get_field
(
'weight'
)
# original values can be modified at cpp side
self
.
weight
=
self
.
get_field
(
'weight'
)
# original values can be modified at cpp side
return
self
return
self
def
set_init_score
(
self
,
init_score
):
def
set_init_score
(
self
,
init_score
)
->
"Dataset"
:
"""Set init score of Booster to start from.
"""Set init score of Booster to start from.
Parameters
Parameters
...
@@ -2268,7 +2286,7 @@ class Dataset:
...
@@ -2268,7 +2286,7 @@ class Dataset:
self
.
init_score
=
self
.
get_field
(
'init_score'
)
# original values can be modified at cpp side
self
.
init_score
=
self
.
get_field
(
'init_score'
)
# original values can be modified at cpp side
return
self
return
self
def
set_group
(
self
,
group
):
def
set_group
(
self
,
group
)
->
"Dataset"
:
"""Set group size of Dataset (used for ranking).
"""Set group size of Dataset (used for ranking).
Parameters
Parameters
...
@@ -2330,7 +2348,7 @@ class Dataset:
...
@@ -2330,7 +2348,7 @@ class Dataset:
ptr_string_buffers
))
ptr_string_buffers
))
return
[
string_buffers
[
i
].
value
.
decode
(
'utf-8'
)
for
i
in
range
(
num_feature
)]
return
[
string_buffers
[
i
].
value
.
decode
(
'utf-8'
)
for
i
in
range
(
num_feature
)]
def
get_label
(
self
):
def
get_label
(
self
)
->
Optional
[
np
.
ndarray
]
:
"""Get the label of the Dataset.
"""Get the label of the Dataset.
Returns
Returns
...
@@ -2342,7 +2360,7 @@ class Dataset:
...
@@ -2342,7 +2360,7 @@ class Dataset:
self
.
label
=
self
.
get_field
(
'label'
)
self
.
label
=
self
.
get_field
(
'label'
)
return
self
.
label
return
self
.
label
def
get_weight
(
self
):
def
get_weight
(
self
)
->
Optional
[
np
.
ndarray
]
:
"""Get the weight of the Dataset.
"""Get the weight of the Dataset.
Returns
Returns
...
@@ -2354,7 +2372,7 @@ class Dataset:
...
@@ -2354,7 +2372,7 @@ class Dataset:
self
.
weight
=
self
.
get_field
(
'weight'
)
self
.
weight
=
self
.
get_field
(
'weight'
)
return
self
.
weight
return
self
.
weight
def
get_init_score
(
self
):
def
get_init_score
(
self
)
->
Optional
[
np
.
ndarray
]
:
"""Get the initial score of the Dataset.
"""Get the initial score of the Dataset.
Returns
Returns
...
@@ -2473,7 +2491,7 @@ class Dataset:
...
@@ -2473,7 +2491,7 @@ class Dataset:
else
:
else
:
raise
LightGBMError
(
"Cannot get feature_num_bin before construct dataset"
)
raise
LightGBMError
(
"Cannot get feature_num_bin before construct dataset"
)
def
get_ref_chain
(
self
,
ref_limit
=
100
)
:
def
get_ref_chain
(
self
,
ref_limit
:
int
=
100
)
->
Set
[
"Dataset"
]
:
"""Get a chain of Dataset objects.
"""Get a chain of Dataset objects.
Starts with r, then goes to r.reference (if exists),
Starts with r, then goes to r.reference (if exists),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment