Unverified Commit d38359e2 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Pin pyright version (#4902)

parent f9ea49ff
...@@ -7,7 +7,7 @@ jupyter ...@@ -7,7 +7,7 @@ jupyter
jupyterlab == 3.0.9 jupyterlab == 3.0.9
nbsphinx nbsphinx
pylint pylint
pyright pyright == 1.1.250
pytest pytest
pytest-cov pytest-cov
rstcheck rstcheck
......
...@@ -18,4 +18,5 @@ matplotlib ...@@ -18,4 +18,5 @@ matplotlib
# TODO: time to drop tensorflow 1.x # TODO: time to drop tensorflow 1.x
keras keras
tensorflow < 2.0 tensorflow < 2.0
protobuf <= 3.20.1
timm >= 0.5.4 timm >= 0.5.4
\ No newline at end of file
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Dict, List, Callable, Optional from typing import Dict, List, Callable, Optional, cast
import json_tricks import json_tricks
import torch import torch
...@@ -73,7 +73,8 @@ class AMCTaskGenerator(TaskGenerator): ...@@ -73,7 +73,8 @@ class AMCTaskGenerator(TaskGenerator):
total_sparsity = config_list_copy[0]['total_sparsity'] total_sparsity = config_list_copy[0]['total_sparsity']
max_sparsity_per_layer = config_list_copy[0].get('max_sparsity_per_layer', 1.) max_sparsity_per_layer = config_list_copy[0].get('max_sparsity_per_layer', 1.)
self.env = AMCEnv(origin_model, origin_config_list, self.dummy_input, total_sparsity, max_sparsity_per_layer, self.target) self.env = AMCEnv(origin_model, origin_config_list, self.dummy_input, total_sparsity,
cast(Dict[str, float], max_sparsity_per_layer), self.target)
self.agent = DDPG(len(self.env.state_feature), 1, self.ddpg_params) self.agent = DDPG(len(self.env.state_feature), 1, self.ddpg_params)
self.agent.is_training = True self.agent.is_training = True
task_result = TaskResult('origin', origin_model, origin_masks, origin_masks, None) task_result = TaskResult('origin', origin_model, origin_masks, origin_masks, None)
......
...@@ -25,8 +25,8 @@ class AMCEnv: ...@@ -25,8 +25,8 @@ class AMCEnv:
for i, (name, layer) in enumerate(model.named_modules()): for i, (name, layer) in enumerate(model.named_modules()):
if name in pruning_op_names: if name in pruning_op_names:
op_type = type(layer).__name__ op_type = type(layer).__name__
stride = np.power(np.prod(layer.stride), 1 / len(layer.stride)) if hasattr(layer, 'stride') else 0 stride = np.power(np.prod(layer.stride), 1 / len(layer.stride)) if hasattr(layer, 'stride') else 0 # type: ignore
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) if hasattr(layer, 'kernel_size') else 1 kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) if hasattr(layer, 'kernel_size') else 1 # type: ignore
self.pruning_ops[name] = (i, op_type, stride, kernel_size) self.pruning_ops[name] = (i, op_type, stride, kernel_size)
self.pruning_types.append(op_type) self.pruning_types.append(op_type)
self.pruning_types = list(set(self.pruning_types)) self.pruning_types = list(set(self.pruning_types))
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from collections import deque, namedtuple from collections import deque, namedtuple
from typing import Any, List from typing import Any, List, cast
import warnings import warnings
import random import random
...@@ -174,7 +174,7 @@ class SequentialMemory(Memory): ...@@ -174,7 +174,7 @@ class SequentialMemory(Memory):
# to the right. Again, we need to be careful to not include an observation from the next # to the right. Again, we need to be careful to not include an observation from the next
# episode if the last state is terminal. # episode if the last state is terminal.
state1 = [np.copy(x) for x in state0[1:]] state1 = [np.copy(x) for x in state0[1:]]
state1.append(self.observations[idx]) state1.append(cast(np.ndarray, self.observations[idx]))
assert len(state0) == self.window_length assert len(state0) == self.window_length
assert len(state1) == len(state0) assert len(state1) == len(state0)
......
...@@ -13,7 +13,7 @@ import sys ...@@ -13,7 +13,7 @@ import sys
import types import types
import warnings import warnings
from io import IOBase from io import IOBase
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast from typing import Any, Dict, List, Optional, Type, TypeVar, Tuple, Union, cast
import cloudpickle # use cloudpickle as backend for unserializable types and instances import cloudpickle # use cloudpickle as backend for unserializable types and instances
import json_tricks # use json_tricks as serializer backend import json_tricks # use json_tricks as serializer backend
...@@ -604,7 +604,7 @@ class _unwrap_metaclass(type): ...@@ -604,7 +604,7 @@ class _unwrap_metaclass(type):
def __new__(cls, name, bases, dct): def __new__(cls, name, bases, dct):
bases = tuple([getattr(base, '__wrapped__', base) for base in bases]) bases = tuple([getattr(base, '__wrapped__', base) for base in bases])
return super().__new__(cls, name, bases, dct) return super().__new__(cls, name, cast(Tuple[type, ...], bases), dct)
# Using a customized "bases" breaks default isinstance and issubclass. # Using a customized "bases" breaks default isinstance and issubclass.
# We recover this by overriding the subclass and isinstance behavior, which conerns wrapped class only. # We recover this by overriding the subclass and isinstance behavior, which conerns wrapped class only.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from abc import ABC, abstractmethod, abstractclassmethod from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, Iterable, NewType, List, Union from typing import Any, Iterable, NewType, List, Union, Type
from ..graph import Model, MetricData from ..graph import Model, MetricData
...@@ -12,7 +12,7 @@ __all__ = [ ...@@ -12,7 +12,7 @@ __all__ = [
] ]
GraphData = NewType('GraphData', Any) GraphData: Type[Any] = NewType('GraphData', Any)
""" """
A _serializable_ internal data type defined by execution engine. A _serializable_ internal data type defined by execution engine.
...@@ -26,7 +26,7 @@ See `AbstractExecutionEngine` for details. ...@@ -26,7 +26,7 @@ See `AbstractExecutionEngine` for details.
""" """
WorkerInfo = NewType('WorkerInfo', Any) WorkerInfo: Type[Any] = NewType('WorkerInfo', Any)
""" """
To be designed. Discussion needed. To be designed. Discussion needed.
...@@ -114,7 +114,7 @@ class AbstractExecutionEngine(ABC): ...@@ -114,7 +114,7 @@ class AbstractExecutionEngine(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def query_available_resource(self) -> Union[List[WorkerInfo], int]: def query_available_resource(self) -> Union[List[WorkerInfo], int]: # type: ignore
""" """
Returns information of all idle workers. Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers. If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import Dict, Any, Type from typing import Dict, Any, Type, cast
import torch.nn as nn import torch.nn as nn
...@@ -53,7 +53,10 @@ class PurePythonExecutionEngine(BaseExecutionEngine): ...@@ -53,7 +53,10 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
def pack_model_data(cls, model: Model) -> Any: def pack_model_data(cls, model: Model) -> Any:
mutation = get_mutation_dict(model) mutation = get_mutation_dict(model)
assert model.evaluator is not None, 'Model evaluator is not available.' assert model.evaluator is not None, 'Model evaluator is not available.'
graph_data = PythonGraphData(model.python_class, model.python_init_params or {}, mutation, model.evaluator) graph_data = PythonGraphData(
cast(Type[nn.Module], model.python_class),
model.python_init_params or {}, mutation, model.evaluator
)
return graph_data return graph_data
@classmethod @classmethod
......
...@@ -351,7 +351,7 @@ class RetiariiExperiment(Experiment): ...@@ -351,7 +351,7 @@ class RetiariiExperiment(Experiment):
# when strategy hasn't implemented its own export logic # when strategy hasn't implemented its own export logic
all_models = filter(lambda m: m.metric is not None, list_models()) all_models = filter(lambda m: m.metric is not None, list_models())
assert optimize_mode in ['maximize', 'minimize'] assert optimize_mode in ['maximize', 'minimize']
all_models = sorted(all_models, key=lambda m: m.metric, reverse=optimize_mode == 'maximize') all_models = sorted(all_models, key=lambda m: cast(float, m.metric), reverse=optimize_mode == 'maximize')
assert formatter in ['code', 'dict'], 'Export formatter other than "code" and "dict" is not supported yet.' assert formatter in ['code', 'dict'], 'Export formatter other than "code" and "dict" is not supported yet.'
if formatter == 'code': if formatter == 'code':
return [model_to_pytorch_script(model) for model in all_models[:top_k]] return [model_to_pytorch_script(model) for model in all_models[:top_k]]
......
...@@ -56,8 +56,8 @@ class Evaluator(abc.ABC): ...@@ -56,8 +56,8 @@ class Evaluator(abc.ABC):
if subclass.__name__ == evaluator_type: if subclass.__name__ == evaluator_type:
evaluator_type = subclass evaluator_type = subclass
break break
assert issubclass(evaluator_type, Evaluator) assert issubclass(cast(type, evaluator_type), Evaluator)
return evaluator_type._load(ir) return cast(Type[Evaluator], evaluator_type)._load(ir)
@abc.abstractmethod @abc.abstractmethod
def _dump(self) -> Any: def _dump(self) -> Any:
...@@ -350,7 +350,7 @@ class Graph: ...@@ -350,7 +350,7 @@ class Graph:
if isinstance(operation_or_type, Operation): if isinstance(operation_or_type, Operation):
op = operation_or_type op = operation_or_type
else: else:
op = Operation.new(operation_or_type, parameters, name) op = Operation.new(operation_or_type, cast(dict, parameters), name)
return Node(self, uid(), name, op, _internal=True)._register() return Node(self, uid(), name, op, _internal=True)._register()
@overload @overload
...@@ -363,7 +363,7 @@ class Graph: ...@@ -363,7 +363,7 @@ class Graph:
if isinstance(operation_or_type, Operation): if isinstance(operation_or_type, Operation):
op = operation_or_type op = operation_or_type
else: else:
op = Operation.new(operation_or_type, parameters, name) op = Operation.new(operation_or_type, cast(dict, parameters), name)
new_node = Node(self, uid(), name, op, _internal=True)._register() new_node = Node(self, uid(), name, op, _internal=True)._register()
# update edges # update edges
self.add_edge((edge.head, edge.head_slot), (new_node, None)) self.add_edge((edge.head, edge.head_slot), (new_node, None))
......
...@@ -11,16 +11,18 @@ from nni.common.version import version_check ...@@ -11,16 +11,18 @@ from nni.common.version import version_check
# because it would induce cycled import # because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any) RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
_advisor: 'RetiariiAdvisor' = None _advisor = None # type is RetiariiAdvisor
def get_advisor() -> 'RetiariiAdvisor': def get_advisor():
# return type: RetiariiAdvisor
global _advisor global _advisor
assert _advisor is not None assert _advisor is not None
return _advisor return _advisor
def register_advisor(advisor: 'RetiariiAdvisor'): def register_advisor(advisor):
# type of advisor: RetiariiAdvisor
global _advisor global _advisor
if _advisor is not None: if _advisor is not None:
warnings.warn('Advisor is already set.' warnings.warn('Advisor is already set.'
......
...@@ -141,7 +141,7 @@ class ModelEvaluationEnv(gym.Env[ObservationType, int]): ...@@ -141,7 +141,7 @@ class ModelEvaluationEnv(gym.Env[ObservationType, int]):
wait_models(model) wait_models(model)
if model.status == ModelStatus.Failed: if model.status == ModelStatus.Failed:
return self.reset(), 0., False, {} return self.reset(), 0., False, {}
rew = float(model.metric) rew = float(model.metric) # type: ignore
_logger.info(f'Model metric received as reward: {rew}') _logger.info(f'Model metric received as reward: {rew}')
return obs, rew, True, {} return obs, rew, True, {}
else: else:
......
...@@ -93,7 +93,7 @@ def create_validator_instance(algo_type, builtin_name): ...@@ -93,7 +93,7 @@ def create_validator_instance(algo_type, builtin_name):
module_name, class_name = parse_full_class_name(meta['classArgsValidator']) module_name, class_name = parse_full_class_name(meta['classArgsValidator'])
assert module_name is not None assert module_name is not None
class_module = importlib.import_module(module_name) class_module = importlib.import_module(module_name)
class_constructor = getattr(class_module, class_name) class_constructor = getattr(class_module, class_name) # type: ignore
return class_constructor() return class_constructor()
...@@ -149,7 +149,7 @@ def create_builtin_class_instance( ...@@ -149,7 +149,7 @@ def create_builtin_class_instance(
raise RuntimeError('Builtin module can not be loaded: {}'.format(module_name)) raise RuntimeError('Builtin module can not be loaded: {}'.format(module_name))
class_module = importlib.import_module(module_name) class_module = importlib.import_module(module_name)
class_constructor = getattr(class_module, class_name) class_constructor = getattr(class_module, class_name) # type: ignore
instance = class_constructor(**class_args) instance = class_constructor(**class_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment