Unverified Commit dc1bc23a authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on Booster.set_network() (#4068)

* [python-package] add type hints on Booster.set_network()

* change behavior
parent b044070e
...@@ -9,7 +9,7 @@ from copy import deepcopy ...@@ -9,7 +9,7 @@ from copy import deepcopy
from functools import wraps from functools import wraps
from logging import Logger from logging import Logger
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any, Dict from typing import Any, Dict, List, Set, Union
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
...@@ -2336,8 +2336,13 @@ class Booster: ...@@ -2336,8 +2336,13 @@ class Booster:
self.__is_predicted_cur_iter = [] self.__is_predicted_cur_iter = []
return self return self
def set_network(self, machines, local_listen_port=12400, def set_network(
listen_time_out=120, num_machines=1): self,
machines: Union[List[str], Set[str], str],
local_listen_port: int = 12400,
listen_time_out: int = 120,
num_machines: int = 1
) -> "Booster":
"""Set the network configuration. """Set the network configuration.
Parameters Parameters
...@@ -2356,6 +2361,8 @@ class Booster: ...@@ -2356,6 +2361,8 @@ class Booster:
self : Booster self : Booster
Booster with set network. Booster with set network.
""" """
if isinstance(machines, (list, set)):
machines = ','.join(machines)
_safe_call(_LIB.LGBM_NetworkInit(c_str(machines), _safe_call(_LIB.LGBM_NetworkInit(c_str(machines),
ctypes.c_int(local_listen_port), ctypes.c_int(local_listen_port),
ctypes.c_int(listen_time_out), ctypes.c_int(listen_time_out),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment