Unverified Commit d38359e2 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Pin pyright version (#4902)

parent f9ea49ff
......@@ -7,7 +7,7 @@ jupyter
jupyterlab == 3.0.9
nbsphinx
pylint
pyright
pyright == 1.1.250
pytest
pytest-cov
rstcheck
......
......@@ -18,4 +18,5 @@ matplotlib
# TODO: time to drop tensorflow 1.x
keras
tensorflow < 2.0
protobuf <= 3.20.1
timm >= 0.5.4
\ No newline at end of file
......@@ -3,7 +3,7 @@
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Callable, Optional
from typing import Dict, List, Callable, Optional, cast
import json_tricks
import torch
......@@ -73,7 +73,8 @@ class AMCTaskGenerator(TaskGenerator):
total_sparsity = config_list_copy[0]['total_sparsity']
max_sparsity_per_layer = config_list_copy[0].get('max_sparsity_per_layer', 1.)
self.env = AMCEnv(origin_model, origin_config_list, self.dummy_input, total_sparsity, max_sparsity_per_layer, self.target)
self.env = AMCEnv(origin_model, origin_config_list, self.dummy_input, total_sparsity,
cast(Dict[str, float], max_sparsity_per_layer), self.target)
self.agent = DDPG(len(self.env.state_feature), 1, self.ddpg_params)
self.agent.is_training = True
task_result = TaskResult('origin', origin_model, origin_masks, origin_masks, None)
......
......@@ -25,8 +25,8 @@ class AMCEnv:
for i, (name, layer) in enumerate(model.named_modules()):
if name in pruning_op_names:
op_type = type(layer).__name__
stride = np.power(np.prod(layer.stride), 1 / len(layer.stride)) if hasattr(layer, 'stride') else 0
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) if hasattr(layer, 'kernel_size') else 1
stride = np.power(np.prod(layer.stride), 1 / len(layer.stride)) if hasattr(layer, 'stride') else 0 # type: ignore
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) if hasattr(layer, 'kernel_size') else 1 # type: ignore
self.pruning_ops[name] = (i, op_type, stride, kernel_size)
self.pruning_types.append(op_type)
self.pruning_types = list(set(self.pruning_types))
......
......@@ -3,7 +3,7 @@
from __future__ import absolute_import
from collections import deque, namedtuple
from typing import Any, List
from typing import Any, List, cast
import warnings
import random
......@@ -174,7 +174,7 @@ class SequentialMemory(Memory):
# to the right. Again, we need to be careful to not include an observation from the next
# episode if the last state is terminal.
state1 = [np.copy(x) for x in state0[1:]]
state1.append(self.observations[idx])
state1.append(cast(np.ndarray, self.observations[idx]))
assert len(state0) == self.window_length
assert len(state1) == len(state0)
......
......@@ -13,7 +13,7 @@ import sys
import types
import warnings
from io import IOBase
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast
from typing import Any, Dict, List, Optional, Type, TypeVar, Tuple, Union, cast
import cloudpickle # use cloudpickle as backend for unserializable types and instances
import json_tricks # use json_tricks as serializer backend
......@@ -604,7 +604,7 @@ class _unwrap_metaclass(type):
def __new__(cls, name, bases, dct):
bases = tuple([getattr(base, '__wrapped__', base) for base in bases])
return super().__new__(cls, name, bases, dct)
return super().__new__(cls, name, cast(Tuple[type, ...], bases), dct)
# Using a customized "bases" breaks default isinstance and issubclass.
# We recover this by overriding the subclass and isinstance behavior, which conerns wrapped class only.
......
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, Iterable, NewType, List, Union
from typing import Any, Iterable, NewType, List, Union, Type
from ..graph import Model, MetricData
......@@ -12,7 +12,7 @@ __all__ = [
]
GraphData = NewType('GraphData', Any)
GraphData: Type[Any] = NewType('GraphData', Any)
"""
A _serializable_ internal data type defined by execution engine.
......@@ -26,7 +26,7 @@ See `AbstractExecutionEngine` for details.
"""
WorkerInfo = NewType('WorkerInfo', Any)
WorkerInfo: Type[Any] = NewType('WorkerInfo', Any)
"""
To be designed. Discussion needed.
......@@ -114,7 +114,7 @@ class AbstractExecutionEngine(ABC):
raise NotImplementedError
@abstractmethod
def query_available_resource(self) -> Union[List[WorkerInfo], int]:
def query_available_resource(self) -> Union[List[WorkerInfo], int]: # type: ignore
"""
Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, Any, Type
from typing import Dict, Any, Type, cast
import torch.nn as nn
......@@ -53,7 +53,10 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
def pack_model_data(cls, model: Model) -> Any:
mutation = get_mutation_dict(model)
assert model.evaluator is not None, 'Model evaluator is not available.'
graph_data = PythonGraphData(model.python_class, model.python_init_params or {}, mutation, model.evaluator)
graph_data = PythonGraphData(
cast(Type[nn.Module], model.python_class),
model.python_init_params or {}, mutation, model.evaluator
)
return graph_data
@classmethod
......
......@@ -351,7 +351,7 @@ class RetiariiExperiment(Experiment):
# when strategy hasn't implemented its own export logic
all_models = filter(lambda m: m.metric is not None, list_models())
assert optimize_mode in ['maximize', 'minimize']
all_models = sorted(all_models, key=lambda m: m.metric, reverse=optimize_mode == 'maximize')
all_models = sorted(all_models, key=lambda m: cast(float, m.metric), reverse=optimize_mode == 'maximize')
assert formatter in ['code', 'dict'], 'Export formatter other than "code" and "dict" is not supported yet.'
if formatter == 'code':
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
......
......@@ -56,8 +56,8 @@ class Evaluator(abc.ABC):
if subclass.__name__ == evaluator_type:
evaluator_type = subclass
break
assert issubclass(evaluator_type, Evaluator)
return evaluator_type._load(ir)
assert issubclass(cast(type, evaluator_type), Evaluator)
return cast(Type[Evaluator], evaluator_type)._load(ir)
@abc.abstractmethod
def _dump(self) -> Any:
......@@ -350,7 +350,7 @@ class Graph:
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
op = Operation.new(operation_or_type, parameters, name)
op = Operation.new(operation_or_type, cast(dict, parameters), name)
return Node(self, uid(), name, op, _internal=True)._register()
@overload
......@@ -363,7 +363,7 @@ class Graph:
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
op = Operation.new(operation_or_type, parameters, name)
op = Operation.new(operation_or_type, cast(dict, parameters), name)
new_node = Node(self, uid(), name, op, _internal=True)._register()
# update edges
self.add_edge((edge.head, edge.head_slot), (new_node, None))
......
......@@ -11,16 +11,18 @@ from nni.common.version import version_check
# because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
_advisor: 'RetiariiAdvisor' = None
_advisor = None # type is RetiariiAdvisor
def get_advisor() -> 'RetiariiAdvisor':
def get_advisor():
# return type: RetiariiAdvisor
global _advisor
assert _advisor is not None
return _advisor
def register_advisor(advisor: 'RetiariiAdvisor'):
def register_advisor(advisor):
# type of advisor: RetiariiAdvisor
global _advisor
if _advisor is not None:
warnings.warn('Advisor is already set.'
......
......@@ -141,7 +141,7 @@ class ModelEvaluationEnv(gym.Env[ObservationType, int]):
wait_models(model)
if model.status == ModelStatus.Failed:
return self.reset(), 0., False, {}
rew = float(model.metric)
rew = float(model.metric) # type: ignore
_logger.info(f'Model metric received as reward: {rew}')
return obs, rew, True, {}
else:
......
......@@ -93,7 +93,7 @@ def create_validator_instance(algo_type, builtin_name):
module_name, class_name = parse_full_class_name(meta['classArgsValidator'])
assert module_name is not None
class_module = importlib.import_module(module_name)
class_constructor = getattr(class_module, class_name)
class_constructor = getattr(class_module, class_name) # type: ignore
return class_constructor()
......@@ -149,7 +149,7 @@ def create_builtin_class_instance(
raise RuntimeError('Builtin module can not be loaded: {}'.format(module_name))
class_module = importlib.import_module(module_name)
class_constructor = getattr(class_module, class_name)
class_constructor = getattr(class_module, class_name) # type: ignore
instance = class_constructor(**class_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment