Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d38359e2
Unverified
Commit
d38359e2
authored
May 31, 2022
by
Yuge Zhang
Committed by
GitHub
May 31, 2022
Browse files
Pin pyright version (#4902)
parent
f9ea49ff
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
33 additions
and
26 deletions
+33
-26
dependencies/develop.txt
dependencies/develop.txt
+1
-1
dependencies/recommended_legacy.txt
dependencies/recommended_legacy.txt
+1
-0
nni/algorithms/compression/v2/pytorch/pruning/amc_pruner.py
nni/algorithms/compression/v2/pytorch/pruning/amc_pruner.py
+3
-2
nni/algorithms/compression/v2/pytorch/pruning/tools/rl_env/amc_env.py
...ms/compression/v2/pytorch/pruning/tools/rl_env/amc_env.py
+2
-2
nni/algorithms/compression/v2/pytorch/pruning/tools/rl_env/memory.py
...hms/compression/v2/pytorch/pruning/tools/rl_env/memory.py
+2
-2
nni/common/serializer.py
nni/common/serializer.py
+2
-2
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+4
-4
nni/retiarii/execution/python.py
nni/retiarii/execution/python.py
+5
-2
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+1
-1
nni/retiarii/graph.py
nni/retiarii/graph.py
+4
-4
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+5
-3
nni/retiarii/strategy/_rl_impl.py
nni/retiarii/strategy/_rl_impl.py
+1
-1
nni/tools/package_utils/tuner_factory.py
nni/tools/package_utils/tuner_factory.py
+2
-2
No files found.
dependencies/develop.txt
View file @
d38359e2
...
...
@@ -7,7 +7,7 @@ jupyter
jupyterlab == 3.0.9
nbsphinx
pylint
pyright
pyright
== 1.1.250
pytest
pytest-cov
rstcheck
...
...
dependencies/recommended_legacy.txt
View file @
d38359e2
...
...
@@ -18,4 +18,5 @@ matplotlib
# TODO: time to drop tensorflow 1.x
keras
tensorflow < 2.0
protobuf <= 3.20.1
timm >= 0.5.4
\ No newline at end of file
nni/algorithms/compression/v2/pytorch/pruning/amc_pruner.py
View file @
d38359e2
...
...
@@ -3,7 +3,7 @@
from
copy
import
deepcopy
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Callable
,
Optional
from
typing
import
Dict
,
List
,
Callable
,
Optional
,
cast
import
json_tricks
import
torch
...
...
@@ -73,7 +73,8 @@ class AMCTaskGenerator(TaskGenerator):
total_sparsity
=
config_list_copy
[
0
][
'total_sparsity'
]
max_sparsity_per_layer
=
config_list_copy
[
0
].
get
(
'max_sparsity_per_layer'
,
1.
)
self
.
env
=
AMCEnv
(
origin_model
,
origin_config_list
,
self
.
dummy_input
,
total_sparsity
,
max_sparsity_per_layer
,
self
.
target
)
self
.
env
=
AMCEnv
(
origin_model
,
origin_config_list
,
self
.
dummy_input
,
total_sparsity
,
cast
(
Dict
[
str
,
float
],
max_sparsity_per_layer
),
self
.
target
)
self
.
agent
=
DDPG
(
len
(
self
.
env
.
state_feature
),
1
,
self
.
ddpg_params
)
self
.
agent
.
is_training
=
True
task_result
=
TaskResult
(
'origin'
,
origin_model
,
origin_masks
,
origin_masks
,
None
)
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/rl_env/amc_env.py
View file @
d38359e2
...
...
@@ -25,8 +25,8 @@ class AMCEnv:
for
i
,
(
name
,
layer
)
in
enumerate
(
model
.
named_modules
()):
if
name
in
pruning_op_names
:
op_type
=
type
(
layer
).
__name__
stride
=
np
.
power
(
np
.
prod
(
layer
.
stride
),
1
/
len
(
layer
.
stride
))
if
hasattr
(
layer
,
'stride'
)
else
0
kernel_size
=
np
.
power
(
np
.
prod
(
layer
.
kernel_size
),
1
/
len
(
layer
.
kernel_size
))
if
hasattr
(
layer
,
'kernel_size'
)
else
1
stride
=
np
.
power
(
np
.
prod
(
layer
.
stride
),
1
/
len
(
layer
.
stride
))
if
hasattr
(
layer
,
'stride'
)
else
0
# type: ignore
kernel_size
=
np
.
power
(
np
.
prod
(
layer
.
kernel_size
),
1
/
len
(
layer
.
kernel_size
))
if
hasattr
(
layer
,
'kernel_size'
)
else
1
# type: ignore
self
.
pruning_ops
[
name
]
=
(
i
,
op_type
,
stride
,
kernel_size
)
self
.
pruning_types
.
append
(
op_type
)
self
.
pruning_types
=
list
(
set
(
self
.
pruning_types
))
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/rl_env/memory.py
View file @
d38359e2
...
...
@@ -3,7 +3,7 @@
from
__future__
import
absolute_import
from
collections
import
deque
,
namedtuple
from
typing
import
Any
,
List
from
typing
import
Any
,
List
,
cast
import
warnings
import
random
...
...
@@ -174,7 +174,7 @@ class SequentialMemory(Memory):
# to the right. Again, we need to be careful to not include an observation from the next
# episode if the last state is terminal.
state1
=
[
np
.
copy
(
x
)
for
x
in
state0
[
1
:]]
state1
.
append
(
self
.
observations
[
idx
])
state1
.
append
(
cast
(
np
.
ndarray
,
self
.
observations
[
idx
])
)
assert
len
(
state0
)
==
self
.
window_length
assert
len
(
state1
)
==
len
(
state0
)
...
...
nni/common/serializer.py
View file @
d38359e2
...
...
@@ -13,7 +13,7 @@ import sys
import
types
import
warnings
from
io
import
IOBase
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
TypeVar
,
Union
,
cast
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
TypeVar
,
Tuple
,
Union
,
cast
import
cloudpickle
# use cloudpickle as backend for unserializable types and instances
import
json_tricks
# use json_tricks as serializer backend
...
...
@@ -604,7 +604,7 @@ class _unwrap_metaclass(type):
def
__new__
(
cls
,
name
,
bases
,
dct
):
bases
=
tuple
([
getattr
(
base
,
'__wrapped__'
,
base
)
for
base
in
bases
])
return
super
().
__new__
(
cls
,
name
,
bases
,
dct
)
return
super
().
__new__
(
cls
,
name
,
cast
(
Tuple
[
type
,
...],
bases
)
,
dct
)
# Using a customized "bases" breaks default isinstance and issubclass.
# We recover this by overriding the subclass and isinstance behavior, which conerns wrapped class only.
...
...
nni/retiarii/execution/interface.py
View file @
d38359e2
...
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
from
abc
import
ABC
,
abstractmethod
,
abstractclassmethod
from
typing
import
Any
,
Iterable
,
NewType
,
List
,
Union
from
typing
import
Any
,
Iterable
,
NewType
,
List
,
Union
,
Type
from
..graph
import
Model
,
MetricData
...
...
@@ -12,7 +12,7 @@ __all__ = [
]
GraphData
=
NewType
(
'GraphData'
,
Any
)
GraphData
:
Type
[
Any
]
=
NewType
(
'GraphData'
,
Any
)
"""
A _serializable_ internal data type defined by execution engine.
...
...
@@ -26,7 +26,7 @@ See `AbstractExecutionEngine` for details.
"""
WorkerInfo
=
NewType
(
'WorkerInfo'
,
Any
)
WorkerInfo
:
Type
[
Any
]
=
NewType
(
'WorkerInfo'
,
Any
)
"""
To be designed. Discussion needed.
...
...
@@ -114,7 +114,7 @@ class AbstractExecutionEngine(ABC):
raise
NotImplementedError
@
abstractmethod
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
# type: ignore
"""
Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
...
...
nni/retiarii/execution/python.py
View file @
d38359e2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Dict
,
Any
,
Type
from
typing
import
Dict
,
Any
,
Type
,
cast
import
torch.nn
as
nn
...
...
@@ -53,7 +53,10 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
mutation
=
get_mutation_dict
(
model
)
assert
model
.
evaluator
is
not
None
,
'Model evaluator is not available.'
graph_data
=
PythonGraphData
(
model
.
python_class
,
model
.
python_init_params
or
{},
mutation
,
model
.
evaluator
)
graph_data
=
PythonGraphData
(
cast
(
Type
[
nn
.
Module
],
model
.
python_class
),
model
.
python_init_params
or
{},
mutation
,
model
.
evaluator
)
return
graph_data
@
classmethod
...
...
nni/retiarii/experiment/pytorch.py
View file @
d38359e2
...
...
@@ -351,7 +351,7 @@ class RetiariiExperiment(Experiment):
# when strategy hasn't implemented its own export logic
all_models
=
filter
(
lambda
m
:
m
.
metric
is
not
None
,
list_models
())
assert
optimize_mode
in
[
'maximize'
,
'minimize'
]
all_models
=
sorted
(
all_models
,
key
=
lambda
m
:
m
.
metric
,
reverse
=
optimize_mode
==
'maximize'
)
all_models
=
sorted
(
all_models
,
key
=
lambda
m
:
cast
(
float
,
m
.
metric
)
,
reverse
=
optimize_mode
==
'maximize'
)
assert
formatter
in
[
'code'
,
'dict'
],
'Export formatter other than "code" and "dict" is not supported yet.'
if
formatter
==
'code'
:
return
[
model_to_pytorch_script
(
model
)
for
model
in
all_models
[:
top_k
]]
...
...
nni/retiarii/graph.py
View file @
d38359e2
...
...
@@ -56,8 +56,8 @@ class Evaluator(abc.ABC):
if
subclass
.
__name__
==
evaluator_type
:
evaluator_type
=
subclass
break
assert
issubclass
(
evaluator_type
,
Evaluator
)
return
evaluator_type
.
_load
(
ir
)
assert
issubclass
(
cast
(
type
,
evaluator_type
)
,
Evaluator
)
return
cast
(
Type
[
Evaluator
],
evaluator_type
)
.
_load
(
ir
)
@
abc
.
abstractmethod
def
_dump
(
self
)
->
Any
:
...
...
@@ -350,7 +350,7 @@ class Graph:
if
isinstance
(
operation_or_type
,
Operation
):
op
=
operation_or_type
else
:
op
=
Operation
.
new
(
operation_or_type
,
parameters
,
name
)
op
=
Operation
.
new
(
operation_or_type
,
cast
(
dict
,
parameters
)
,
name
)
return
Node
(
self
,
uid
(),
name
,
op
,
_internal
=
True
).
_register
()
@
overload
...
...
@@ -363,7 +363,7 @@ class Graph:
if
isinstance
(
operation_or_type
,
Operation
):
op
=
operation_or_type
else
:
op
=
Operation
.
new
(
operation_or_type
,
parameters
,
name
)
op
=
Operation
.
new
(
operation_or_type
,
cast
(
dict
,
parameters
)
,
name
)
new_node
=
Node
(
self
,
uid
(),
name
,
op
,
_internal
=
True
).
_register
()
# update edges
self
.
add_edge
((
edge
.
head
,
edge
.
head_slot
),
(
new_node
,
None
))
...
...
nni/retiarii/integration_api.py
View file @
d38359e2
...
...
@@ -11,16 +11,18 @@ from nni.common.version import version_check
# because it would induce cycled import
RetiariiAdvisor
=
NewType
(
'RetiariiAdvisor'
,
Any
)
_advisor
:
'
RetiariiAdvisor
'
=
None
_advisor
=
None
# type is
RetiariiAdvisor
def
get_advisor
()
->
'RetiariiAdvisor'
:
def
get_advisor
():
# return type: RetiariiAdvisor
global
_advisor
assert
_advisor
is
not
None
return
_advisor
def
register_advisor
(
advisor
:
'RetiariiAdvisor'
):
def
register_advisor
(
advisor
):
# type of advisor: RetiariiAdvisor
global
_advisor
if
_advisor
is
not
None
:
warnings
.
warn
(
'Advisor is already set.'
...
...
nni/retiarii/strategy/_rl_impl.py
View file @
d38359e2
...
...
@@ -141,7 +141,7 @@ class ModelEvaluationEnv(gym.Env[ObservationType, int]):
wait_models
(
model
)
if
model
.
status
==
ModelStatus
.
Failed
:
return
self
.
reset
(),
0.
,
False
,
{}
rew
=
float
(
model
.
metric
)
rew
=
float
(
model
.
metric
)
# type: ignore
_logger
.
info
(
f
'Model metric received as reward:
{
rew
}
'
)
return
obs
,
rew
,
True
,
{}
else
:
...
...
nni/tools/package_utils/tuner_factory.py
View file @
d38359e2
...
...
@@ -93,7 +93,7 @@ def create_validator_instance(algo_type, builtin_name):
module_name
,
class_name
=
parse_full_class_name
(
meta
[
'classArgsValidator'
])
assert
module_name
is
not
None
class_module
=
importlib
.
import_module
(
module_name
)
class_constructor
=
getattr
(
class_module
,
class_name
)
class_constructor
=
getattr
(
class_module
,
class_name
)
# type: ignore
return
class_constructor
()
...
...
@@ -149,7 +149,7 @@ def create_builtin_class_instance(
raise
RuntimeError
(
'Builtin module can not be loaded: {}'
.
format
(
module_name
))
class_module
=
importlib
.
import_module
(
module_name
)
class_constructor
=
getattr
(
class_module
,
class_name
)
class_constructor
=
getattr
(
class_module
,
class_name
)
# type: ignore
instance
=
class_constructor
(
**
class_args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment