Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
18962129
Unverified
Commit
18962129
authored
Apr 25, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 25, 2022
Browse files
Add license header and typehints for NAS (#4774)
parent
8c2f717d
Changes
96
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
249 additions
and
166 deletions
+249
-166
nni/retiarii/nn/pytorch/cell.py
nni/retiarii/nn/pytorch/cell.py
+20
-11
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+14
-12
nni/retiarii/nn/pytorch/hypermodule.py
nni/retiarii/nn/pytorch/hypermodule.py
+3
-1
nni/retiarii/nn/pytorch/mutation_utils.py
nni/retiarii/nn/pytorch/mutation_utils.py
+5
-2
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+35
-26
nni/retiarii/nn/pytorch/nasbench101.py
nni/retiarii/nn/pytorch/nasbench101.py
+18
-6
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+47
-22
nni/retiarii/oneshot/pytorch/base_lightning.py
nni/retiarii/oneshot/pytorch/base_lightning.py
+3
-2
nni/retiarii/oneshot/pytorch/darts.py
nni/retiarii/oneshot/pytorch/darts.py
+2
-0
nni/retiarii/oneshot/pytorch/proxyless.py
nni/retiarii/oneshot/pytorch/proxyless.py
+2
-0
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
+3
-2
nni/retiarii/oneshot/pytorch/supermodule/base.py
nni/retiarii/oneshot/pytorch/supermodule/base.py
+2
-2
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
+2
-2
nni/retiarii/oneshot/pytorch/supermodule/operation.py
nni/retiarii/oneshot/pytorch/supermodule/operation.py
+22
-22
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
+7
-5
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
+4
-2
nni/retiarii/oneshot/pytorch/utils.py
nni/retiarii/oneshot/pytorch/utils.py
+1
-1
nni/retiarii/operation.py
nni/retiarii/operation.py
+19
-15
nni/retiarii/operation_def/__init__.py
nni/retiarii/operation_def/__init__.py
+3
-0
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+37
-33
No files found.
nni/retiarii/nn/pytorch/cell.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
warnings
from
typing
import
Callable
,
Dict
,
List
,
Union
,
Optional
,
Tuple
from
typing
import
Callable
,
Dict
,
List
,
Union
,
Optional
,
Tuple
,
Sequence
,
cast
try
:
from
typing
import
Literal
except
ImportError
:
...
...
@@ -193,8 +196,10 @@ class Cell(nn.Module):
def
__init__
(
self
,
op_candidates
:
Union
[
Callable
[[],
List
[
nn
.
Module
]],
List
[
Union
[
nn
.
Module
,
_cell_op_factory_type
]],
Dict
[
str
,
Union
[
nn
.
Module
,
_cell_op_factory_type
]]
List
[
nn
.
Module
],
List
[
_cell_op_factory_type
],
Dict
[
str
,
nn
.
Module
],
Dict
[
str
,
_cell_op_factory_type
]
],
num_nodes
:
int
,
num_ops_per_node
:
int
=
1
,
...
...
@@ -251,8 +256,8 @@ class Cell(nn.Module):
ops
=
self
.
_convert_op_candidates
(
op_candidates
,
i
,
k
,
chosen
)
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
self
.
ops
[
-
1
].
append
(
LayerChoice
(
ops
,
label
=
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
))
self
.
inputs
[
-
1
].
append
(
inp
)
cast
(
ModuleList
,
self
.
ops
[
-
1
]
)
.
append
(
LayerChoice
(
ops
,
label
=
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
))
cast
(
ModuleList
,
self
.
inputs
[
-
1
]
)
.
append
(
inp
)
@
property
def
label
(
self
):
...
...
@@ -274,13 +279,17 @@ class Cell(nn.Module):
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
of some of (possibly all) the nodes' outputs in the cell.
"""
processed_inputs
:
List
[
torch
.
Tensor
]
if
len
(
inputs
)
==
1
and
isinstance
(
inputs
[
0
],
list
):
inputs
=
inputs
[
0
]
processed_
inputs
=
list
(
inputs
[
0
]
)
# shallow copy
else
:
inputs
=
list
(
inputs
)
assert
len
(
inputs
)
==
self
.
num_predecessors
,
'The number of inputs must be equal to `num_predecessors`.'
states
=
self
.
preprocessor
(
inputs
)
for
ops
,
inps
in
zip
(
self
.
ops
,
self
.
inputs
):
processed_inputs
=
cast
(
List
[
torch
.
Tensor
],
list
(
inputs
))
assert
len
(
processed_inputs
)
==
self
.
num_predecessors
,
'The number of inputs must be equal to `num_predecessors`.'
states
:
List
[
torch
.
Tensor
]
=
self
.
preprocessor
(
processed_inputs
)
for
ops
,
inps
in
zip
(
cast
(
Sequence
[
Sequence
[
LayerChoice
]],
self
.
ops
),
cast
(
Sequence
[
Sequence
[
InputChoice
]],
self
.
inputs
)
):
current_state
=
[]
for
op
,
inp
in
zip
(
ops
,
inps
):
current_state
.
append
(
op
(
inp
(
states
)))
...
...
@@ -291,7 +300,7 @@ class Cell(nn.Module):
this_cell
=
torch
.
cat
(
states
[
self
.
num_predecessors
:],
self
.
concat_dim
)
else
:
this_cell
=
torch
.
cat
([
states
[
k
]
for
k
in
self
.
output_node_indices
],
self
.
concat_dim
)
return
self
.
postprocessor
(
this_cell
,
inputs
)
return
self
.
postprocessor
(
this_cell
,
processed_
inputs
)
@
staticmethod
def
_convert_op_candidates
(
op_candidates
,
node_index
,
op_index
,
chosen
)
->
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]]:
...
...
nni/retiarii/nn/pytorch/component.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
warnings
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Union
,
Tuple
,
Optional
from
typing
import
Callable
,
List
,
Dict
,
Union
,
Tuple
,
Optional
import
torch
import
torch.nn
as
nn
from
nni.retiarii.utils
import
NoContextError
,
STATE_DICT_PY_MAPPING_PARTIAL
from
.api
import
LayerChoice
,
ValueChoice
,
ValueChoiceX
from
.api
import
LayerChoice
,
ValueChoice
,
ValueChoiceX
,
ChoiceOf
from
.cell
import
Cell
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.mutation_utils
import
Mutable
,
generate_new_label
,
get_fixed_value
...
...
@@ -64,7 +67,7 @@ class Repeat(Mutable):
List
[
Callable
[[
int
],
nn
.
Module
]],
nn
.
Module
,
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
],
Value
Choice
],
*
,
label
:
Optional
[
str
]
=
None
):
depth
:
Union
[
int
,
Tuple
[
int
,
int
],
Choice
Of
[
int
]
],
*
,
label
:
Optional
[
str
]
=
None
):
if
isinstance
(
depth
,
tuple
):
# we can't create a value choice here,
# otherwise we will have two value choices, one created here, another in init.
...
...
@@ -90,7 +93,7 @@ class Repeat(Mutable):
List
[
Callable
[[
int
],
nn
.
Module
]],
nn
.
Module
,
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
depth
:
Union
[
int
,
Tuple
[
int
,
int
],
ChoiceOf
[
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
self
.
_label
=
None
# by default, no label
...
...
@@ -192,7 +195,7 @@ class NasBench201Cell(nn.Module):
return
OrderedDict
([(
str
(
i
),
t
)
for
i
,
t
in
enumerate
(
x
)])
return
OrderedDict
(
x
)
def
__init__
(
self
,
op_candidates
:
List
[
Callable
[[
int
,
int
],
nn
.
Module
]],
def
__init__
(
self
,
op_candidates
:
Union
[
Dict
[
str
,
Callable
[[
int
,
int
],
nn
.
Module
]],
List
[
Callable
[[
int
,
int
],
nn
.
Module
]]
]
,
in_features
:
int
,
out_features
:
int
,
num_tensors
:
int
=
4
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
...
...
@@ -214,16 +217,15 @@ class NasBench201Cell(nn.Module):
node_ops
.
append
(
LayerChoice
(
op_choices
,
label
=
f
'
{
self
.
_label
}
__
{
j
}
_
{
tid
}
'
))
# put __ here to be compatible with base engine
self
.
layers
.
append
(
node_ops
)
def
forward
(
self
,
inputs
)
:
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases.
"""
tensors
=
[
inputs
]
tensors
:
List
[
torch
.
Tensor
]
=
[
inputs
]
for
layer
in
self
.
layers
:
current_tensor
=
[]
for
i
,
op
in
enumerate
(
layer
):
current_tensor
.
append
(
op
(
tensors
[
i
]))
current_tensor
=
torch
.
sum
(
torch
.
stack
(
current_tensor
),
0
)
tensors
.
append
(
current_tensor
)
current_tensor
:
List
[
torch
.
Tensor
]
=
[]
for
i
,
op
in
enumerate
(
layer
):
# type: ignore
current_tensor
.
append
(
op
(
tensors
[
i
]))
# type: ignore
tensors
.
append
(
torch
.
sum
(
torch
.
stack
(
current_tensor
),
0
))
return
tensors
[
-
1
]
nni/retiarii/nn/pytorch/hypermodule.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
from
packaging.version
import
Version
import
torch
import
torch.nn
as
nn
...
...
@@ -233,7 +235,7 @@ class AutoActivation(nn.Module):
-----
Current `beta` is not per-channel parameter.
"""
def
__init__
(
self
,
unit_num
:
int
=
1
,
label
:
str
=
None
):
def
__init__
(
self
,
unit_num
:
int
=
1
,
label
:
str
|
None
=
None
):
super
().
__init__
()
self
.
_label
=
generate_new_label
(
label
)
self
.
unaries
=
nn
.
ModuleList
()
...
...
nni/retiarii/nn/pytorch/mutation_utils.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
torch.nn
as
nn
...
...
@@ -41,7 +44,7 @@ def generate_new_label(label: Optional[str]):
return
label
def
get_fixed_value
(
label
:
str
)
->
Any
:
def
get_fixed_value
(
label
:
Optional
[
str
]
)
->
Any
:
ret
=
get_current_context
(
'fixed'
)
try
:
return
ret
[
generate_new_label
(
label
)]
...
...
@@ -49,7 +52,7 @@ def get_fixed_value(label: str) -> Any:
raise
KeyError
(
f
'Fixed context with
{
label
}
not found. Existing values are:
{
ret
}
'
)
def
get_fixed_dict
(
label_prefix
:
str
)
->
Tuple
[
str
,
Any
]:
def
get_fixed_dict
(
label_prefix
:
Optional
[
str
]
)
->
Tuple
[
str
,
Any
]:
ret
=
get_current_context
(
'fixed'
)
try
:
label_prefix
=
generate_new_label
(
label_prefix
)
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
18962129
...
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import
inspect
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Dict
,
Iterator
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Dict
,
Iterator
,
Iterable
,
cast
import
torch.nn
as
nn
...
...
@@ -28,12 +28,14 @@ class LayerChoiceMutator(Mutator):
# Each layer choice corresponds to a cell, which is unconnected in the base graph.
# We add the connections here in the mutation logic.
# Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
target
=
model
.
graphs
[
node
.
operation
.
cell_name
]
target
=
model
.
graphs
[
cast
(
Cell
,
node
.
operation
)
.
cell_name
]
chosen_node
=
target
.
get_node_by_name
(
chosen
)
assert
chosen_node
is
not
None
target
.
add_edge
((
target
.
input_node
,
0
),
(
chosen_node
,
None
))
target
.
add_edge
((
chosen_node
,
None
),
(
target
.
output_node
,
None
))
model
.
get_node_by_name
(
node
.
name
).
update_operation
(
Cell
(
node
.
operation
.
cell_name
))
operation
=
cast
(
Cell
,
node
.
operation
)
target_node
=
cast
(
Node
,
model
.
get_node_by_name
(
node
.
name
))
target_node
.
update_operation
(
Cell
(
operation
.
cell_name
))
# remove redundant nodes
for
rm_node
in
list
(
target
.
hidden_nodes
):
# remove from a list on the fly will cause issues
...
...
@@ -57,7 +59,7 @@ class InputChoiceMutator(Mutator):
else
:
chosen
=
[
self
.
choice
(
candidates
)
for
_
in
range
(
n_chosen
)]
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
target
=
cast
(
Node
,
model
.
get_node_by_name
(
node
.
name
)
)
target
.
update_operation
(
'__torch__.nni.retiarii.nn.pytorch.ChosenInputs'
,
{
'chosen'
:
chosen
,
'reduction'
:
node
.
operation
.
parameters
[
'reduction'
]})
...
...
@@ -74,7 +76,7 @@ class ValueChoiceMutator(Mutator):
# no need to support transformation here,
# because it is naturally done in forward loop
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
target
=
cast
(
Node
,
model
.
get_node_by_name
(
node
.
name
)
)
target
.
update_operation
(
'prim::Constant'
,
{
'type'
:
type
(
chosen
).
__name__
,
'value'
:
chosen
})
...
...
@@ -86,7 +88,7 @@ class ParameterChoiceLeafMutator(Mutator):
super
().
__init__
(
label
=
label
)
self
.
candidates
=
candidates
def
mutate
(
self
,
model
:
Model
)
->
Model
:
def
mutate
(
self
,
model
:
Model
)
->
None
:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self
.
choice
(
self
.
candidates
)
...
...
@@ -103,7 +105,7 @@ class ParameterChoiceMutator(Mutator):
self
.
nodes
=
nodes
def
mutate
(
self
,
model
:
Model
)
->
Model
:
def
mutate
(
self
,
model
:
Model
)
->
None
:
# looks like {"label1": "cat", "label2": 123}
value_choice_decisions
=
{}
for
mutation
in
model
.
history
:
...
...
@@ -122,7 +124,7 @@ class ParameterChoiceMutator(Mutator):
result_value
=
value_choice
.
evaluate
(
leaf_node_values
)
# update model with graph mutation primitives
target
=
model
.
get_node_by_name
(
node
.
name
)
target
=
cast
(
Node
,
model
.
get_node_by_name
(
node
.
name
)
)
target
.
update_operation
(
target
.
operation
.
type
,
{
**
target
.
operation
.
parameters
,
argname
:
result_value
})
...
...
@@ -138,20 +140,20 @@ class RepeatMutator(Mutator):
while
u
!=
graph
.
output_node
:
if
u
!=
graph
.
input_node
:
chain
.
append
(
u
)
assert
len
(
u
.
successors
)
==
1
,
f
'This graph is an illegal chain.
{
u
}
has output
{
u
.
successor
}
.'
assert
len
(
u
.
successors
)
==
1
,
f
'This graph is an illegal chain.
{
u
}
has output
{
u
.
successor
s
}
.'
u
=
u
.
successors
[
0
]
return
chain
def
mutate
(
self
,
model
):
for
node
in
self
.
nodes
:
# the logic here is similar to layer choice. We find cell attached to each node.
target
:
Graph
=
model
.
graphs
[
node
.
operation
.
cell_name
]
target
:
Graph
=
model
.
graphs
[
cast
(
Cell
,
node
.
operation
)
.
cell_name
]
chain
=
self
.
_retrieve_chain_from_graph
(
target
)
# and we get the chosen depth (by value choice)
node_in_model
=
model
.
get_node_by_name
(
node
.
name
)
node_in_model
=
cast
(
Node
,
model
.
get_node_by_name
(
node
.
name
)
)
# depth is a value choice in base model
# but it's already mutated by a ParameterChoiceMutator here
chosen_depth
=
node_in_model
.
operation
.
parameters
[
'depth'
]
chosen_depth
:
int
=
node_in_model
.
operation
.
parameters
[
'depth'
]
for
edge
in
chain
[
chosen_depth
-
1
].
outgoing_edges
:
edge
.
remove
()
target
.
add_edge
((
chain
[
chosen_depth
-
1
],
None
),
(
target
.
output_node
,
None
))
...
...
@@ -159,8 +161,11 @@ class RepeatMutator(Mutator):
for
edge
in
rm_node
.
outgoing_edges
:
edge
.
remove
()
rm_node
.
remove
()
# to delete the unused parameters.
model
.
get_node_by_name
(
node
.
name
).
update_operation
(
Cell
(
node
.
operation
.
cell_name
))
target_node
=
cast
(
Node
,
model
.
get_node_by_name
(
node
.
name
))
cell_operation
=
cast
(
Cell
,
node
.
operation
)
target_node
.
update_operation
(
Cell
(
cell_operation
.
cell_name
))
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
...
...
@@ -241,7 +246,7 @@ class ManyChooseManyMutator(Mutator):
Choose based on labels. Will not affect the model itself.
"""
def
__init__
(
self
,
label
:
Optional
[
str
]
):
def
__init__
(
self
,
label
:
str
):
super
().
__init__
(
label
=
label
)
@
staticmethod
...
...
@@ -257,7 +262,7 @@ class ManyChooseManyMutator(Mutator):
return
node
.
operation
.
parameters
[
'n_chosen'
]
return
1
def
mutate
(
self
,
model
:
Model
):
def
mutate
(
self
,
model
:
Model
)
->
None
:
# this mutate does not have any effect, but it is recorded in the mutation history
for
node
in
model
.
get_nodes_by_label
(
self
.
label
):
n_chosen
=
self
.
number_of_chosen
(
node
)
...
...
@@ -280,12 +285,12 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
if
not
is_model_wrapped
(
pytorch_model
):
raise
ValueError
(
'Please annotate the model with @model_wrapper decorator in python execution mode '
'if your model has init parameters.'
)
model
.
python_init_params
=
pytorch_model
.
trace_kwargs
model
.
python_init_params
=
cast
(
dict
,
pytorch_model
.
trace_kwargs
)
else
:
model
.
python_init_params
=
{}
# hyper-parameter choice
namespace
:
ModelNamespace
=
pytorch_model
.
_model_namespace
namespace
:
ModelNamespace
=
cast
(
ModelNamespace
,
pytorch_model
.
_model_namespace
)
for
param_spec
in
namespace
.
parameter_specs
:
assert
param_spec
.
categorical
and
param_spec
.
type
==
'choice'
node
=
graph
.
add_node
(
f
'param_spec_
{
param_spec
.
name
}
'
,
'ModelParameterChoice'
,
{
'candidates'
:
param_spec
.
values
})
...
...
@@ -294,7 +299,8 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
for
name
,
module
in
pytorch_model
.
named_modules
():
# tricky case: value choice that serves as parameters are stored in traced arguments
if
is_basic_unit
(
module
):
for
key
,
value
in
module
.
trace_kwargs
.
items
():
trace_kwargs
=
cast
(
Dict
[
str
,
Any
],
module
.
trace_kwargs
)
for
key
,
value
in
trace_kwargs
.
items
():
if
isinstance
(
value
,
ValueChoiceX
):
for
i
,
choice
in
enumerate
(
value
.
inner_choices
()):
node
=
graph
.
add_node
(
f
'
{
name
}
.init.
{
key
}
.
{
i
}
'
,
'ValueChoice'
,
{
'candidates'
:
choice
.
candidates
})
...
...
@@ -329,14 +335,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
mutators
=
[]
mutators_final
=
[]
for
nodes
in
_group_by_label_and_type
(
graph
.
hidden_nodes
):
label
=
nodes
[
0
].
label
assert
label
is
not
None
,
f
'label of
{
nodes
[
0
]
}
can not be None.'
assert
_is_all_equal
(
map
(
lambda
n
:
n
.
operation
.
type
,
nodes
)),
\
f
'Node with label "
{
nodes
[
0
].
label
}
" does not all have the same type.'
f
'Node with label "
{
label
}
" does not all have the same type.'
assert
_is_all_equal
(
map
(
lambda
n
:
n
.
operation
.
parameters
,
nodes
)),
\
f
'Node with label "
{
nodes
[
0
].
label
}
" does not agree on parameters.'
f
'Node with label "
{
label
}
" does not agree on parameters.'
if
nodes
[
0
].
operation
.
type
==
'NasBench101Cell'
:
mutators_final
.
append
(
NasBench101Mutator
(
nodes
[
0
].
label
))
# The mutation of Nas-bench-101 is special, and has to be done lastly.
mutators_final
.
append
(
NasBench101Mutator
(
label
))
else
:
mutators
.
append
(
ManyChooseManyMutator
(
nodes
[
0
].
label
))
mutators
.
append
(
ManyChooseManyMutator
(
label
))
return
model
,
mutators
+
mutators_final
...
...
@@ -350,7 +359,7 @@ class EvaluatorValueChoiceLeafMutator(Mutator):
super
().
__init__
(
label
=
label
)
self
.
candidates
=
candidates
def
mutate
(
self
,
model
:
Model
)
->
Model
:
def
mutate
(
self
,
model
:
Model
)
->
None
:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self
.
choice
(
self
.
candidates
)
...
...
@@ -388,7 +397,7 @@ class EvaluatorValueChoiceMutator(Mutator):
return
obj
def
mutate
(
self
,
model
:
Model
):
def
mutate
(
self
,
model
:
Model
)
->
None
:
value_choice_decisions
=
{}
for
mutation
in
model
.
history
:
if
isinstance
(
mutation
.
mutator
,
EvaluatorValueChoiceLeafMutator
):
...
...
@@ -454,7 +463,7 @@ def _is_all_equal(lst):
return
True
def
_group_by_label_and_type
(
nodes
:
List
[
Node
])
->
List
[
List
[
Node
]]:
def
_group_by_label_and_type
(
nodes
:
Iterable
[
Node
])
->
List
[
List
[
Node
]]:
result
=
{}
for
node
in
nodes
:
key
=
(
node
.
label
,
node
.
operation
.
type
)
...
...
@@ -464,7 +473,7 @@ def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
return
list
(
result
.
values
())
def
_group_by_label
(
nodes
:
List
[
Node
])
->
List
[
List
[
Node
]]:
def
_group_by_label
(
nodes
:
Iterable
[
Node
])
->
List
[
List
[
Node
]]:
result
=
{}
for
node
in
nodes
:
label
=
node
.
operation
.
parameters
[
'label'
]
...
...
nni/retiarii/nn/pytorch/nasbench101.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Optional
,
Union
,
Dict
from
typing
import
Callable
,
List
,
Optional
,
Union
,
Dict
,
Tuple
,
cast
import
numpy
as
np
import
torch
...
...
@@ -89,7 +92,7 @@ def compute_vertex_channels(input_channels, output_channels, matrix):
return
vertex_channels
def
prune
(
matrix
,
ops
):
def
prune
(
matrix
,
ops
)
->
Tuple
[
np
.
ndarray
,
List
[
Union
[
str
,
Callable
[[
int
],
nn
.
Module
]]]]
:
"""
Prune the extraneous parts of the graph.
...
...
@@ -152,11 +155,17 @@ class _NasBench101CellFixed(nn.Module):
assert
num_nodes
==
len
(
operations
)
+
2
==
len
(
adjacency_list
)
+
1
self
.
operations
=
[
'IN'
]
+
operations
+
[
'OUT'
]
# add psuedo nodes
raw_operations
:
List
[
Union
[
str
,
Callable
[[
int
],
nn
.
Module
]]]
=
list
(
operations
)
del
operations
# operations is no longer needed. Delete it to avoid misuse
# add psuedo nodes
raw_operations
.
insert
(
0
,
'IN'
)
raw_operations
.
append
(
'OUT'
)
self
.
connection_matrix
=
self
.
build_connection_matrix
(
adjacency_list
,
num_nodes
)
del
num_nodes
# raw number of nodes is no longer used
self
.
connection_matrix
,
self
.
operations
=
prune
(
self
.
connection_matrix
,
self
.
operations
)
self
.
connection_matrix
,
self
.
operations
=
prune
(
self
.
connection_matrix
,
raw_
operations
)
self
.
hidden_features
=
compute_vertex_channels
(
in_features
,
out_features
,
self
.
connection_matrix
)
...
...
@@ -172,7 +181,8 @@ class _NasBench101CellFixed(nn.Module):
self
.
projections
.
append
(
projection
(
in_features
,
self
.
hidden_features
[
i
]))
for
i
in
range
(
1
,
self
.
num_nodes
-
1
):
self
.
ops
.
append
(
operations
[
i
-
1
](
self
.
hidden_features
[
i
]))
operation
=
cast
(
Callable
[[
int
],
nn
.
Module
],
self
.
operations
[
i
])
self
.
ops
.
append
(
operation
(
self
.
hidden_features
[
i
]))
@
staticmethod
def
build_connection_matrix
(
adjacency_list
,
num_nodes
):
...
...
@@ -361,7 +371,7 @@ class NasBench101Mutator(Mutator):
# for validation purposes
# for python execution engine
def
__init__
(
self
,
label
:
Optional
[
str
]
):
def
__init__
(
self
,
label
:
str
):
super
().
__init__
(
label
=
label
)
@
staticmethod
...
...
@@ -378,9 +388,11 @@ class NasBench101Mutator(Mutator):
return
1
def
mutate
(
self
,
model
:
Model
):
max_num_edges
=
cast
(
int
,
None
)
for
node
in
model
.
get_nodes_by_label
(
self
.
label
):
max_num_edges
=
node
.
operation
.
parameters
[
'max_num_edges'
]
break
assert
max_num_edges
is
not
None
mutation_dict
=
{
mut
.
mutator
.
label
:
mut
.
samples
for
mut
in
model
.
history
}
num_nodes
=
mutation_dict
[
f
'
{
self
.
label
}
/num_nodes'
][
0
]
adjacency_list
=
[
mutation_dict
[
f
'
{
self
.
label
}
/input
{
i
}
'
]
for
i
in
range
(
1
,
num_nodes
)]
...
...
nni/retiarii/nn/pytorch/nn.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
inspect
import
warnings
from
pathlib
import
Path
import
torch
import
torch.nn
as
nn
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
nn_cache_file_path
=
Path
(
__file__
).
parent
/
'_nn.py'
cache_valid
=
False
# Update this when cache format changes, to enforce an update.
cache_version
=
2
def
validate_cache
()
->
bool
:
import
torch
cache_valid
=
[]
if
nn_cache_file_path
.
exists
():
lines
=
nn_cache_file_path
.
read_text
().
splitlines
()
for
line
in
lines
:
if
line
.
startswith
(
'# _torch_version'
):
_cached_torch_version
=
line
[
line
.
find
(
'='
)
+
1
:].
strip
()
if
_cached_torch_version
==
torch
.
__version__
:
cache_valid
.
append
(
True
)
if
line
.
startswith
(
'# _torch_nn_cache_version'
):
_cached_cache_version
=
int
(
line
[
line
.
find
(
'='
)
+
1
:].
strip
())
if
_cached_cache_version
==
cache_version
:
cache_valid
.
append
(
True
)
return
len
(
cache_valid
)
>=
2
and
all
(
cache_valid
)
if
nn_cache_file_path
.
exists
():
from
.
import
_nn
# pylint: disable=no-name-in-module
# valid only when torch version match
if
_nn
.
_torch_version
==
torch
.
__version__
:
cache_valid
=
True
if
not
cache_valid
:
def
generate_stub_file
()
->
str
:
import
inspect
import
warnings
import
torch
import
torch.nn
as
nn
_NO_WRAP_CLASSES
=
[
# not an nn.Module
'Parameter'
,
...
...
@@ -47,7 +63,10 @@ if not cache_valid:
'# This file is auto-generated to make auto-completion work.'
,
'# When pytorch version does not match, it will get automatically updated.'
,
'# pylint: skip-file'
,
f
'_torch_version = "
{
torch
.
__version__
}
"'
,
'# pyright: reportGeneralTypeIssues=false'
,
f
'# _torch_version =
{
torch
.
__version__
}
'
,
f
'# _torch_nn_cache_version =
{
cache_version
}
'
,
'import typing'
,
'import torch.nn as nn'
,
'from nni.retiarii.serializer import basic_unit'
,
]
...
...
@@ -66,10 +85,9 @@ if not cache_valid:
'It means your PyTorch version might not be supported.'
,
RuntimeWarning
)
code
.
append
(
f
'
{
name
}
= nn.
{
name
}
'
)
elif
name
in
_WRAP_WITHOUT_TAG_CLASSES
:
code
.
append
(
f
'
{
name
}
= basic_unit(nn.
{
name
}
, basic_unit_tag=False)'
)
code
.
append
(
f
'
{
name
}
=
typing.cast(typing.Type[nn.
{
name
}
],
basic_unit(nn.
{
name
}
, basic_unit_tag=False)
)
'
)
else
:
code
.
append
(
f
'
{
name
}
= basic_unit(nn.
{
name
}
)'
)
code
.
append
(
f
'
{
name
}
= typing.cast(typing.Type[nn.
{
name
}
], basic_unit(nn.
{
name
}
))'
)
all_names
.
append
(
name
)
elif
inspect
.
isfunction
(
obj
)
or
inspect
.
ismodule
(
obj
):
...
...
@@ -78,12 +96,19 @@ if not cache_valid:
code
.
append
(
f
'__all__ =
{
all_names
}
'
)
return
'
\n
'
.
join
(
code
)
def
write_cache
(
code
:
str
)
->
None
:
with
nn_cache_file_path
.
open
(
'w'
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
(
code
))
fp
.
write
(
code
)
code
=
generate_stub_file
()
if
not
validate_cache
():
write_cache
(
code
)
# Import all modules from generated _nn.py
del
Path
,
validate_cache
,
write_cache
,
cache_version
,
nn_cache_file_path
,
code
from
.
import
_nn
# pylint: disable=no-name-in-module
__all__
=
_nn
.
__all__
from
._nn
import
*
# pylint: disable=import-error, wildcard-import
from
._nn
import
*
# pylint: disable=import-error, wildcard-import, unused-wildcard-import
nni/retiarii/oneshot/pytorch/base_lightning.py
View file @
18962129
...
...
@@ -20,7 +20,7 @@ from .supermodule.base import BaseSuperNetModule
__all__
=
[
'MutationHook'
,
'BaseSuperNetModule'
,
'BaseOneShotLightningModule'
,
'traverse_and_mutate_submodules'
]
MutationHook
=
Callable
[[
nn
.
Module
,
str
,
Dict
[
str
,
Any
]],
Union
[
nn
.
Module
,
bool
,
Tuple
[
nn
.
Module
,
bool
]]]
MutationHook
=
Callable
[[
nn
.
Module
,
str
,
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]],
Union
[
nn
.
Module
,
bool
,
Tuple
[
nn
.
Module
,
bool
]]]
def
traverse_and_mutate_submodules
(
...
...
@@ -149,11 +149,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
The hook list will be appended by ``default_mutation_hooks`` in each one-shot module.
To be more specific, the input arguments are
three
arguments:
To be more specific, the input arguments are
four
arguments:
#. a module that might be processed,
#. name of the module in its parent module,
#. a memo dict whose usage depends on the particular algorithm.
#. keyword arguments (configurations).
Note that the memo should be read/written by hooks.
There won't be any hooks called on root module.
...
...
nni/retiarii/oneshot/pytorch/darts.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# type: ignore
import
copy
import
logging
from
collections
import
OrderedDict
...
...
nni/retiarii/oneshot/pytorch/proxyless.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# type: ignore
import
logging
import
torch
...
...
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
View file @
18962129
...
...
@@ -27,7 +27,7 @@ which interprets the slice and apply it on a tensor.
"""
import
operator
from
typing
import
Tuple
,
Union
,
List
,
Dict
,
Callable
,
Optional
,
Iterator
,
TypeVar
,
Any
,
Generic
from
typing
import
Tuple
,
Union
,
List
,
Dict
,
Callable
,
Optional
,
Iterator
,
TypeVar
,
Any
,
Generic
,
cast
import
numpy
as
np
import
torch
...
...
@@ -128,9 +128,10 @@ class Slicable(Generic[T]):
raise
TypeError
(
f
'Unsuppoted weight type:
{
type
(
weight
)
}
'
)
self
.
weight
=
weight
def
__getitem__
(
self
,
index
:
multidim_slice
)
->
T
:
def
__getitem__
(
self
,
index
:
Union
[
slice_type
,
multidim_slice
]
)
->
T
:
if
not
isinstance
(
index
,
tuple
):
index
=
(
index
,
)
index
=
cast
(
multidim_slice
,
index
)
# Get the dict value in index's leafs
# There can be at most one dict
...
...
nni/retiarii/oneshot/pytorch/supermodule/base.py
View file @
18962129
...
...
@@ -24,7 +24,7 @@ class BaseSuperNetModule(nn.Module):
rather than their compositions.
"""
def
resample
(
self
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
resample
(
self
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""
Resample the super-net module.
...
...
@@ -40,7 +40,7 @@ class BaseSuperNetModule(nn.Module):
"""
raise
NotImplementedError
()
def
export
(
self
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
export
(
self
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""
Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``.
...
...
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
View file @
18962129
...
...
@@ -275,11 +275,11 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
if
not
arch
:
yield
name
,
p
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""Differentiable. Do nothing in resample."""
return
{}
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""Export is also random for each leaf value choice."""
result
=
{}
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
...
...
nni/retiarii/oneshot/pytorch/supermodule/operation.py
View file @
18962129
...
...
@@ -8,11 +8,12 @@ which is commonly known as super-kernel (as in channel search), or weight entang
import
inspect
import
itertools
from
typing
import
Union
,
Tuple
,
Dict
,
List
,
Any
,
Type
,
Optional
,
TypeVar
from
typing
import
Union
,
Tuple
,
Dict
,
List
,
Any
,
Type
,
Optional
,
TypeVar
,
cast
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
import
nni.retiarii.nn.pytorch
as
retiarii_nn
from
nni.common.hpo_utils
import
ParameterSpec
...
...
@@ -46,11 +47,11 @@ class MixedOperationSamplingPolicy:
"""
pass
def
resample
(
self
,
operation
:
'MixedOperation'
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
resample
(
self
,
operation
:
'MixedOperation'
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.resample`."""
raise
NotImplementedError
()
def
export
(
self
,
operation
:
'MixedOperation'
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
export
(
self
,
operation
:
'MixedOperation'
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.export`."""
raise
NotImplementedError
()
...
...
@@ -513,43 +514,42 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
embed_dim
=
_W
(
embed_dim
)
# in projection weights & biases has q, k, v weights concatenated together
in_proj_bias
=
in_proj_weight
=
None
in_proj_bias
:
Optional
[
Tensor
]
=
None
in_proj_weight
:
Optional
[
Tensor
]
=
None
if
self
.
in_proj_bias
is
not
None
:
in_proj_bias
=
_S
(
self
.
in_proj_bias
)[
self
.
_to_proj_slice
(
embed_dim
)]
in_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
in_proj_bias
)
)
[
self
.
_to_proj_slice
(
embed_dim
)]
if
self
.
in_proj_weight
is
not
None
:
in_proj_weight
=
_S
(
self
.
in_proj_weight
)[
self
.
_to_proj_slice
(
embed_dim
),
:
embed_dim
]
in_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
in_proj_weight
)
)
[
self
.
_to_proj_slice
(
embed_dim
),
:
embed_dim
]
bias_k
=
_S
(
self
.
bias_k
)[:,
:,
:
embed_dim
]
if
self
.
bias_k
is
not
None
else
None
bias_v
=
_S
(
self
.
bias_v
)[:,
:,
:
embed_dim
]
if
self
.
bias_v
is
not
None
else
None
out_proj_weight
=
_S
(
self
.
out_proj
.
weight
)[:
embed_dim
,
:
embed_dim
]
out_proj_bias
=
_S
(
self
.
out_proj
.
bias
)[:
embed_dim
]
if
self
.
out_proj
.
bias
is
not
None
else
None
bias_k
=
_S
(
cast
(
Tensor
,
self
.
bias_k
)
)
[:,
:,
:
embed_dim
]
if
self
.
bias_k
is
not
None
else
None
bias_v
=
_S
(
cast
(
Tensor
,
self
.
bias_v
)
)
[:,
:,
:
embed_dim
]
if
self
.
bias_v
is
not
None
else
None
out_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
weight
)
)
[:
embed_dim
,
:
embed_dim
]
out_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
bias
)
)
[:
embed_dim
]
if
self
.
out_proj
.
bias
is
not
None
else
None
if
not
qkv_same_embed_dim
:
kdim
=
_W
(
kdim
)
vdim
=
_W
(
vdim
)
q_proj
=
_S
(
self
.
q_proj_weight
)[:
embed_dim
,
:
embed_dim
]
k_proj
=
_S
(
self
.
k_proj_weight
)[:
embed_dim
]
k_proj
=
_S
(
k_proj
)[:,
:
kdim
]
v_proj
=
_S
(
self
.
v_proj_weight
)[:
embed_dim
]
v_proj
=
_S
(
v_proj
)[:,
:
vdim
]
q_proj
=
_S
(
cast
(
Tensor
,
self
.
q_proj_weight
))[:
embed_dim
,
:
embed_dim
]
k_proj
=
_S
(
cast
(
Tensor
,
self
.
k_proj_weight
))[:
embed_dim
]
k_proj
=
_S
(
k_proj
)[:,
:
_W
(
kdim
)]
v_proj
=
_S
(
cast
(
Tensor
,
self
.
v_proj_weight
))[:
embed_dim
]
v_proj
=
_S
(
v_proj
)[:,
:
_W
(
vdim
)]
# The rest part is basically same as pytorch
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
in_proj_weight
,
in_proj_bias
,
cast
(
Tensor
,
in_proj_weight
),
cast
(
Tensor
,
in_proj_bias
)
,
bias_k
,
bias_v
,
self
.
add_zero_attn
,
dropout
,
out_proj_weight
,
out_proj_bias
,
dropout
,
out_proj_weight
,
cast
(
Tensor
,
out_proj_bias
)
,
training
=
self
.
training
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
attn_mask
=
attn_mask
,
use_separate_proj_weight
=
True
,
q_proj_weight
=
q_proj
,
k_proj_weight
=
k_proj
,
v_proj_weight
=
v_proj
)
else
:
# Cast tensor here because of a bug in pytorch stub
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
in_proj_weight
,
in_proj_bias
,
cast
(
Tensor
,
in_proj_weight
),
cast
(
Tensor
,
in_proj_bias
)
,
bias_k
,
bias_v
,
self
.
add_zero_attn
,
dropout
,
out_proj_weight
,
out_proj_bias
,
dropout
,
out_proj_weight
,
cast
(
Tensor
,
out_proj_bias
)
,
training
=
self
.
training
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
attn_mask
=
attn_mask
)
...
...
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
View file @
18962129
...
...
@@ -9,7 +9,7 @@ The support remains limited. Known limitations include:
- The code contains duplicates. Needs refactor.
"""
from
typing
import
List
,
Tuple
,
Optional
from
typing
import
List
,
Tuple
,
Optional
,
cast
import
torch
import
torch.nn
as
nn
...
...
@@ -94,7 +94,7 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
self
.
_sample_idx
=
self
.
op_names
.
index
(
self
.
_sampled
)
else
:
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
self
.
_sample_idx
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
_sample_idx
=
int
(
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
)
self
.
_sampled
=
self
.
op_names
[
self
.
_sample_idx
]
# set binary gates
...
...
@@ -109,10 +109,11 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
"""Chose the argmax if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
op_names
[
torch
.
argmax
(
self
.
_arch_alpha
).
item
()]}
return
{
self
.
label
:
self
.
op_names
[
int
(
torch
.
argmax
(
self
.
_arch_alpha
).
item
()
)
]}
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
assert
binary_grads
is
not
None
with
torch
.
no_grad
():
if
self
.
_arch_alpha
.
grad
is
None
:
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
...
...
@@ -164,13 +165,13 @@ class ProxylessMixedInput(DifferentiableMixedInput):
else
:
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
sample
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
_sampled
=
sample
self
.
_sampled
=
int
(
sample
)
# set binary gates
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
sample
]
=
1.0
self
.
_binary_gates
.
data
[
cast
(
int
,
self
.
_
sample
d
)
]
=
1.0
return
{
self
.
label
:
self
.
_sampled
}
...
...
@@ -182,6 +183,7 @@ class ProxylessMixedInput(DifferentiableMixedInput):
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
assert
binary_grads
is
not
None
with
torch
.
no_grad
():
if
self
.
_arch_alpha
.
grad
is
None
:
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
...
...
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
View file @
18962129
...
...
@@ -129,6 +129,8 @@ class PathSamplingInput(BaseSuperNetModule):
if
isinstance
(
module
,
InputChoice
):
if
module
.
reduction
not
in
[
'sum'
,
'mean'
,
'concat'
]:
raise
ValueError
(
'Only input choice of sum/mean/concat reduction is supported.'
)
if
module
.
n_chosen
is
None
:
raise
ValueError
(
'n_chosen is None is not supported yet.'
)
return
cls
(
module
.
n_candidates
,
module
.
n_chosen
,
module
.
reduction
,
module
.
label
)
def
forward
(
self
,
input_tensors
):
...
...
@@ -161,7 +163,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self
.
_sampled
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""Random sample for each leaf value choice."""
result
=
{}
space_spec
=
operation
.
search_space_spec
()
...
...
@@ -179,7 +181,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return
result
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""Export is also random for each leaf value choice."""
result
=
{}
space_spec
=
operation
.
search_space_spec
()
...
...
nni/retiarii/oneshot/pytorch/utils.py
View file @
18962129
...
...
@@ -132,7 +132,7 @@ def _replace_module_with_type(root_module, init_fn, type_name, modules):
for
name
,
child
in
m
.
named_children
():
if
isinstance
(
child
,
type_name
):
setattr
(
m
,
name
,
init_fn
(
child
))
modules
.
append
((
child
.
key
,
getattr
(
m
,
name
)))
modules
.
append
((
child
.
label
,
getattr
(
m
,
name
)))
else
:
apply
(
child
)
...
...
nni/retiarii/operation.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
(
Any
,
Dict
,
List
)
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
cast
)
from
.
import
debug_configs
...
...
@@ -34,6 +34,8 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size).
"""
io_names
:
List
[
str
]
=
[]
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{},
_internal
:
bool
=
False
,
attributes
:
Dict
[
str
,
Any
]
=
{}):
assert
_internal
,
'`Operation()` is private, use `Operation.new()` instead'
self
.
type
:
str
=
type_name
...
...
@@ -43,7 +45,7 @@ class Operation:
def
to_init_code
(
self
,
field
:
str
)
->
str
:
raise
NotImplementedError
()
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
]
,
inputs_value
:
List
[
Any
]
)
->
str
:
raise
NotImplementedError
()
def
_to_class_name
(
self
)
->
str
:
...
...
@@ -53,8 +55,8 @@ class Operation:
return
True
@
staticmethod
def
new
(
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
cell_name
:
str
=
None
,
attributes
:
Dict
[
str
,
Any
]
=
None
)
->
'Operation'
:
def
new
(
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
)
,
cell_name
:
str
=
cast
(
str
,
None
)
,
attributes
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
)
)
->
'Operation'
:
parameters
=
parameters
or
{}
attributes
=
attributes
or
{}
if
type_name
==
'_cell'
:
...
...
@@ -98,16 +100,16 @@ class PyTorchOperation(Operation):
subclass_name
=
'FunctionalOperator'
for
subclass
in
cls
.
__subclasses__
():
if
hasattr
(
subclass
,
'_ori_type_name'
)
and
\
subclass_name
in
subclass
.
_ori_type_name
:
subclass_name
in
cast
(
Any
,
subclass
)
.
_ori_type_name
:
return
subclass
for
subclass
in
cls
.
__subclasses__
():
if
hasattr
(
subclass
,
'_artificial_op_name'
)
and
\
subclass_name
in
subclass
.
_artificial_op_name
:
subclass_name
in
cast
(
Any
,
subclass
)
.
_artificial_op_name
:
return
subclass
return
cls
@
classmethod
def
to_class_name
(
cls
,
type_name
)
->
str
:
def
to_class_name
(
cls
,
type_name
)
->
Optional
[
str
]
:
if
type_name
.
startswith
(
'__torch__.'
):
return
type_name
[
len
(
'__torch__.'
):]
elif
type_name
.
startswith
(
'__mutated__.'
):
...
...
@@ -119,7 +121,7 @@ class PyTorchOperation(Operation):
def
is_functional
(
cls
,
type_name
)
->
bool
:
return
type_name
.
startswith
(
'Function.'
)
def
_to_class_name
(
self
)
->
str
:
def
_to_class_name
(
self
)
->
Optional
[
str
]
:
if
self
.
type
.
startswith
(
'__torch__.'
):
return
self
.
type
[
len
(
'__torch__.'
):]
elif
self
.
type
.
startswith
(
'__mutated__.'
):
...
...
@@ -127,7 +129,7 @@ class PyTorchOperation(Operation):
else
:
return
None
def
get_import_pkg
(
self
)
->
str
:
def
get_import_pkg
(
self
)
->
Optional
[
str
]
:
if
self
.
type
.
startswith
(
'__torch__.'
):
return
self
.
type
[
len
(
'__torch__.'
):].
split
(
'.'
)[
0
]
elif
self
.
type
.
startswith
(
'__mutated__.'
):
...
...
@@ -135,14 +137,14 @@ class PyTorchOperation(Operation):
else
:
return
None
def
to_init_code
(
self
,
field
:
str
)
->
str
:
def
to_init_code
(
self
,
field
:
str
)
->
Optional
[
str
]
:
if
self
.
_to_class_name
()
is
not
None
:
assert
'positional_args'
not
in
self
.
parameters
kw_params
=
', '
.
join
(
f
'
{
key
}
=
{
repr
(
value
)
}
'
for
key
,
value
in
self
.
parameters
.
items
())
return
f
'self.
{
field
}
=
{
self
.
_to_class_name
()
}
(
{
kw_params
}
)'
return
None
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
"""
Parameters
----------
...
...
@@ -207,7 +209,9 @@ class Cell(PyTorchOperation):
No real usage. Exists for compatibility with base class.
"""
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
attributes
:
Dict
[
str
,
Any
]
=
None
):
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
),
attributes
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
)):
self
.
type
=
'_cell'
self
.
cell_name
=
cell_name
self
.
parameters
=
parameters
or
{}
...
...
@@ -217,7 +221,7 @@ class Cell(PyTorchOperation):
# TODO: ugly, think about how to refactor this part
return
_convert_name
(
self
.
cell_name
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
class
_IOPseudoOperation
(
Operation
):
...
...
@@ -227,7 +231,7 @@ class _IOPseudoOperation(Operation):
especially in static type checking.
"""
def
__init__
(
self
,
type_name
:
str
,
io_names
:
List
=
None
):
def
__init__
(
self
,
type_name
:
str
,
io_names
:
List
[
str
]
=
cast
(
List
[
str
],
None
)
)
:
assert
type_name
.
startswith
(
'_'
)
super
(
_IOPseudoOperation
,
self
).
__init__
(
type_name
,
{},
True
)
self
.
io_names
=
io_names
...
...
@@ -235,7 +239,7 @@ class _IOPseudoOperation(Operation):
def
to_init_code
(
self
,
field
:
str
)
->
str
:
raise
ValueError
(
f
'Cannot generate code for pseudo operation "
{
self
.
type
}
"'
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
]
,
inputs_value
:
List
[
Any
]
)
->
str
:
raise
ValueError
(
f
'Cannot generate code for pseudo operation "
{
self
.
type
}
"'
)
def
__bool__
(
self
)
->
bool
:
...
...
nni/retiarii/operation_def/__init__.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Definition of operation types.
...
...
nni/retiarii/operation_def/torch_op_def.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
from
typing
import
(
Any
,
Dict
,
List
)
import
torch
import
torch.nn.functional
as
nn_functional
from
..operation
import
PyTorchOperation
...
...
@@ -39,23 +42,23 @@ class NoOpIdentity(PyTorchOperation):
"""
_ori_type_name
=
[
'noop_identity'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
", "
.
join
(
inputs
)
}
'
class
ModuleOperator
(
PyTorchOperation
):
_ori_type_name
=
[
'ModuleOperator'
,
'shared'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
class
FunctionalOperator
(
PyTorchOperation
):
_ori_type_name
=
[
'FunctionalOperator'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
func_name
=
self
.
type
[
len
(
'Function.'
):]
if
not
hasattr
(
torch
.
nn
.
functional
,
func_name
):
if
not
hasattr
(
nn
_
functional
,
func_name
):
raise
RuntimeError
(
'For now, we only support calling independent functions from `torch.nn.functional`, '
f
'
{
func_name
}
is not in it.'
)
return
f
'
{
output
}
= F.
{
func_name
}
(
{
", "
.
join
(
inputs
)
}
)'
...
...
@@ -64,7 +67,7 @@ class FunctionalOperator(PyTorchOperation):
class
PrimConstant
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::Constant'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
if
self
.
parameters
[
'type'
]
in
[
'None'
,
'NoneType'
]:
...
...
@@ -87,28 +90,28 @@ class PrimConstant(PyTorchOperation):
class
PrimListConstruct
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::ListConstruct'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
class
PrimListUnpack
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::ListUnpack'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
class
PrimTupleConstruct
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::TupleConstruct'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= (
{
", "
.
join
(
inputs
)
}
)'
class
PrimTupleUnpack
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::TupleUnpack'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
# have single output here, because the following code uses index to access the unpacked values
assert
len
(
inputs
)
==
1
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
...
...
@@ -117,7 +120,7 @@ class PrimTupleUnpack(PyTorchOperation):
class
PrimGetAttr
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::GetAttr'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
if
self
.
parameters
[
'value'
]
is
not
None
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'value'
]
}
"
else
:
...
...
@@ -127,14 +130,14 @@ class PrimGetAttr(PyTorchOperation):
class
PrimUncheckedCast
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::unchecked_cast'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
class
SimpleMember
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::is_cuda'
,
'prim::data'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
member_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
member_name
}
'
...
...
@@ -142,16 +145,16 @@ class SimpleMember(PyTorchOperation):
class
AtenContiguous
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::contiguous'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
# defined in pytorch/c10/core/MemoryFormat.h
assert
inputs_value
[
1
]
in
[
0
,
1
,
2
]
assert
inputs_value
is
not
None
and
inputs_value
[
1
]
in
[
0
,
1
,
2
]
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.contiguous(memory_format=
{
mem_format
[
inputs_value
[
1
]]
}
)'
class
AtenGetitem
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::__getitem__'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
...
...
@@ -159,7 +162,7 @@ class AtenGetitem(PyTorchOperation):
class
AtenAppend
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::append'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
assert
len
(
inputs
)
==
2
return
f
'_,
{
output
}
=
{
inputs
[
0
]
}
.append(
{
inputs
[
1
]
}
),
{
inputs
[
0
]
}
'
...
...
@@ -167,7 +170,7 @@ class AtenAppend(PyTorchOperation):
class
MergedSlice
(
PyTorchOperation
):
_ori_type_name
=
[
'MergedSlice'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
if
(
len
(
inputs
)
-
1
)
%
4
==
0
:
slices
=
[]
dim
=
int
((
len
(
inputs
)
-
1
)
/
4
)
...
...
@@ -187,21 +190,21 @@ class MergedSlice(PyTorchOperation):
class
AtenBool
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::Bool'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
class
AtenNot
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::__not__'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= not
{
inputs
[
0
]
}
'
class
AtenCat
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::cat'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
...
...
@@ -215,7 +218,7 @@ class AtenTensors(PyTorchOperation):
'aten::new_empty'
,
'aten::new_zeros'
,
'aten::arange'
,
'aten::tensor'
,
'aten::ones'
,
'aten::zeros'
,
'aten::as_tensor'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
schemas
=
torch
.
_C
.
_jit_get_schemas_for_operator
(
self
.
type
)
# match number of inputs
overloaded_defs
=
[
len
(
s
.
arguments
)
for
s
in
schemas
]
...
...
@@ -257,40 +260,41 @@ class AtenTensors(PyTorchOperation):
class
AtenFloordiv
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::floordiv'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
//
{
inputs
[
1
]
}
'
class
AtenMul
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::mul'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
*
{
inputs
[
1
]
}
'
class
AtenLen
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::len'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= len(
{
inputs
[
0
]
}
)'
class
AtenIntImplicit
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::IntImplicit'
,
'aten::Float'
,
'aten::Int'
,
'aten::ScalarImplicit'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
if
self
.
type
.
endswith
(
'Implicit'
):
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
elif
self
.
type
==
'aten::Int'
:
return
f
'
{
output
}
= int(
{
inputs
[
0
]
}
)'
elif
self
.
type
==
'aten::Float'
:
return
f
'
{
output
}
= float(
{
inputs
[
0
]
}
)'
raise
TypeError
(
f
'Unexpected type:
{
self
.
type
}
'
)
class
AtenIndex
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::index'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
...
...
@@ -355,13 +359,13 @@ def _get_tensor_ops():
def
_get_torch_ops
():
torch_op_args
=
{}
for
mod
in
torch
.
jit
.
_builtins
.
_modules_containing_builtins
:
for
mod
in
torch
.
jit
.
_builtins
.
_modules_containing_builtins
:
# type: ignore
name
=
mod
.
__name__
if
name
==
'torch._C._nn'
:
continue
# only process 'torch.XXX'
for
elem
in
dir
(
mod
):
builtin
=
torch
.
jit
.
_builtins
.
_find_builtin
(
getattr
(
mod
,
elem
))
builtin
=
torch
.
jit
.
_builtins
.
_find_builtin
(
getattr
(
mod
,
elem
))
# type: ignore
if
builtin
is
not
None
:
schemas
=
torch
.
_C
.
_jit_get_schemas_for_operator
(
builtin
)
for
schema
in
schemas
:
...
...
@@ -436,7 +440,7 @@ class TensorOps(PyTorchOperation):
return
None
raise
RuntimeError
(
f
'tensor op type
{
_type
}
has no matched'
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
# TODO: deal with conditional ops
if
self
.
type
in
TensorOps
.
comparison_ops
:
return
f
'
{
output
}
= (
{
inputs
[
0
]
}
{
TensorOps
.
comparison_ops
[
self
.
type
]
}
{
inputs
[
1
]
}
)'
...
...
@@ -486,7 +490,7 @@ class TorchOps(PyTorchOperation):
else
:
raise
RuntimeError
(
f
'torch op type
{
_type
}
has no matched'
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
matched_args
=
TorchOps
.
_get_matched_args
(
self
.
type
,
inputs
)
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
]
}
'
if
t
.
startswith
(
'Optional['
)
else
f
'
{
inputs
[
i
]
}
'
...
...
@@ -498,7 +502,7 @@ class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name
=
[
'aten::avg_pool2d'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= F.avg_pool2d(
{
", "
.
join
(
inputs
)
}
)'
...
...
@@ -506,7 +510,7 @@ class ToDevice(PyTorchOperation):
_artificial_op_name
=
"ToDevice"
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
],
_internal
:
bool
=
False
,
attributes
:
Dict
[
str
,
Any
]
=
None
):
attributes
:
Dict
[
str
,
Any
]
=
{}
):
self
.
type
=
"ToDevice"
self
.
device
=
parameters
[
'device'
]
self
.
overridden_device_repr
=
None
...
...
@@ -540,5 +544,5 @@ class AtenDet(PyTorchOperation):
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name
=
[
'aten::linalg_det'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= torch.det(
{
inputs
[
0
]
}
)'
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment