"...composable_kernel.git" did not exist on "c71e140d32a4fcadee95415b7cb8fde05e68e02a"
Unverified Commit 3f6cbcc8 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Fix typehint (#4700)

parent 8490f7e4
......@@ -7,16 +7,21 @@ Deduplicate repeated parameters.
No guarantee for forward-compatibility.
"""
from __future__ import annotations
import logging
import typing
import nni
from .formatting import deformat_parameters
_logger = logging.getLogger(__name__)
from .formatting import FormattedParameters, FormattedSearchSpace, ParameterSpec, deformat_parameters
# TODO:
# Move main logic of basic tuners (random and grid search) into SDK,
# so we can get rid of private methods and circular dependency.
if typing.TYPE_CHECKING:
from nni.algorithms.hpo.gridsearch_tuner import GridSearchTuner
_logger = logging.getLogger(__name__)
class Deduplicator:
"""
......@@ -36,13 +41,13 @@ class Deduplicator:
See random tuner's source code for example usage.
"""
def __init__(self, formatted_search_space):
self._space = formatted_search_space
self._never_dup = any(_spec_never_dup(spec) for spec in self._space.values())
self._history = set()
self._grid_search = None
def __init__(self, formatted_search_space: FormattedSearchSpace):
self._space: FormattedSearchSpace = formatted_search_space
self._never_dup: bool = any(_spec_never_dup(spec) for spec in self._space.values())
self._history: set[str] = set()
self._grid_search: GridSearchTuner | None = None
def __call__(self, formatted_parameters):
def __call__(self, formatted_parameters: FormattedParameters) -> FormattedParameters:
if self._never_dup or self._not_dup(formatted_parameters):
return formatted_parameters
......@@ -52,29 +57,29 @@ class Deduplicator:
self._init_grid_search()
while True:
new = self._grid_search._suggest()
new = self._grid_search._suggest() # type: ignore
if new is None:
raise nni.NoMoreTrialError()
if self._not_dup(new):
return new
def _init_grid_search(self):
def _init_grid_search(self) -> None:
from nni.algorithms.hpo.gridsearch_tuner import GridSearchTuner
self._grid_search = GridSearchTuner()
self._grid_search.history = self._history
self._grid_search.space = self._space
self._grid_search._init_grid()
def _not_dup(self, formatted_parameters):
def _not_dup(self, formatted_parameters: FormattedParameters) -> bool:
params = deformat_parameters(formatted_parameters, self._space)
params_str = nni.dump(params, sort_keys=True)
params_str = typing.cast(str, nni.dump(params, sort_keys=True))
if params_str in self._history:
return False
else:
self._history.add(params_str)
return True
def _spec_never_dup(spec):
def _spec_never_dup(spec: ParameterSpec) -> bool:
if spec.is_nested():
return False # "not chosen" duplicates with "not chosen"
if spec.categorical or spec.q is not None:
......
......@@ -27,7 +27,7 @@ _dispatcher_env_var_names = [
def _load_env_vars(env_var_names):
env_var_dict = {k: os.environ.get(k) for k in env_var_names}
return namedtuple('EnvVars', env_var_names)(**env_var_dict)
return namedtuple('EnvVars', env_var_names)(**env_var_dict) # pylint: disable=unused-variable
trial_env_vars = _load_env_vars(_trial_env_var_names)
......
......@@ -28,8 +28,8 @@ def main(argv):
def parse_nvidia_smi_result(smi, outputDir):
old_umask = os.umask(0)
try:
old_umask = os.umask(0)
xmldoc = minidom.parseString(smi)
gpuList = xmldoc.getElementsByTagName('gpu')
with open(os.path.join(outputDir, "gpu_metrics"), 'a') as outputFile:
......@@ -62,8 +62,8 @@ def parse_nvidia_smi_result(smi, outputDir):
def gen_empty_gpu_metric(outputDir):
old_umask = os.umask(0)
try:
old_umask = os.umask(0)
with open(os.path.join(outputDir, "gpu_metrics"), 'a') as outputFile:
outPut = {}
outPut["Timestamp"] = time.asctime(time.localtime())
......
......@@ -65,6 +65,7 @@ stages:
- script: |
python -m pyright nni
displayName: Type Check
- job: typescript
pool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment