Unverified Commit f4bcda10 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Replace nni.typehint with typing_extensions (#4802)

parent 290824c1
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from schema import Schema, Optional from schema import Schema, Optional
from typing_extensions import Literal
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.assessor import Assessor, AssessResult from nni.assessor import Assessor, AssessResult
from nni.typehint import Literal
from nni.utils import extract_scalar_history from nni.utils import extract_scalar_history
logger = logging.getLogger('medianstop_Assessor') logger = logging.getLogger('medianstop_Assessor')
......
...@@ -22,10 +22,10 @@ from typing import Any, NamedTuple ...@@ -22,10 +22,10 @@ from typing import Any, NamedTuple
import numpy as np import numpy as np
from scipy.special import erf # pylint: disable=no-name-in-module from scipy.special import erf # pylint: disable=no-name-in-module
from typing_extensions import Literal
from nni.common.hpo_utils import Deduplicator, OptimizeMode, format_search_space, deformat_parameters, format_parameters from nni.common.hpo_utils import Deduplicator, OptimizeMode, format_search_space, deformat_parameters, format_parameters
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.typehint import Literal
from nni.utils import extract_scalar_reward from nni.utils import extract_scalar_reward
from . import random_tuner from . import random_tuner
......
...@@ -19,7 +19,7 @@ from gym import spaces ...@@ -19,7 +19,7 @@ from gym import spaces
from tianshou.data import to_torch from tianshou.data import to_torch
from tianshou.env.worker import EnvWorker from tianshou.env.worker import EnvWorker
from nni.typehint import TypedDict from typing_extensions import TypedDict
from .utils import get_targeted_model from .utils import get_targeted_model
from ..graph import ModelStatus from ..graph import ModelStatus
......
...@@ -7,7 +7,7 @@ __all__ = ['AlgoMeta'] ...@@ -7,7 +7,7 @@ __all__ = ['AlgoMeta']
from typing import NamedTuple from typing import NamedTuple
from nni.typehint import Literal from typing_extensions import Literal
class AlgoMeta(NamedTuple): class AlgoMeta(NamedTuple):
name: str name: str
......
...@@ -13,7 +13,8 @@ import os ...@@ -13,7 +13,8 @@ import os
import sys import sys
from typing import Any from typing import Any
from nni.typehint import Literal from typing_extensions import Literal
from . import config_manager from . import config_manager
ALGO_TYPES = ['tuners', 'assessors'] ALGO_TYPES = ['tuners', 'assessors']
......
...@@ -5,18 +5,11 @@ ...@@ -5,18 +5,11 @@
Types for static checking. Types for static checking.
""" """
__all__ = [ __all__ = ['Parameters', 'SearchSpace', 'TrialMetric', 'TrialRecord']
'Literal', 'TypedDict',
'Parameters', 'SearchSpace', 'TrialMetric', 'TrialRecord', from typing import Any, Dict, List
]
from typing_extensions import Literal, TypedDict
import sys
from typing import Any, Dict, List, TYPE_CHECKING
if TYPE_CHECKING or sys.version_info >= (3, 8):
from typing import Literal, TypedDict
else:
from typing_extensions import Literal, TypedDict
Parameters = Dict[str, Any] Parameters = Dict[str, Any]
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment