Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d38359e2
Unverified
Commit
d38359e2
authored
May 31, 2022
by
Yuge Zhang
Committed by
GitHub
May 31, 2022
Browse files
Pin pyright version (#4902)
parent
f9ea49ff
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
33 additions
and
26 deletions
+33
-26
dependencies/develop.txt
dependencies/develop.txt
+1
-1
dependencies/recommended_legacy.txt
dependencies/recommended_legacy.txt
+1
-0
nni/algorithms/compression/v2/pytorch/pruning/amc_pruner.py
nni/algorithms/compression/v2/pytorch/pruning/amc_pruner.py
+3
-2
nni/algorithms/compression/v2/pytorch/pruning/tools/rl_env/amc_env.py
...ms/compression/v2/pytorch/pruning/tools/rl_env/amc_env.py
+2
-2
nni/algorithms/compression/v2/pytorch/pruning/tools/rl_env/memory.py
...hms/compression/v2/pytorch/pruning/tools/rl_env/memory.py
+2
-2
nni/common/serializer.py
nni/common/serializer.py
+2
-2
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+4
-4
nni/retiarii/execution/python.py
nni/retiarii/execution/python.py
+5
-2
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+1
-1
nni/retiarii/graph.py
nni/retiarii/graph.py
+4
-4
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+5
-3
nni/retiarii/strategy/_rl_impl.py
nni/retiarii/strategy/_rl_impl.py
+1
-1
nni/tools/package_utils/tuner_factory.py
nni/tools/package_utils/tuner_factory.py
+2
-2
No files found.
dependencies/develop.txt
View file @
d38359e2
...
@@ -7,7 +7,7 @@ jupyter
...
@@ -7,7 +7,7 @@ jupyter
jupyterlab == 3.0.9
jupyterlab == 3.0.9
nbsphinx
nbsphinx
pylint
pylint
pyright
pyright
== 1.1.250
pytest
pytest
pytest-cov
pytest-cov
rstcheck
rstcheck
...
...
dependencies/recommended_legacy.txt
View file @
d38359e2
...
@@ -18,4 +18,5 @@ matplotlib
...
@@ -18,4 +18,5 @@ matplotlib
# TODO: time to drop tensorflow 1.x
# TODO: time to drop tensorflow 1.x
keras
keras
tensorflow < 2.0
tensorflow < 2.0
protobuf <= 3.20.1
timm >= 0.5.4
timm >= 0.5.4
\ No newline at end of file
nni/algorithms/compression/v2/pytorch/pruning/amc_pruner.py
View file @
d38359e2
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
from
copy
import
deepcopy
from
copy
import
deepcopy
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Callable
,
Optional
from
typing
import
Dict
,
List
,
Callable
,
Optional
,
cast
import
json_tricks
import
json_tricks
import
torch
import
torch
...
@@ -73,7 +73,8 @@ class AMCTaskGenerator(TaskGenerator):
...
@@ -73,7 +73,8 @@ class AMCTaskGenerator(TaskGenerator):
total_sparsity
=
config_list_copy
[
0
][
'total_sparsity'
]
total_sparsity
=
config_list_copy
[
0
][
'total_sparsity'
]
max_sparsity_per_layer
=
config_list_copy
[
0
].
get
(
'max_sparsity_per_layer'
,
1.
)
max_sparsity_per_layer
=
config_list_copy
[
0
].
get
(
'max_sparsity_per_layer'
,
1.
)
self
.
env
=
AMCEnv
(
origin_model
,
origin_config_list
,
self
.
dummy_input
,
total_sparsity
,
max_sparsity_per_layer
,
self
.
target
)
self
.
env
=
AMCEnv
(
origin_model
,
origin_config_list
,
self
.
dummy_input
,
total_sparsity
,
cast
(
Dict
[
str
,
float
],
max_sparsity_per_layer
),
self
.
target
)
self
.
agent
=
DDPG
(
len
(
self
.
env
.
state_feature
),
1
,
self
.
ddpg_params
)
self
.
agent
=
DDPG
(
len
(
self
.
env
.
state_feature
),
1
,
self
.
ddpg_params
)
self
.
agent
.
is_training
=
True
self
.
agent
.
is_training
=
True
task_result
=
TaskResult
(
'origin'
,
origin_model
,
origin_masks
,
origin_masks
,
None
)
task_result
=
TaskResult
(
'origin'
,
origin_model
,
origin_masks
,
origin_masks
,
None
)
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/rl_env/amc_env.py
View file @
d38359e2
...
@@ -25,8 +25,8 @@ class AMCEnv:
...
@@ -25,8 +25,8 @@ class AMCEnv:
for
i
,
(
name
,
layer
)
in
enumerate
(
model
.
named_modules
()):
for
i
,
(
name
,
layer
)
in
enumerate
(
model
.
named_modules
()):
if
name
in
pruning_op_names
:
if
name
in
pruning_op_names
:
op_type
=
type
(
layer
).
__name__
op_type
=
type
(
layer
).
__name__
stride
=
np
.
power
(
np
.
prod
(
layer
.
stride
),
1
/
len
(
layer
.
stride
))
if
hasattr
(
layer
,
'stride'
)
else
0
stride
=
np
.
power
(
np
.
prod
(
layer
.
stride
),
1
/
len
(
layer
.
stride
))
if
hasattr
(
layer
,
'stride'
)
else
0
# type: ignore
kernel_size
=
np
.
power
(
np
.
prod
(
layer
.
kernel_size
),
1
/
len
(
layer
.
kernel_size
))
if
hasattr
(
layer
,
'kernel_size'
)
else
1
kernel_size
=
np
.
power
(
np
.
prod
(
layer
.
kernel_size
),
1
/
len
(
layer
.
kernel_size
))
if
hasattr
(
layer
,
'kernel_size'
)
else
1
# type: ignore
self
.
pruning_ops
[
name
]
=
(
i
,
op_type
,
stride
,
kernel_size
)
self
.
pruning_ops
[
name
]
=
(
i
,
op_type
,
stride
,
kernel_size
)
self
.
pruning_types
.
append
(
op_type
)
self
.
pruning_types
.
append
(
op_type
)
self
.
pruning_types
=
list
(
set
(
self
.
pruning_types
))
self
.
pruning_types
=
list
(
set
(
self
.
pruning_types
))
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/rl_env/memory.py
View file @
d38359e2
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
collections
import
deque
,
namedtuple
from
collections
import
deque
,
namedtuple
from
typing
import
Any
,
List
from
typing
import
Any
,
List
,
cast
import
warnings
import
warnings
import
random
import
random
...
@@ -174,7 +174,7 @@ class SequentialMemory(Memory):
...
@@ -174,7 +174,7 @@ class SequentialMemory(Memory):
# to the right. Again, we need to be careful to not include an observation from the next
# to the right. Again, we need to be careful to not include an observation from the next
# episode if the last state is terminal.
# episode if the last state is terminal.
state1
=
[
np
.
copy
(
x
)
for
x
in
state0
[
1
:]]
state1
=
[
np
.
copy
(
x
)
for
x
in
state0
[
1
:]]
state1
.
append
(
self
.
observations
[
idx
])
state1
.
append
(
cast
(
np
.
ndarray
,
self
.
observations
[
idx
])
)
assert
len
(
state0
)
==
self
.
window_length
assert
len
(
state0
)
==
self
.
window_length
assert
len
(
state1
)
==
len
(
state0
)
assert
len
(
state1
)
==
len
(
state0
)
...
...
nni/common/serializer.py
View file @
d38359e2
...
@@ -13,7 +13,7 @@ import sys
...
@@ -13,7 +13,7 @@ import sys
import
types
import
types
import
warnings
import
warnings
from
io
import
IOBase
from
io
import
IOBase
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
TypeVar
,
Union
,
cast
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
TypeVar
,
Tuple
,
Union
,
cast
import
cloudpickle
# use cloudpickle as backend for unserializable types and instances
import
cloudpickle
# use cloudpickle as backend for unserializable types and instances
import
json_tricks
# use json_tricks as serializer backend
import
json_tricks
# use json_tricks as serializer backend
...
@@ -604,7 +604,7 @@ class _unwrap_metaclass(type):
...
@@ -604,7 +604,7 @@ class _unwrap_metaclass(type):
def
__new__
(
cls
,
name
,
bases
,
dct
):
def
__new__
(
cls
,
name
,
bases
,
dct
):
bases
=
tuple
([
getattr
(
base
,
'__wrapped__'
,
base
)
for
base
in
bases
])
bases
=
tuple
([
getattr
(
base
,
'__wrapped__'
,
base
)
for
base
in
bases
])
return
super
().
__new__
(
cls
,
name
,
bases
,
dct
)
return
super
().
__new__
(
cls
,
name
,
cast
(
Tuple
[
type
,
...],
bases
)
,
dct
)
# Using a customized "bases" breaks default isinstance and issubclass.
# Using a customized "bases" breaks default isinstance and issubclass.
# We recover this by overriding the subclass and isinstance behavior, which conerns wrapped class only.
# We recover this by overriding the subclass and isinstance behavior, which conerns wrapped class only.
...
...
nni/retiarii/execution/interface.py
View file @
d38359e2
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
from
abc
import
ABC
,
abstractmethod
,
abstractclassmethod
from
abc
import
ABC
,
abstractmethod
,
abstractclassmethod
from
typing
import
Any
,
Iterable
,
NewType
,
List
,
Union
from
typing
import
Any
,
Iterable
,
NewType
,
List
,
Union
,
Type
from
..graph
import
Model
,
MetricData
from
..graph
import
Model
,
MetricData
...
@@ -12,7 +12,7 @@ __all__ = [
...
@@ -12,7 +12,7 @@ __all__ = [
]
]
GraphData
=
NewType
(
'GraphData'
,
Any
)
GraphData
:
Type
[
Any
]
=
NewType
(
'GraphData'
,
Any
)
"""
"""
A _serializable_ internal data type defined by execution engine.
A _serializable_ internal data type defined by execution engine.
...
@@ -26,7 +26,7 @@ See `AbstractExecutionEngine` for details.
...
@@ -26,7 +26,7 @@ See `AbstractExecutionEngine` for details.
"""
"""
WorkerInfo
=
NewType
(
'WorkerInfo'
,
Any
)
WorkerInfo
:
Type
[
Any
]
=
NewType
(
'WorkerInfo'
,
Any
)
"""
"""
To be designed. Discussion needed.
To be designed. Discussion needed.
...
@@ -114,7 +114,7 @@ class AbstractExecutionEngine(ABC):
...
@@ -114,7 +114,7 @@ class AbstractExecutionEngine(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
# type: ignore
"""
"""
Returns information of all idle workers.
Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
...
...
nni/retiarii/execution/python.py
View file @
d38359e2
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
Dict
,
Any
,
Type
from
typing
import
Dict
,
Any
,
Type
,
cast
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -53,7 +53,10 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
...
@@ -53,7 +53,10 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
mutation
=
get_mutation_dict
(
model
)
mutation
=
get_mutation_dict
(
model
)
assert
model
.
evaluator
is
not
None
,
'Model evaluator is not available.'
assert
model
.
evaluator
is
not
None
,
'Model evaluator is not available.'
graph_data
=
PythonGraphData
(
model
.
python_class
,
model
.
python_init_params
or
{},
mutation
,
model
.
evaluator
)
graph_data
=
PythonGraphData
(
cast
(
Type
[
nn
.
Module
],
model
.
python_class
),
model
.
python_init_params
or
{},
mutation
,
model
.
evaluator
)
return
graph_data
return
graph_data
@
classmethod
@
classmethod
...
...
nni/retiarii/experiment/pytorch.py
View file @
d38359e2
...
@@ -351,7 +351,7 @@ class RetiariiExperiment(Experiment):
...
@@ -351,7 +351,7 @@ class RetiariiExperiment(Experiment):
# when strategy hasn't implemented its own export logic
# when strategy hasn't implemented its own export logic
all_models
=
filter
(
lambda
m
:
m
.
metric
is
not
None
,
list_models
())
all_models
=
filter
(
lambda
m
:
m
.
metric
is
not
None
,
list_models
())
assert
optimize_mode
in
[
'maximize'
,
'minimize'
]
assert
optimize_mode
in
[
'maximize'
,
'minimize'
]
all_models
=
sorted
(
all_models
,
key
=
lambda
m
:
m
.
metric
,
reverse
=
optimize_mode
==
'maximize'
)
all_models
=
sorted
(
all_models
,
key
=
lambda
m
:
cast
(
float
,
m
.
metric
)
,
reverse
=
optimize_mode
==
'maximize'
)
assert
formatter
in
[
'code'
,
'dict'
],
'Export formatter other than "code" and "dict" is not supported yet.'
assert
formatter
in
[
'code'
,
'dict'
],
'Export formatter other than "code" and "dict" is not supported yet.'
if
formatter
==
'code'
:
if
formatter
==
'code'
:
return
[
model_to_pytorch_script
(
model
)
for
model
in
all_models
[:
top_k
]]
return
[
model_to_pytorch_script
(
model
)
for
model
in
all_models
[:
top_k
]]
...
...
nni/retiarii/graph.py
View file @
d38359e2
...
@@ -56,8 +56,8 @@ class Evaluator(abc.ABC):
...
@@ -56,8 +56,8 @@ class Evaluator(abc.ABC):
if
subclass
.
__name__
==
evaluator_type
:
if
subclass
.
__name__
==
evaluator_type
:
evaluator_type
=
subclass
evaluator_type
=
subclass
break
break
assert
issubclass
(
evaluator_type
,
Evaluator
)
assert
issubclass
(
cast
(
type
,
evaluator_type
)
,
Evaluator
)
return
evaluator_type
.
_load
(
ir
)
return
cast
(
Type
[
Evaluator
],
evaluator_type
)
.
_load
(
ir
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
_dump
(
self
)
->
Any
:
def
_dump
(
self
)
->
Any
:
...
@@ -350,7 +350,7 @@ class Graph:
...
@@ -350,7 +350,7 @@ class Graph:
if
isinstance
(
operation_or_type
,
Operation
):
if
isinstance
(
operation_or_type
,
Operation
):
op
=
operation_or_type
op
=
operation_or_type
else
:
else
:
op
=
Operation
.
new
(
operation_or_type
,
parameters
,
name
)
op
=
Operation
.
new
(
operation_or_type
,
cast
(
dict
,
parameters
)
,
name
)
return
Node
(
self
,
uid
(),
name
,
op
,
_internal
=
True
).
_register
()
return
Node
(
self
,
uid
(),
name
,
op
,
_internal
=
True
).
_register
()
@
overload
@
overload
...
@@ -363,7 +363,7 @@ class Graph:
...
@@ -363,7 +363,7 @@ class Graph:
if
isinstance
(
operation_or_type
,
Operation
):
if
isinstance
(
operation_or_type
,
Operation
):
op
=
operation_or_type
op
=
operation_or_type
else
:
else
:
op
=
Operation
.
new
(
operation_or_type
,
parameters
,
name
)
op
=
Operation
.
new
(
operation_or_type
,
cast
(
dict
,
parameters
)
,
name
)
new_node
=
Node
(
self
,
uid
(),
name
,
op
,
_internal
=
True
).
_register
()
new_node
=
Node
(
self
,
uid
(),
name
,
op
,
_internal
=
True
).
_register
()
# update edges
# update edges
self
.
add_edge
((
edge
.
head
,
edge
.
head_slot
),
(
new_node
,
None
))
self
.
add_edge
((
edge
.
head
,
edge
.
head_slot
),
(
new_node
,
None
))
...
...
nni/retiarii/integration_api.py
View file @
d38359e2
...
@@ -11,16 +11,18 @@ from nni.common.version import version_check
...
@@ -11,16 +11,18 @@ from nni.common.version import version_check
# because it would induce cycled import
# because it would induce cycled import
RetiariiAdvisor
=
NewType
(
'RetiariiAdvisor'
,
Any
)
RetiariiAdvisor
=
NewType
(
'RetiariiAdvisor'
,
Any
)
_advisor
:
'
RetiariiAdvisor
'
=
None
_advisor
=
None
# type is
RetiariiAdvisor
def
get_advisor
()
->
'RetiariiAdvisor'
:
def
get_advisor
():
# return type: RetiariiAdvisor
global
_advisor
global
_advisor
assert
_advisor
is
not
None
assert
_advisor
is
not
None
return
_advisor
return
_advisor
def
register_advisor
(
advisor
:
'RetiariiAdvisor'
):
def
register_advisor
(
advisor
):
# type of advisor: RetiariiAdvisor
global
_advisor
global
_advisor
if
_advisor
is
not
None
:
if
_advisor
is
not
None
:
warnings
.
warn
(
'Advisor is already set.'
warnings
.
warn
(
'Advisor is already set.'
...
...
nni/retiarii/strategy/_rl_impl.py
View file @
d38359e2
...
@@ -141,7 +141,7 @@ class ModelEvaluationEnv(gym.Env[ObservationType, int]):
...
@@ -141,7 +141,7 @@ class ModelEvaluationEnv(gym.Env[ObservationType, int]):
wait_models
(
model
)
wait_models
(
model
)
if
model
.
status
==
ModelStatus
.
Failed
:
if
model
.
status
==
ModelStatus
.
Failed
:
return
self
.
reset
(),
0.
,
False
,
{}
return
self
.
reset
(),
0.
,
False
,
{}
rew
=
float
(
model
.
metric
)
rew
=
float
(
model
.
metric
)
# type: ignore
_logger
.
info
(
f
'Model metric received as reward:
{
rew
}
'
)
_logger
.
info
(
f
'Model metric received as reward:
{
rew
}
'
)
return
obs
,
rew
,
True
,
{}
return
obs
,
rew
,
True
,
{}
else
:
else
:
...
...
nni/tools/package_utils/tuner_factory.py
View file @
d38359e2
...
@@ -93,7 +93,7 @@ def create_validator_instance(algo_type, builtin_name):
...
@@ -93,7 +93,7 @@ def create_validator_instance(algo_type, builtin_name):
module_name
,
class_name
=
parse_full_class_name
(
meta
[
'classArgsValidator'
])
module_name
,
class_name
=
parse_full_class_name
(
meta
[
'classArgsValidator'
])
assert
module_name
is
not
None
assert
module_name
is
not
None
class_module
=
importlib
.
import_module
(
module_name
)
class_module
=
importlib
.
import_module
(
module_name
)
class_constructor
=
getattr
(
class_module
,
class_name
)
class_constructor
=
getattr
(
class_module
,
class_name
)
# type: ignore
return
class_constructor
()
return
class_constructor
()
...
@@ -149,7 +149,7 @@ def create_builtin_class_instance(
...
@@ -149,7 +149,7 @@ def create_builtin_class_instance(
raise
RuntimeError
(
'Builtin module can not be loaded: {}'
.
format
(
module_name
))
raise
RuntimeError
(
'Builtin module can not be loaded: {}'
.
format
(
module_name
))
class_module
=
importlib
.
import_module
(
module_name
)
class_module
=
importlib
.
import_module
(
module_name
)
class_constructor
=
getattr
(
class_module
,
class_name
)
class_constructor
=
getattr
(
class_module
,
class_name
)
# type: ignore
instance
=
class_constructor
(
**
class_args
)
instance
=
class_constructor
(
**
class_args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment