Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
18962129
Unverified
Commit
18962129
authored
Apr 25, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 25, 2022
Browse files
Add license header and typehints for NAS (#4774)
parent
8c2f717d
Changes
96
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
249 additions
and
166 deletions
+249
-166
nni/retiarii/nn/pytorch/cell.py
nni/retiarii/nn/pytorch/cell.py
+20
-11
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+14
-12
nni/retiarii/nn/pytorch/hypermodule.py
nni/retiarii/nn/pytorch/hypermodule.py
+3
-1
nni/retiarii/nn/pytorch/mutation_utils.py
nni/retiarii/nn/pytorch/mutation_utils.py
+5
-2
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+35
-26
nni/retiarii/nn/pytorch/nasbench101.py
nni/retiarii/nn/pytorch/nasbench101.py
+18
-6
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+47
-22
nni/retiarii/oneshot/pytorch/base_lightning.py
nni/retiarii/oneshot/pytorch/base_lightning.py
+3
-2
nni/retiarii/oneshot/pytorch/darts.py
nni/retiarii/oneshot/pytorch/darts.py
+2
-0
nni/retiarii/oneshot/pytorch/proxyless.py
nni/retiarii/oneshot/pytorch/proxyless.py
+2
-0
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
+3
-2
nni/retiarii/oneshot/pytorch/supermodule/base.py
nni/retiarii/oneshot/pytorch/supermodule/base.py
+2
-2
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
+2
-2
nni/retiarii/oneshot/pytorch/supermodule/operation.py
nni/retiarii/oneshot/pytorch/supermodule/operation.py
+22
-22
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
+7
-5
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
+4
-2
nni/retiarii/oneshot/pytorch/utils.py
nni/retiarii/oneshot/pytorch/utils.py
+1
-1
nni/retiarii/operation.py
nni/retiarii/operation.py
+19
-15
nni/retiarii/operation_def/__init__.py
nni/retiarii/operation_def/__init__.py
+3
-0
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+37
-33
No files found.
nni/retiarii/nn/pytorch/cell.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
copy
import
warnings
import
warnings
from
typing
import
Callable
,
Dict
,
List
,
Union
,
Optional
,
Tuple
from
typing
import
Callable
,
Dict
,
List
,
Union
,
Optional
,
Tuple
,
Sequence
,
cast
try
:
try
:
from
typing
import
Literal
from
typing
import
Literal
except
ImportError
:
except
ImportError
:
...
@@ -193,8 +196,10 @@ class Cell(nn.Module):
...
@@ -193,8 +196,10 @@ class Cell(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
op_candidates
:
Union
[
op_candidates
:
Union
[
Callable
[[],
List
[
nn
.
Module
]],
Callable
[[],
List
[
nn
.
Module
]],
List
[
Union
[
nn
.
Module
,
_cell_op_factory_type
]],
List
[
nn
.
Module
],
Dict
[
str
,
Union
[
nn
.
Module
,
_cell_op_factory_type
]]
List
[
_cell_op_factory_type
],
Dict
[
str
,
nn
.
Module
],
Dict
[
str
,
_cell_op_factory_type
]
],
],
num_nodes
:
int
,
num_nodes
:
int
,
num_ops_per_node
:
int
=
1
,
num_ops_per_node
:
int
=
1
,
...
@@ -251,8 +256,8 @@ class Cell(nn.Module):
...
@@ -251,8 +256,8 @@ class Cell(nn.Module):
ops
=
self
.
_convert_op_candidates
(
op_candidates
,
i
,
k
,
chosen
)
ops
=
self
.
_convert_op_candidates
(
op_candidates
,
i
,
k
,
chosen
)
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
self
.
ops
[
-
1
].
append
(
LayerChoice
(
ops
,
label
=
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
))
cast
(
ModuleList
,
self
.
ops
[
-
1
]
)
.
append
(
LayerChoice
(
ops
,
label
=
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
))
self
.
inputs
[
-
1
].
append
(
inp
)
cast
(
ModuleList
,
self
.
inputs
[
-
1
]
)
.
append
(
inp
)
@
property
@
property
def
label
(
self
):
def
label
(
self
):
...
@@ -274,13 +279,17 @@ class Cell(nn.Module):
...
@@ -274,13 +279,17 @@ class Cell(nn.Module):
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
of some of (possibly all) the nodes' outputs in the cell.
of some of (possibly all) the nodes' outputs in the cell.
"""
"""
processed_inputs
:
List
[
torch
.
Tensor
]
if
len
(
inputs
)
==
1
and
isinstance
(
inputs
[
0
],
list
):
if
len
(
inputs
)
==
1
and
isinstance
(
inputs
[
0
],
list
):
inputs
=
inputs
[
0
]
processed_
inputs
=
list
(
inputs
[
0
]
)
# shallow copy
else
:
else
:
inputs
=
list
(
inputs
)
processed_inputs
=
cast
(
List
[
torch
.
Tensor
],
list
(
inputs
))
assert
len
(
inputs
)
==
self
.
num_predecessors
,
'The number of inputs must be equal to `num_predecessors`.'
assert
len
(
processed_inputs
)
==
self
.
num_predecessors
,
'The number of inputs must be equal to `num_predecessors`.'
states
=
self
.
preprocessor
(
inputs
)
states
:
List
[
torch
.
Tensor
]
=
self
.
preprocessor
(
processed_inputs
)
for
ops
,
inps
in
zip
(
self
.
ops
,
self
.
inputs
):
for
ops
,
inps
in
zip
(
cast
(
Sequence
[
Sequence
[
LayerChoice
]],
self
.
ops
),
cast
(
Sequence
[
Sequence
[
InputChoice
]],
self
.
inputs
)
):
current_state
=
[]
current_state
=
[]
for
op
,
inp
in
zip
(
ops
,
inps
):
for
op
,
inp
in
zip
(
ops
,
inps
):
current_state
.
append
(
op
(
inp
(
states
)))
current_state
.
append
(
op
(
inp
(
states
)))
...
@@ -291,7 +300,7 @@ class Cell(nn.Module):
...
@@ -291,7 +300,7 @@ class Cell(nn.Module):
this_cell
=
torch
.
cat
(
states
[
self
.
num_predecessors
:],
self
.
concat_dim
)
this_cell
=
torch
.
cat
(
states
[
self
.
num_predecessors
:],
self
.
concat_dim
)
else
:
else
:
this_cell
=
torch
.
cat
([
states
[
k
]
for
k
in
self
.
output_node_indices
],
self
.
concat_dim
)
this_cell
=
torch
.
cat
([
states
[
k
]
for
k
in
self
.
output_node_indices
],
self
.
concat_dim
)
return
self
.
postprocessor
(
this_cell
,
inputs
)
return
self
.
postprocessor
(
this_cell
,
processed_
inputs
)
@
staticmethod
@
staticmethod
def
_convert_op_candidates
(
op_candidates
,
node_index
,
op_index
,
chosen
)
->
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]]:
def
_convert_op_candidates
(
op_candidates
,
node_index
,
op_index
,
chosen
)
->
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]]:
...
...
nni/retiarii/nn/pytorch/component.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
copy
import
warnings
import
warnings
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Union
,
Tuple
,
Optional
from
typing
import
Callable
,
List
,
Dict
,
Union
,
Tuple
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.retiarii.utils
import
NoContextError
,
STATE_DICT_PY_MAPPING_PARTIAL
from
nni.retiarii.utils
import
NoContextError
,
STATE_DICT_PY_MAPPING_PARTIAL
from
.api
import
LayerChoice
,
ValueChoice
,
ValueChoiceX
from
.api
import
LayerChoice
,
ValueChoice
,
ValueChoiceX
,
ChoiceOf
from
.cell
import
Cell
from
.cell
import
Cell
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.mutation_utils
import
Mutable
,
generate_new_label
,
get_fixed_value
from
.mutation_utils
import
Mutable
,
generate_new_label
,
get_fixed_value
...
@@ -64,7 +67,7 @@ class Repeat(Mutable):
...
@@ -64,7 +67,7 @@ class Repeat(Mutable):
List
[
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]],
nn
.
Module
,
nn
.
Module
,
List
[
nn
.
Module
]],
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
],
Value
Choice
],
*
,
label
:
Optional
[
str
]
=
None
):
depth
:
Union
[
int
,
Tuple
[
int
,
int
],
Choice
Of
[
int
]
],
*
,
label
:
Optional
[
str
]
=
None
):
if
isinstance
(
depth
,
tuple
):
if
isinstance
(
depth
,
tuple
):
# we can't create a value choice here,
# we can't create a value choice here,
# otherwise we will have two value choices, one created here, another in init.
# otherwise we will have two value choices, one created here, another in init.
...
@@ -90,7 +93,7 @@ class Repeat(Mutable):
...
@@ -90,7 +93,7 @@ class Repeat(Mutable):
List
[
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]],
nn
.
Module
,
nn
.
Module
,
List
[
nn
.
Module
]],
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
depth
:
Union
[
int
,
Tuple
[
int
,
int
],
ChoiceOf
[
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
_label
=
None
# by default, no label
self
.
_label
=
None
# by default, no label
...
@@ -192,7 +195,7 @@ class NasBench201Cell(nn.Module):
...
@@ -192,7 +195,7 @@ class NasBench201Cell(nn.Module):
return
OrderedDict
([(
str
(
i
),
t
)
for
i
,
t
in
enumerate
(
x
)])
return
OrderedDict
([(
str
(
i
),
t
)
for
i
,
t
in
enumerate
(
x
)])
return
OrderedDict
(
x
)
return
OrderedDict
(
x
)
def
__init__
(
self
,
op_candidates
:
List
[
Callable
[[
int
,
int
],
nn
.
Module
]],
def
__init__
(
self
,
op_candidates
:
Union
[
Dict
[
str
,
Callable
[[
int
,
int
],
nn
.
Module
]],
List
[
Callable
[[
int
,
int
],
nn
.
Module
]]
]
,
in_features
:
int
,
out_features
:
int
,
num_tensors
:
int
=
4
,
in_features
:
int
,
out_features
:
int
,
num_tensors
:
int
=
4
,
label
:
Optional
[
str
]
=
None
):
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
super
().
__init__
()
...
@@ -214,16 +217,15 @@ class NasBench201Cell(nn.Module):
...
@@ -214,16 +217,15 @@ class NasBench201Cell(nn.Module):
node_ops
.
append
(
LayerChoice
(
op_choices
,
label
=
f
'
{
self
.
_label
}
__
{
j
}
_
{
tid
}
'
))
# put __ here to be compatible with base engine
node_ops
.
append
(
LayerChoice
(
op_choices
,
label
=
f
'
{
self
.
_label
}
__
{
j
}
_
{
tid
}
'
))
# put __ here to be compatible with base engine
self
.
layers
.
append
(
node_ops
)
self
.
layers
.
append
(
node_ops
)
def
forward
(
self
,
inputs
)
:
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
The forward of input choice is simply selecting first on all choices.
The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases.
It shouldn't be called directly by users in most cases.
"""
"""
tensors
=
[
inputs
]
tensors
:
List
[
torch
.
Tensor
]
=
[
inputs
]
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
current_tensor
=
[]
current_tensor
:
List
[
torch
.
Tensor
]
=
[]
for
i
,
op
in
enumerate
(
layer
):
for
i
,
op
in
enumerate
(
layer
):
# type: ignore
current_tensor
.
append
(
op
(
tensors
[
i
]))
current_tensor
.
append
(
op
(
tensors
[
i
]))
# type: ignore
current_tensor
=
torch
.
sum
(
torch
.
stack
(
current_tensor
),
0
)
tensors
.
append
(
torch
.
sum
(
torch
.
stack
(
current_tensor
),
0
))
tensors
.
append
(
current_tensor
)
return
tensors
[
-
1
]
return
tensors
[
-
1
]
nni/retiarii/nn/pytorch/hypermodule.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
from
packaging.version
import
Version
from
packaging.version
import
Version
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -233,7 +235,7 @@ class AutoActivation(nn.Module):
...
@@ -233,7 +235,7 @@ class AutoActivation(nn.Module):
-----
-----
Current `beta` is not per-channel parameter.
Current `beta` is not per-channel parameter.
"""
"""
def
__init__
(
self
,
unit_num
:
int
=
1
,
label
:
str
=
None
):
def
__init__
(
self
,
unit_num
:
int
=
1
,
label
:
str
|
None
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
_label
=
generate_new_label
(
label
)
self
.
_label
=
generate_new_label
(
label
)
self
.
unaries
=
nn
.
ModuleList
()
self
.
unaries
=
nn
.
ModuleList
()
...
...
nni/retiarii/nn/pytorch/mutation_utils.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Any
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -41,7 +44,7 @@ def generate_new_label(label: Optional[str]):
...
@@ -41,7 +44,7 @@ def generate_new_label(label: Optional[str]):
return
label
return
label
def
get_fixed_value
(
label
:
str
)
->
Any
:
def
get_fixed_value
(
label
:
Optional
[
str
]
)
->
Any
:
ret
=
get_current_context
(
'fixed'
)
ret
=
get_current_context
(
'fixed'
)
try
:
try
:
return
ret
[
generate_new_label
(
label
)]
return
ret
[
generate_new_label
(
label
)]
...
@@ -49,7 +52,7 @@ def get_fixed_value(label: str) -> Any:
...
@@ -49,7 +52,7 @@ def get_fixed_value(label: str) -> Any:
raise
KeyError
(
f
'Fixed context with
{
label
}
not found. Existing values are:
{
ret
}
'
)
raise
KeyError
(
f
'Fixed context with
{
label
}
not found. Existing values are:
{
ret
}
'
)
def
get_fixed_dict
(
label_prefix
:
str
)
->
Tuple
[
str
,
Any
]:
def
get_fixed_dict
(
label_prefix
:
Optional
[
str
]
)
->
Tuple
[
str
,
Any
]:
ret
=
get_current_context
(
'fixed'
)
ret
=
get_current_context
(
'fixed'
)
try
:
try
:
label_prefix
=
generate_new_label
(
label_prefix
)
label_prefix
=
generate_new_label
(
label_prefix
)
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
18962129
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
inspect
import
inspect
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Dict
,
Iterator
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Dict
,
Iterator
,
Iterable
,
cast
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -28,12 +28,14 @@ class LayerChoiceMutator(Mutator):
...
@@ -28,12 +28,14 @@ class LayerChoiceMutator(Mutator):
# Each layer choice corresponds to a cell, which is unconnected in the base graph.
# Each layer choice corresponds to a cell, which is unconnected in the base graph.
# We add the connections here in the mutation logic.
# We add the connections here in the mutation logic.
# Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
# Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
target
=
model
.
graphs
[
node
.
operation
.
cell_name
]
target
=
model
.
graphs
[
cast
(
Cell
,
node
.
operation
)
.
cell_name
]
chosen_node
=
target
.
get_node_by_name
(
chosen
)
chosen_node
=
target
.
get_node_by_name
(
chosen
)
assert
chosen_node
is
not
None
assert
chosen_node
is
not
None
target
.
add_edge
((
target
.
input_node
,
0
),
(
chosen_node
,
None
))
target
.
add_edge
((
target
.
input_node
,
0
),
(
chosen_node
,
None
))
target
.
add_edge
((
chosen_node
,
None
),
(
target
.
output_node
,
None
))
target
.
add_edge
((
chosen_node
,
None
),
(
target
.
output_node
,
None
))
model
.
get_node_by_name
(
node
.
name
).
update_operation
(
Cell
(
node
.
operation
.
cell_name
))
operation
=
cast
(
Cell
,
node
.
operation
)
target_node
=
cast
(
Node
,
model
.
get_node_by_name
(
node
.
name
))
target_node
.
update_operation
(
Cell
(
operation
.
cell_name
))
# remove redundant nodes
# remove redundant nodes
for
rm_node
in
list
(
target
.
hidden_nodes
):
# remove from a list on the fly will cause issues
for
rm_node
in
list
(
target
.
hidden_nodes
):
# remove from a list on the fly will cause issues
...
@@ -57,7 +59,7 @@ class InputChoiceMutator(Mutator):
...
@@ -57,7 +59,7 @@ class InputChoiceMutator(Mutator):
else
:
else
:
chosen
=
[
self
.
choice
(
candidates
)
for
_
in
range
(
n_chosen
)]
chosen
=
[
self
.
choice
(
candidates
)
for
_
in
range
(
n_chosen
)]
for
node
in
self
.
nodes
:
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
target
=
cast
(
Node
,
model
.
get_node_by_name
(
node
.
name
)
)
target
.
update_operation
(
'__torch__.nni.retiarii.nn.pytorch.ChosenInputs'
,
target
.
update_operation
(
'__torch__.nni.retiarii.nn.pytorch.ChosenInputs'
,
{
'chosen'
:
chosen
,
'reduction'
:
node
.
operation
.
parameters
[
'reduction'
]})
{
'chosen'
:
chosen
,
'reduction'
:
node
.
operation
.
parameters
[
'reduction'
]})
...
@@ -74,7 +76,7 @@ class ValueChoiceMutator(Mutator):
...
@@ -74,7 +76,7 @@ class ValueChoiceMutator(Mutator):
# no need to support transformation here,
# no need to support transformation here,
# because it is naturally done in forward loop
# because it is naturally done in forward loop
for
node
in
self
.
nodes
:
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
target
=
cast
(
Node
,
model
.
get_node_by_name
(
node
.
name
)
)
target
.
update_operation
(
'prim::Constant'
,
{
'type'
:
type
(
chosen
).
__name__
,
'value'
:
chosen
})
target
.
update_operation
(
'prim::Constant'
,
{
'type'
:
type
(
chosen
).
__name__
,
'value'
:
chosen
})
...
@@ -86,7 +88,7 @@ class ParameterChoiceLeafMutator(Mutator):
...
@@ -86,7 +88,7 @@ class ParameterChoiceLeafMutator(Mutator):
super
().
__init__
(
label
=
label
)
super
().
__init__
(
label
=
label
)
self
.
candidates
=
candidates
self
.
candidates
=
candidates
def
mutate
(
self
,
model
:
Model
)
->
Model
:
def
mutate
(
self
,
model
:
Model
)
->
None
:
# leave a record here
# leave a record here
# real mutations will be done in ParameterChoiceMutator
# real mutations will be done in ParameterChoiceMutator
self
.
choice
(
self
.
candidates
)
self
.
choice
(
self
.
candidates
)
...
@@ -103,7 +105,7 @@ class ParameterChoiceMutator(Mutator):
...
@@ -103,7 +105,7 @@ class ParameterChoiceMutator(Mutator):
self
.
nodes
=
nodes
self
.
nodes
=
nodes
def
mutate
(
self
,
model
:
Model
)
->
Model
:
def
mutate
(
self
,
model
:
Model
)
->
None
:
# looks like {"label1": "cat", "label2": 123}
# looks like {"label1": "cat", "label2": 123}
value_choice_decisions
=
{}
value_choice_decisions
=
{}
for
mutation
in
model
.
history
:
for
mutation
in
model
.
history
:
...
@@ -122,7 +124,7 @@ class ParameterChoiceMutator(Mutator):
...
@@ -122,7 +124,7 @@ class ParameterChoiceMutator(Mutator):
result_value
=
value_choice
.
evaluate
(
leaf_node_values
)
result_value
=
value_choice
.
evaluate
(
leaf_node_values
)
# update model with graph mutation primitives
# update model with graph mutation primitives
target
=
model
.
get_node_by_name
(
node
.
name
)
target
=
cast
(
Node
,
model
.
get_node_by_name
(
node
.
name
)
)
target
.
update_operation
(
target
.
operation
.
type
,
{
**
target
.
operation
.
parameters
,
argname
:
result_value
})
target
.
update_operation
(
target
.
operation
.
type
,
{
**
target
.
operation
.
parameters
,
argname
:
result_value
})
...
@@ -138,20 +140,20 @@ class RepeatMutator(Mutator):
...
@@ -138,20 +140,20 @@ class RepeatMutator(Mutator):
while
u
!=
graph
.
output_node
:
while
u
!=
graph
.
output_node
:
if
u
!=
graph
.
input_node
:
if
u
!=
graph
.
input_node
:
chain
.
append
(
u
)
chain
.
append
(
u
)
assert
len
(
u
.
successors
)
==
1
,
f
'This graph is an illegal chain.
{
u
}
has output
{
u
.
successor
}
.'
assert
len
(
u
.
successors
)
==
1
,
f
'This graph is an illegal chain.
{
u
}
has output
{
u
.
successor
s
}
.'
u
=
u
.
successors
[
0
]
u
=
u
.
successors
[
0
]
return
chain
return
chain
def
mutate
(
self
,
model
):
def
mutate
(
self
,
model
):
for
node
in
self
.
nodes
:
for
node
in
self
.
nodes
:
# the logic here is similar to layer choice. We find cell attached to each node.
# the logic here is similar to layer choice. We find cell attached to each node.
target
:
Graph
=
model
.
graphs
[
node
.
operation
.
cell_name
]
target
:
Graph
=
model
.
graphs
[
cast
(
Cell
,
node
.
operation
)
.
cell_name
]
chain
=
self
.
_retrieve_chain_from_graph
(
target
)
chain
=
self
.
_retrieve_chain_from_graph
(
target
)
# and we get the chosen depth (by value choice)
# and we get the chosen depth (by value choice)
node_in_model
=
model
.
get_node_by_name
(
node
.
name
)
node_in_model
=
cast
(
Node
,
model
.
get_node_by_name
(
node
.
name
)
)
# depth is a value choice in base model
# depth is a value choice in base model
# but it's already mutated by a ParameterChoiceMutator here
# but it's already mutated by a ParameterChoiceMutator here
chosen_depth
=
node_in_model
.
operation
.
parameters
[
'depth'
]
chosen_depth
:
int
=
node_in_model
.
operation
.
parameters
[
'depth'
]
for
edge
in
chain
[
chosen_depth
-
1
].
outgoing_edges
:
for
edge
in
chain
[
chosen_depth
-
1
].
outgoing_edges
:
edge
.
remove
()
edge
.
remove
()
target
.
add_edge
((
chain
[
chosen_depth
-
1
],
None
),
(
target
.
output_node
,
None
))
target
.
add_edge
((
chain
[
chosen_depth
-
1
],
None
),
(
target
.
output_node
,
None
))
...
@@ -159,8 +161,11 @@ class RepeatMutator(Mutator):
...
@@ -159,8 +161,11 @@ class RepeatMutator(Mutator):
for
edge
in
rm_node
.
outgoing_edges
:
for
edge
in
rm_node
.
outgoing_edges
:
edge
.
remove
()
edge
.
remove
()
rm_node
.
remove
()
rm_node
.
remove
()
# to delete the unused parameters.
# to delete the unused parameters.
model
.
get_node_by_name
(
node
.
name
).
update_operation
(
Cell
(
node
.
operation
.
cell_name
))
target_node
=
cast
(
Node
,
model
.
get_node_by_name
(
node
.
name
))
cell_operation
=
cast
(
Cell
,
node
.
operation
)
target_node
.
update_operation
(
Cell
(
cell_operation
.
cell_name
))
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
...
@@ -241,7 +246,7 @@ class ManyChooseManyMutator(Mutator):
...
@@ -241,7 +246,7 @@ class ManyChooseManyMutator(Mutator):
Choose based on labels. Will not affect the model itself.
Choose based on labels. Will not affect the model itself.
"""
"""
def
__init__
(
self
,
label
:
Optional
[
str
]
):
def
__init__
(
self
,
label
:
str
):
super
().
__init__
(
label
=
label
)
super
().
__init__
(
label
=
label
)
@
staticmethod
@
staticmethod
...
@@ -257,7 +262,7 @@ class ManyChooseManyMutator(Mutator):
...
@@ -257,7 +262,7 @@ class ManyChooseManyMutator(Mutator):
return
node
.
operation
.
parameters
[
'n_chosen'
]
return
node
.
operation
.
parameters
[
'n_chosen'
]
return
1
return
1
def
mutate
(
self
,
model
:
Model
):
def
mutate
(
self
,
model
:
Model
)
->
None
:
# this mutate does not have any effect, but it is recorded in the mutation history
# this mutate does not have any effect, but it is recorded in the mutation history
for
node
in
model
.
get_nodes_by_label
(
self
.
label
):
for
node
in
model
.
get_nodes_by_label
(
self
.
label
):
n_chosen
=
self
.
number_of_chosen
(
node
)
n_chosen
=
self
.
number_of_chosen
(
node
)
...
@@ -280,12 +285,12 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
...
@@ -280,12 +285,12 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
if
not
is_model_wrapped
(
pytorch_model
):
if
not
is_model_wrapped
(
pytorch_model
):
raise
ValueError
(
'Please annotate the model with @model_wrapper decorator in python execution mode '
raise
ValueError
(
'Please annotate the model with @model_wrapper decorator in python execution mode '
'if your model has init parameters.'
)
'if your model has init parameters.'
)
model
.
python_init_params
=
pytorch_model
.
trace_kwargs
model
.
python_init_params
=
cast
(
dict
,
pytorch_model
.
trace_kwargs
)
else
:
else
:
model
.
python_init_params
=
{}
model
.
python_init_params
=
{}
# hyper-parameter choice
# hyper-parameter choice
namespace
:
ModelNamespace
=
pytorch_model
.
_model_namespace
namespace
:
ModelNamespace
=
cast
(
ModelNamespace
,
pytorch_model
.
_model_namespace
)
for
param_spec
in
namespace
.
parameter_specs
:
for
param_spec
in
namespace
.
parameter_specs
:
assert
param_spec
.
categorical
and
param_spec
.
type
==
'choice'
assert
param_spec
.
categorical
and
param_spec
.
type
==
'choice'
node
=
graph
.
add_node
(
f
'param_spec_
{
param_spec
.
name
}
'
,
'ModelParameterChoice'
,
{
'candidates'
:
param_spec
.
values
})
node
=
graph
.
add_node
(
f
'param_spec_
{
param_spec
.
name
}
'
,
'ModelParameterChoice'
,
{
'candidates'
:
param_spec
.
values
})
...
@@ -294,7 +299,8 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
...
@@ -294,7 +299,8 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
for
name
,
module
in
pytorch_model
.
named_modules
():
for
name
,
module
in
pytorch_model
.
named_modules
():
# tricky case: value choice that serves as parameters are stored in traced arguments
# tricky case: value choice that serves as parameters are stored in traced arguments
if
is_basic_unit
(
module
):
if
is_basic_unit
(
module
):
for
key
,
value
in
module
.
trace_kwargs
.
items
():
trace_kwargs
=
cast
(
Dict
[
str
,
Any
],
module
.
trace_kwargs
)
for
key
,
value
in
trace_kwargs
.
items
():
if
isinstance
(
value
,
ValueChoiceX
):
if
isinstance
(
value
,
ValueChoiceX
):
for
i
,
choice
in
enumerate
(
value
.
inner_choices
()):
for
i
,
choice
in
enumerate
(
value
.
inner_choices
()):
node
=
graph
.
add_node
(
f
'
{
name
}
.init.
{
key
}
.
{
i
}
'
,
'ValueChoice'
,
{
'candidates'
:
choice
.
candidates
})
node
=
graph
.
add_node
(
f
'
{
name
}
.init.
{
key
}
.
{
i
}
'
,
'ValueChoice'
,
{
'candidates'
:
choice
.
candidates
})
...
@@ -329,14 +335,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
...
@@ -329,14 +335,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
mutators
=
[]
mutators
=
[]
mutators_final
=
[]
mutators_final
=
[]
for
nodes
in
_group_by_label_and_type
(
graph
.
hidden_nodes
):
for
nodes
in
_group_by_label_and_type
(
graph
.
hidden_nodes
):
label
=
nodes
[
0
].
label
assert
label
is
not
None
,
f
'label of
{
nodes
[
0
]
}
can not be None.'
assert
_is_all_equal
(
map
(
lambda
n
:
n
.
operation
.
type
,
nodes
)),
\
assert
_is_all_equal
(
map
(
lambda
n
:
n
.
operation
.
type
,
nodes
)),
\
f
'Node with label "
{
nodes
[
0
].
label
}
" does not all have the same type.'
f
'Node with label "
{
label
}
" does not all have the same type.'
assert
_is_all_equal
(
map
(
lambda
n
:
n
.
operation
.
parameters
,
nodes
)),
\
assert
_is_all_equal
(
map
(
lambda
n
:
n
.
operation
.
parameters
,
nodes
)),
\
f
'Node with label "
{
nodes
[
0
].
label
}
" does not agree on parameters.'
f
'Node with label "
{
label
}
" does not agree on parameters.'
if
nodes
[
0
].
operation
.
type
==
'NasBench101Cell'
:
if
nodes
[
0
].
operation
.
type
==
'NasBench101Cell'
:
mutators_final
.
append
(
NasBench101Mutator
(
nodes
[
0
].
label
))
# The mutation of Nas-bench-101 is special, and has to be done lastly.
mutators_final
.
append
(
NasBench101Mutator
(
label
))
else
:
else
:
mutators
.
append
(
ManyChooseManyMutator
(
nodes
[
0
].
label
))
mutators
.
append
(
ManyChooseManyMutator
(
label
))
return
model
,
mutators
+
mutators_final
return
model
,
mutators
+
mutators_final
...
@@ -350,7 +359,7 @@ class EvaluatorValueChoiceLeafMutator(Mutator):
...
@@ -350,7 +359,7 @@ class EvaluatorValueChoiceLeafMutator(Mutator):
super
().
__init__
(
label
=
label
)
super
().
__init__
(
label
=
label
)
self
.
candidates
=
candidates
self
.
candidates
=
candidates
def
mutate
(
self
,
model
:
Model
)
->
Model
:
def
mutate
(
self
,
model
:
Model
)
->
None
:
# leave a record here
# leave a record here
# real mutations will be done in ParameterChoiceMutator
# real mutations will be done in ParameterChoiceMutator
self
.
choice
(
self
.
candidates
)
self
.
choice
(
self
.
candidates
)
...
@@ -388,7 +397,7 @@ class EvaluatorValueChoiceMutator(Mutator):
...
@@ -388,7 +397,7 @@ class EvaluatorValueChoiceMutator(Mutator):
return
obj
return
obj
def
mutate
(
self
,
model
:
Model
):
def
mutate
(
self
,
model
:
Model
)
->
None
:
value_choice_decisions
=
{}
value_choice_decisions
=
{}
for
mutation
in
model
.
history
:
for
mutation
in
model
.
history
:
if
isinstance
(
mutation
.
mutator
,
EvaluatorValueChoiceLeafMutator
):
if
isinstance
(
mutation
.
mutator
,
EvaluatorValueChoiceLeafMutator
):
...
@@ -454,7 +463,7 @@ def _is_all_equal(lst):
...
@@ -454,7 +463,7 @@ def _is_all_equal(lst):
return
True
return
True
def
_group_by_label_and_type
(
nodes
:
List
[
Node
])
->
List
[
List
[
Node
]]:
def
_group_by_label_and_type
(
nodes
:
Iterable
[
Node
])
->
List
[
List
[
Node
]]:
result
=
{}
result
=
{}
for
node
in
nodes
:
for
node
in
nodes
:
key
=
(
node
.
label
,
node
.
operation
.
type
)
key
=
(
node
.
label
,
node
.
operation
.
type
)
...
@@ -464,7 +473,7 @@ def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
...
@@ -464,7 +473,7 @@ def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
return
list
(
result
.
values
())
return
list
(
result
.
values
())
def
_group_by_label
(
nodes
:
List
[
Node
])
->
List
[
List
[
Node
]]:
def
_group_by_label
(
nodes
:
Iterable
[
Node
])
->
List
[
List
[
Node
]]:
result
=
{}
result
=
{}
for
node
in
nodes
:
for
node
in
nodes
:
label
=
node
.
operation
.
parameters
[
'label'
]
label
=
node
.
operation
.
parameters
[
'label'
]
...
...
nni/retiarii/nn/pytorch/nasbench101.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
logging
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Optional
,
Union
,
Dict
from
typing
import
Callable
,
List
,
Optional
,
Union
,
Dict
,
Tuple
,
cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -89,7 +92,7 @@ def compute_vertex_channels(input_channels, output_channels, matrix):
...
@@ -89,7 +92,7 @@ def compute_vertex_channels(input_channels, output_channels, matrix):
return
vertex_channels
return
vertex_channels
def
prune
(
matrix
,
ops
):
def
prune
(
matrix
,
ops
)
->
Tuple
[
np
.
ndarray
,
List
[
Union
[
str
,
Callable
[[
int
],
nn
.
Module
]]]]
:
"""
"""
Prune the extraneous parts of the graph.
Prune the extraneous parts of the graph.
...
@@ -152,11 +155,17 @@ class _NasBench101CellFixed(nn.Module):
...
@@ -152,11 +155,17 @@ class _NasBench101CellFixed(nn.Module):
assert
num_nodes
==
len
(
operations
)
+
2
==
len
(
adjacency_list
)
+
1
assert
num_nodes
==
len
(
operations
)
+
2
==
len
(
adjacency_list
)
+
1
self
.
operations
=
[
'IN'
]
+
operations
+
[
'OUT'
]
# add psuedo nodes
raw_operations
:
List
[
Union
[
str
,
Callable
[[
int
],
nn
.
Module
]]]
=
list
(
operations
)
del
operations
# operations is no longer needed. Delete it to avoid misuse
# add psuedo nodes
raw_operations
.
insert
(
0
,
'IN'
)
raw_operations
.
append
(
'OUT'
)
self
.
connection_matrix
=
self
.
build_connection_matrix
(
adjacency_list
,
num_nodes
)
self
.
connection_matrix
=
self
.
build_connection_matrix
(
adjacency_list
,
num_nodes
)
del
num_nodes
# raw number of nodes is no longer used
del
num_nodes
# raw number of nodes is no longer used
self
.
connection_matrix
,
self
.
operations
=
prune
(
self
.
connection_matrix
,
self
.
operations
)
self
.
connection_matrix
,
self
.
operations
=
prune
(
self
.
connection_matrix
,
raw_
operations
)
self
.
hidden_features
=
compute_vertex_channels
(
in_features
,
out_features
,
self
.
connection_matrix
)
self
.
hidden_features
=
compute_vertex_channels
(
in_features
,
out_features
,
self
.
connection_matrix
)
...
@@ -172,7 +181,8 @@ class _NasBench101CellFixed(nn.Module):
...
@@ -172,7 +181,8 @@ class _NasBench101CellFixed(nn.Module):
self
.
projections
.
append
(
projection
(
in_features
,
self
.
hidden_features
[
i
]))
self
.
projections
.
append
(
projection
(
in_features
,
self
.
hidden_features
[
i
]))
for
i
in
range
(
1
,
self
.
num_nodes
-
1
):
for
i
in
range
(
1
,
self
.
num_nodes
-
1
):
self
.
ops
.
append
(
operations
[
i
-
1
](
self
.
hidden_features
[
i
]))
operation
=
cast
(
Callable
[[
int
],
nn
.
Module
],
self
.
operations
[
i
])
self
.
ops
.
append
(
operation
(
self
.
hidden_features
[
i
]))
@
staticmethod
@
staticmethod
def
build_connection_matrix
(
adjacency_list
,
num_nodes
):
def
build_connection_matrix
(
adjacency_list
,
num_nodes
):
...
@@ -361,7 +371,7 @@ class NasBench101Mutator(Mutator):
...
@@ -361,7 +371,7 @@ class NasBench101Mutator(Mutator):
# for validation purposes
# for validation purposes
# for python execution engine
# for python execution engine
def
__init__
(
self
,
label
:
Optional
[
str
]
):
def
__init__
(
self
,
label
:
str
):
super
().
__init__
(
label
=
label
)
super
().
__init__
(
label
=
label
)
@
staticmethod
@
staticmethod
...
@@ -378,9 +388,11 @@ class NasBench101Mutator(Mutator):
...
@@ -378,9 +388,11 @@ class NasBench101Mutator(Mutator):
return
1
return
1
def
mutate
(
self
,
model
:
Model
):
def
mutate
(
self
,
model
:
Model
):
max_num_edges
=
cast
(
int
,
None
)
for
node
in
model
.
get_nodes_by_label
(
self
.
label
):
for
node
in
model
.
get_nodes_by_label
(
self
.
label
):
max_num_edges
=
node
.
operation
.
parameters
[
'max_num_edges'
]
max_num_edges
=
node
.
operation
.
parameters
[
'max_num_edges'
]
break
break
assert
max_num_edges
is
not
None
mutation_dict
=
{
mut
.
mutator
.
label
:
mut
.
samples
for
mut
in
model
.
history
}
mutation_dict
=
{
mut
.
mutator
.
label
:
mut
.
samples
for
mut
in
model
.
history
}
num_nodes
=
mutation_dict
[
f
'
{
self
.
label
}
/num_nodes'
][
0
]
num_nodes
=
mutation_dict
[
f
'
{
self
.
label
}
/num_nodes'
][
0
]
adjacency_list
=
[
mutation_dict
[
f
'
{
self
.
label
}
/input
{
i
}
'
]
for
i
in
range
(
1
,
num_nodes
)]
adjacency_list
=
[
mutation_dict
[
f
'
{
self
.
label
}
/input
{
i
}
'
]
for
i
in
range
(
1
,
num_nodes
)]
...
...
nni/retiarii/nn/pytorch/nn.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
inspect
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
import
torch
import
torch.nn
as
nn
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
nn_cache_file_path
=
Path
(
__file__
).
parent
/
'_nn.py'
nn_cache_file_path
=
Path
(
__file__
).
parent
/
'_nn.py'
cache_valid
=
False
# Update this when cache format changes, to enforce an update.
cache_version
=
2
def
validate_cache
()
->
bool
:
import
torch
cache_valid
=
[]
if
nn_cache_file_path
.
exists
():
lines
=
nn_cache_file_path
.
read_text
().
splitlines
()
for
line
in
lines
:
if
line
.
startswith
(
'# _torch_version'
):
_cached_torch_version
=
line
[
line
.
find
(
'='
)
+
1
:].
strip
()
if
_cached_torch_version
==
torch
.
__version__
:
cache_valid
.
append
(
True
)
if
line
.
startswith
(
'# _torch_nn_cache_version'
):
_cached_cache_version
=
int
(
line
[
line
.
find
(
'='
)
+
1
:].
strip
())
if
_cached_cache_version
==
cache_version
:
cache_valid
.
append
(
True
)
return
len
(
cache_valid
)
>=
2
and
all
(
cache_valid
)
if
nn_cache_file_path
.
exists
():
from
.
import
_nn
# pylint: disable=no-name-in-module
# valid only when torch version match
if
_nn
.
_torch_version
==
torch
.
__version__
:
cache_valid
=
True
if
not
cache_valid
:
def
generate_stub_file
()
->
str
:
import
inspect
import
warnings
import
torch
import
torch.nn
as
nn
_NO_WRAP_CLASSES
=
[
_NO_WRAP_CLASSES
=
[
# not an nn.Module
# not an nn.Module
'Parameter'
,
'Parameter'
,
...
@@ -47,7 +63,10 @@ if not cache_valid:
...
@@ -47,7 +63,10 @@ if not cache_valid:
'# This file is auto-generated to make auto-completion work.'
,
'# This file is auto-generated to make auto-completion work.'
,
'# When pytorch version does not match, it will get automatically updated.'
,
'# When pytorch version does not match, it will get automatically updated.'
,
'# pylint: skip-file'
,
'# pylint: skip-file'
,
f
'_torch_version = "
{
torch
.
__version__
}
"'
,
'# pyright: reportGeneralTypeIssues=false'
,
f
'# _torch_version =
{
torch
.
__version__
}
'
,
f
'# _torch_nn_cache_version =
{
cache_version
}
'
,
'import typing'
,
'import torch.nn as nn'
,
'import torch.nn as nn'
,
'from nni.retiarii.serializer import basic_unit'
,
'from nni.retiarii.serializer import basic_unit'
,
]
]
...
@@ -66,10 +85,9 @@ if not cache_valid:
...
@@ -66,10 +85,9 @@ if not cache_valid:
'It means your PyTorch version might not be supported.'
,
RuntimeWarning
)
'It means your PyTorch version might not be supported.'
,
RuntimeWarning
)
code
.
append
(
f
'
{
name
}
= nn.
{
name
}
'
)
code
.
append
(
f
'
{
name
}
= nn.
{
name
}
'
)
elif
name
in
_WRAP_WITHOUT_TAG_CLASSES
:
elif
name
in
_WRAP_WITHOUT_TAG_CLASSES
:
code
.
append
(
f
'
{
name
}
= basic_unit(nn.
{
name
}
, basic_unit_tag=False)'
)
code
.
append
(
f
'
{
name
}
=
typing.cast(typing.Type[nn.
{
name
}
],
basic_unit(nn.
{
name
}
, basic_unit_tag=False)
)
'
)
else
:
else
:
code
.
append
(
f
'
{
name
}
= basic_unit(nn.
{
name
}
)'
)
code
.
append
(
f
'
{
name
}
= typing.cast(typing.Type[nn.
{
name
}
], basic_unit(nn.
{
name
}
))'
)
all_names
.
append
(
name
)
all_names
.
append
(
name
)
elif
inspect
.
isfunction
(
obj
)
or
inspect
.
ismodule
(
obj
):
elif
inspect
.
isfunction
(
obj
)
or
inspect
.
ismodule
(
obj
):
...
@@ -78,12 +96,19 @@ if not cache_valid:
...
@@ -78,12 +96,19 @@ if not cache_valid:
code
.
append
(
f
'__all__ =
{
all_names
}
'
)
code
.
append
(
f
'__all__ =
{
all_names
}
'
)
return
'
\n
'
.
join
(
code
)
def
write_cache
(
code
:
str
)
->
None
:
with
nn_cache_file_path
.
open
(
'w'
)
as
fp
:
with
nn_cache_file_path
.
open
(
'w'
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
(
code
))
fp
.
write
(
code
)
code
=
generate_stub_file
()
if
not
validate_cache
():
write_cache
(
code
)
# Import all modules from generated _nn.py
del
Path
,
validate_cache
,
write_cache
,
cache_version
,
nn_cache_file_path
,
code
from
.
import
_nn
# pylint: disable=no-name-in-module
from
._nn
import
*
# pylint: disable=import-error, wildcard-import, unused-wildcard-import
__all__
=
_nn
.
__all__
from
._nn
import
*
# pylint: disable=import-error, wildcard-import
nni/retiarii/oneshot/pytorch/base_lightning.py
View file @
18962129
...
@@ -20,7 +20,7 @@ from .supermodule.base import BaseSuperNetModule
...
@@ -20,7 +20,7 @@ from .supermodule.base import BaseSuperNetModule
__all__
=
[
'MutationHook'
,
'BaseSuperNetModule'
,
'BaseOneShotLightningModule'
,
'traverse_and_mutate_submodules'
]
__all__
=
[
'MutationHook'
,
'BaseSuperNetModule'
,
'BaseOneShotLightningModule'
,
'traverse_and_mutate_submodules'
]
MutationHook
=
Callable
[[
nn
.
Module
,
str
,
Dict
[
str
,
Any
]],
Union
[
nn
.
Module
,
bool
,
Tuple
[
nn
.
Module
,
bool
]]]
MutationHook
=
Callable
[[
nn
.
Module
,
str
,
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]],
Union
[
nn
.
Module
,
bool
,
Tuple
[
nn
.
Module
,
bool
]]]
def
traverse_and_mutate_submodules
(
def
traverse_and_mutate_submodules
(
...
@@ -149,11 +149,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -149,11 +149,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
The hook list will be appended by ``default_mutation_hooks`` in each one-shot module.
The hook list will be appended by ``default_mutation_hooks`` in each one-shot module.
To be more specific, the input arguments are
three
arguments:
To be more specific, the input arguments are
four
arguments:
#. a module that might be processed,
#. a module that might be processed,
#. name of the module in its parent module,
#. name of the module in its parent module,
#. a memo dict whose usage depends on the particular algorithm.
#. a memo dict whose usage depends on the particular algorithm.
#. keyword arguments (configurations).
Note that the memo should be read/written by hooks.
Note that the memo should be read/written by hooks.
There won't be any hooks called on root module.
There won't be any hooks called on root module.
...
...
nni/retiarii/oneshot/pytorch/darts.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
# type: ignore
import
copy
import
copy
import
logging
import
logging
from
collections
import
OrderedDict
from
collections
import
OrderedDict
...
...
nni/retiarii/oneshot/pytorch/proxyless.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
# type: ignore
import
logging
import
logging
import
torch
import
torch
...
...
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
View file @
18962129
...
@@ -27,7 +27,7 @@ which interprets the slice and apply it on a tensor.
...
@@ -27,7 +27,7 @@ which interprets the slice and apply it on a tensor.
"""
"""
import
operator
import
operator
from
typing
import
Tuple
,
Union
,
List
,
Dict
,
Callable
,
Optional
,
Iterator
,
TypeVar
,
Any
,
Generic
from
typing
import
Tuple
,
Union
,
List
,
Dict
,
Callable
,
Optional
,
Iterator
,
TypeVar
,
Any
,
Generic
,
cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -128,9 +128,10 @@ class Slicable(Generic[T]):
...
@@ -128,9 +128,10 @@ class Slicable(Generic[T]):
raise
TypeError
(
f
'Unsuppoted weight type:
{
type
(
weight
)
}
'
)
raise
TypeError
(
f
'Unsuppoted weight type:
{
type
(
weight
)
}
'
)
self
.
weight
=
weight
self
.
weight
=
weight
def
__getitem__
(
self
,
index
:
multidim_slice
)
->
T
:
def
__getitem__
(
self
,
index
:
Union
[
slice_type
,
multidim_slice
]
)
->
T
:
if
not
isinstance
(
index
,
tuple
):
if
not
isinstance
(
index
,
tuple
):
index
=
(
index
,
)
index
=
(
index
,
)
index
=
cast
(
multidim_slice
,
index
)
# Get the dict value in index's leafs
# Get the dict value in index's leafs
# There can be at most one dict
# There can be at most one dict
...
...
nni/retiarii/oneshot/pytorch/supermodule/base.py
View file @
18962129
...
@@ -24,7 +24,7 @@ class BaseSuperNetModule(nn.Module):
...
@@ -24,7 +24,7 @@ class BaseSuperNetModule(nn.Module):
rather than their compositions.
rather than their compositions.
"""
"""
def
resample
(
self
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
resample
(
self
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""
"""
Resample the super-net module.
Resample the super-net module.
...
@@ -40,7 +40,7 @@ class BaseSuperNetModule(nn.Module):
...
@@ -40,7 +40,7 @@ class BaseSuperNetModule(nn.Module):
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
export
(
self
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
export
(
self
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""
"""
Export the final architecture within this module.
Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``.
It should have the same keys as ``search_space_spec()``.
...
...
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
View file @
18962129
...
@@ -275,11 +275,11 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
...
@@ -275,11 +275,11 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
if
not
arch
:
if
not
arch
:
yield
name
,
p
yield
name
,
p
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""Differentiable. Do nothing in resample."""
"""Differentiable. Do nothing in resample."""
return
{}
return
{}
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""Export is also random for each leaf value choice."""
"""Export is also random for each leaf value choice."""
result
=
{}
result
=
{}
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
...
...
nni/retiarii/oneshot/pytorch/supermodule/operation.py
View file @
18962129
...
@@ -8,11 +8,12 @@ which is commonly known as super-kernel (as in channel search), or weight entang
...
@@ -8,11 +8,12 @@ which is commonly known as super-kernel (as in channel search), or weight entang
import
inspect
import
inspect
import
itertools
import
itertools
from
typing
import
Union
,
Tuple
,
Dict
,
List
,
Any
,
Type
,
Optional
,
TypeVar
from
typing
import
Union
,
Tuple
,
Dict
,
List
,
Any
,
Type
,
Optional
,
TypeVar
,
cast
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
import
nni.retiarii.nn.pytorch
as
retiarii_nn
import
nni.retiarii.nn.pytorch
as
retiarii_nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.hpo_utils
import
ParameterSpec
...
@@ -46,11 +47,11 @@ class MixedOperationSamplingPolicy:
...
@@ -46,11 +47,11 @@ class MixedOperationSamplingPolicy:
"""
"""
pass
pass
def
resample
(
self
,
operation
:
'MixedOperation'
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
resample
(
self
,
operation
:
'MixedOperation'
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.resample`."""
"""The handler of :meth:`MixedOperation.resample`."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
export
(
self
,
operation
:
'MixedOperation'
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
export
(
self
,
operation
:
'MixedOperation'
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.export`."""
"""The handler of :meth:`MixedOperation.export`."""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -513,43 +514,42 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
...
@@ -513,43 +514,42 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
embed_dim
=
_W
(
embed_dim
)
embed_dim
=
_W
(
embed_dim
)
# in projection weights & biases has q, k, v weights concatenated together
# in projection weights & biases has q, k, v weights concatenated together
in_proj_bias
=
in_proj_weight
=
None
in_proj_bias
:
Optional
[
Tensor
]
=
None
in_proj_weight
:
Optional
[
Tensor
]
=
None
if
self
.
in_proj_bias
is
not
None
:
if
self
.
in_proj_bias
is
not
None
:
in_proj_bias
=
_S
(
self
.
in_proj_bias
)[
self
.
_to_proj_slice
(
embed_dim
)]
in_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
in_proj_bias
)
)
[
self
.
_to_proj_slice
(
embed_dim
)]
if
self
.
in_proj_weight
is
not
None
:
if
self
.
in_proj_weight
is
not
None
:
in_proj_weight
=
_S
(
self
.
in_proj_weight
)[
self
.
_to_proj_slice
(
embed_dim
),
:
embed_dim
]
in_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
in_proj_weight
)
)
[
self
.
_to_proj_slice
(
embed_dim
),
:
embed_dim
]
bias_k
=
_S
(
self
.
bias_k
)[:,
:,
:
embed_dim
]
if
self
.
bias_k
is
not
None
else
None
bias_k
=
_S
(
cast
(
Tensor
,
self
.
bias_k
)
)
[:,
:,
:
embed_dim
]
if
self
.
bias_k
is
not
None
else
None
bias_v
=
_S
(
self
.
bias_v
)[:,
:,
:
embed_dim
]
if
self
.
bias_v
is
not
None
else
None
bias_v
=
_S
(
cast
(
Tensor
,
self
.
bias_v
)
)
[:,
:,
:
embed_dim
]
if
self
.
bias_v
is
not
None
else
None
out_proj_weight
=
_S
(
self
.
out_proj
.
weight
)[:
embed_dim
,
:
embed_dim
]
out_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
weight
)
)
[:
embed_dim
,
:
embed_dim
]
out_proj_bias
=
_S
(
self
.
out_proj
.
bias
)[:
embed_dim
]
if
self
.
out_proj
.
bias
is
not
None
else
None
out_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
bias
)
)
[:
embed_dim
]
if
self
.
out_proj
.
bias
is
not
None
else
None
if
not
qkv_same_embed_dim
:
if
not
qkv_same_embed_dim
:
kdim
=
_W
(
kdim
)
q_proj
=
_S
(
cast
(
Tensor
,
self
.
q_proj_weight
))[:
embed_dim
,
:
embed_dim
]
vdim
=
_W
(
vdim
)
k_proj
=
_S
(
cast
(
Tensor
,
self
.
k_proj_weight
))[:
embed_dim
]
k_proj
=
_S
(
k_proj
)[:,
:
_W
(
kdim
)]
q_proj
=
_S
(
self
.
q_proj_weight
)[:
embed_dim
,
:
embed_dim
]
v_proj
=
_S
(
cast
(
Tensor
,
self
.
v_proj_weight
))[:
embed_dim
]
k_proj
=
_S
(
self
.
k_proj_weight
)[:
embed_dim
]
v_proj
=
_S
(
v_proj
)[:,
:
_W
(
vdim
)]
k_proj
=
_S
(
k_proj
)[:,
:
kdim
]
v_proj
=
_S
(
self
.
v_proj_weight
)[:
embed_dim
]
v_proj
=
_S
(
v_proj
)[:,
:
vdim
]
# The rest part is basically same as pytorch
# The rest part is basically same as pytorch
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
in_proj_weight
,
in_proj_bias
,
cast
(
Tensor
,
in_proj_weight
),
cast
(
Tensor
,
in_proj_bias
)
,
bias_k
,
bias_v
,
self
.
add_zero_attn
,
bias_k
,
bias_v
,
self
.
add_zero_attn
,
dropout
,
out_proj_weight
,
out_proj_bias
,
dropout
,
out_proj_weight
,
cast
(
Tensor
,
out_proj_bias
)
,
training
=
self
.
training
,
training
=
self
.
training
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
attn_mask
=
attn_mask
,
use_separate_proj_weight
=
True
,
attn_mask
=
attn_mask
,
use_separate_proj_weight
=
True
,
q_proj_weight
=
q_proj
,
k_proj_weight
=
k_proj
,
v_proj_weight
=
v_proj
)
q_proj_weight
=
q_proj
,
k_proj_weight
=
k_proj
,
v_proj_weight
=
v_proj
)
else
:
else
:
# Cast tensor here because of a bug in pytorch stub
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
in_proj_weight
,
in_proj_bias
,
cast
(
Tensor
,
in_proj_weight
),
cast
(
Tensor
,
in_proj_bias
)
,
bias_k
,
bias_v
,
self
.
add_zero_attn
,
bias_k
,
bias_v
,
self
.
add_zero_attn
,
dropout
,
out_proj_weight
,
out_proj_bias
,
dropout
,
out_proj_weight
,
cast
(
Tensor
,
out_proj_bias
)
,
training
=
self
.
training
,
training
=
self
.
training
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
attn_mask
=
attn_mask
)
attn_mask
=
attn_mask
)
...
...
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
View file @
18962129
...
@@ -9,7 +9,7 @@ The support remains limited. Known limitations include:
...
@@ -9,7 +9,7 @@ The support remains limited. Known limitations include:
- The code contains duplicates. Needs refactor.
- The code contains duplicates. Needs refactor.
"""
"""
from
typing
import
List
,
Tuple
,
Optional
from
typing
import
List
,
Tuple
,
Optional
,
cast
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -94,7 +94,7 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
...
@@ -94,7 +94,7 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
self
.
_sample_idx
=
self
.
op_names
.
index
(
self
.
_sampled
)
self
.
_sample_idx
=
self
.
op_names
.
index
(
self
.
_sampled
)
else
:
else
:
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
self
.
_sample_idx
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
_sample_idx
=
int
(
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
)
self
.
_sampled
=
self
.
op_names
[
self
.
_sample_idx
]
self
.
_sampled
=
self
.
op_names
[
self
.
_sample_idx
]
# set binary gates
# set binary gates
...
@@ -109,10 +109,11 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
...
@@ -109,10 +109,11 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
"""Chose the argmax if label isn't found in memo."""
"""Chose the argmax if label isn't found in memo."""
if
self
.
label
in
memo
:
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
op_names
[
torch
.
argmax
(
self
.
_arch_alpha
).
item
()]}
return
{
self
.
label
:
self
.
op_names
[
int
(
torch
.
argmax
(
self
.
_arch_alpha
).
item
()
)
]}
def
finalize_grad
(
self
):
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
binary_grads
=
self
.
_binary_gates
.
grad
assert
binary_grads
is
not
None
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
self
.
_arch_alpha
.
grad
is
None
:
if
self
.
_arch_alpha
.
grad
is
None
:
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
...
@@ -164,13 +165,13 @@ class ProxylessMixedInput(DifferentiableMixedInput):
...
@@ -164,13 +165,13 @@ class ProxylessMixedInput(DifferentiableMixedInput):
else
:
else
:
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
sample
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
sample
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
_sampled
=
sample
self
.
_sampled
=
int
(
sample
)
# set binary gates
# set binary gates
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
sample
]
=
1.0
self
.
_binary_gates
.
data
[
cast
(
int
,
self
.
_
sample
d
)
]
=
1.0
return
{
self
.
label
:
self
.
_sampled
}
return
{
self
.
label
:
self
.
_sampled
}
...
@@ -182,6 +183,7 @@ class ProxylessMixedInput(DifferentiableMixedInput):
...
@@ -182,6 +183,7 @@ class ProxylessMixedInput(DifferentiableMixedInput):
def
finalize_grad
(
self
):
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
binary_grads
=
self
.
_binary_gates
.
grad
assert
binary_grads
is
not
None
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
self
.
_arch_alpha
.
grad
is
None
:
if
self
.
_arch_alpha
.
grad
is
None
:
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
...
...
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
View file @
18962129
...
@@ -129,6 +129,8 @@ class PathSamplingInput(BaseSuperNetModule):
...
@@ -129,6 +129,8 @@ class PathSamplingInput(BaseSuperNetModule):
if
isinstance
(
module
,
InputChoice
):
if
isinstance
(
module
,
InputChoice
):
if
module
.
reduction
not
in
[
'sum'
,
'mean'
,
'concat'
]:
if
module
.
reduction
not
in
[
'sum'
,
'mean'
,
'concat'
]:
raise
ValueError
(
'Only input choice of sum/mean/concat reduction is supported.'
)
raise
ValueError
(
'Only input choice of sum/mean/concat reduction is supported.'
)
if
module
.
n_chosen
is
None
:
raise
ValueError
(
'n_chosen is None is not supported yet.'
)
return
cls
(
module
.
n_candidates
,
module
.
n_chosen
,
module
.
reduction
,
module
.
label
)
return
cls
(
module
.
n_candidates
,
module
.
n_chosen
,
module
.
reduction
,
module
.
label
)
def
forward
(
self
,
input_tensors
):
def
forward
(
self
,
input_tensors
):
...
@@ -161,7 +163,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
...
@@ -161,7 +163,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self
.
_sampled
:
Optional
[
Dict
[
str
,
Any
]]
=
None
self
.
_sampled
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""Random sample for each leaf value choice."""
"""Random sample for each leaf value choice."""
result
=
{}
result
=
{}
space_spec
=
operation
.
search_space_spec
()
space_spec
=
operation
.
search_space_spec
()
...
@@ -179,7 +181,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
...
@@ -179,7 +181,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return
result
return
result
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""Export is also random for each leaf value choice."""
"""Export is also random for each leaf value choice."""
result
=
{}
result
=
{}
space_spec
=
operation
.
search_space_spec
()
space_spec
=
operation
.
search_space_spec
()
...
...
nni/retiarii/oneshot/pytorch/utils.py
View file @
18962129
...
@@ -132,7 +132,7 @@ def _replace_module_with_type(root_module, init_fn, type_name, modules):
...
@@ -132,7 +132,7 @@ def _replace_module_with_type(root_module, init_fn, type_name, modules):
for
name
,
child
in
m
.
named_children
():
for
name
,
child
in
m
.
named_children
():
if
isinstance
(
child
,
type_name
):
if
isinstance
(
child
,
type_name
):
setattr
(
m
,
name
,
init_fn
(
child
))
setattr
(
m
,
name
,
init_fn
(
child
))
modules
.
append
((
child
.
key
,
getattr
(
m
,
name
)))
modules
.
append
((
child
.
label
,
getattr
(
m
,
name
)))
else
:
else
:
apply
(
child
)
apply
(
child
)
...
...
nni/retiarii/operation.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
(
Any
,
Dict
,
List
)
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
cast
)
from
.
import
debug_configs
from
.
import
debug_configs
...
@@ -34,6 +34,8 @@ class Operation:
...
@@ -34,6 +34,8 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size).
Arbitrary key-value parameters (e.g. kernel_size).
"""
"""
io_names
:
List
[
str
]
=
[]
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{},
_internal
:
bool
=
False
,
attributes
:
Dict
[
str
,
Any
]
=
{}):
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{},
_internal
:
bool
=
False
,
attributes
:
Dict
[
str
,
Any
]
=
{}):
assert
_internal
,
'`Operation()` is private, use `Operation.new()` instead'
assert
_internal
,
'`Operation()` is private, use `Operation.new()` instead'
self
.
type
:
str
=
type_name
self
.
type
:
str
=
type_name
...
@@ -43,7 +45,7 @@ class Operation:
...
@@ -43,7 +45,7 @@ class Operation:
def
to_init_code
(
self
,
field
:
str
)
->
str
:
def
to_init_code
(
self
,
field
:
str
)
->
str
:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
]
,
inputs_value
:
List
[
Any
]
)
->
str
:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
_to_class_name
(
self
)
->
str
:
def
_to_class_name
(
self
)
->
str
:
...
@@ -53,8 +55,8 @@ class Operation:
...
@@ -53,8 +55,8 @@ class Operation:
return
True
return
True
@
staticmethod
@
staticmethod
def
new
(
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
cell_name
:
str
=
None
,
def
new
(
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
)
,
cell_name
:
str
=
cast
(
str
,
None
)
,
attributes
:
Dict
[
str
,
Any
]
=
None
)
->
'Operation'
:
attributes
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
)
)
->
'Operation'
:
parameters
=
parameters
or
{}
parameters
=
parameters
or
{}
attributes
=
attributes
or
{}
attributes
=
attributes
or
{}
if
type_name
==
'_cell'
:
if
type_name
==
'_cell'
:
...
@@ -98,16 +100,16 @@ class PyTorchOperation(Operation):
...
@@ -98,16 +100,16 @@ class PyTorchOperation(Operation):
subclass_name
=
'FunctionalOperator'
subclass_name
=
'FunctionalOperator'
for
subclass
in
cls
.
__subclasses__
():
for
subclass
in
cls
.
__subclasses__
():
if
hasattr
(
subclass
,
'_ori_type_name'
)
and
\
if
hasattr
(
subclass
,
'_ori_type_name'
)
and
\
subclass_name
in
subclass
.
_ori_type_name
:
subclass_name
in
cast
(
Any
,
subclass
)
.
_ori_type_name
:
return
subclass
return
subclass
for
subclass
in
cls
.
__subclasses__
():
for
subclass
in
cls
.
__subclasses__
():
if
hasattr
(
subclass
,
'_artificial_op_name'
)
and
\
if
hasattr
(
subclass
,
'_artificial_op_name'
)
and
\
subclass_name
in
subclass
.
_artificial_op_name
:
subclass_name
in
cast
(
Any
,
subclass
)
.
_artificial_op_name
:
return
subclass
return
subclass
return
cls
return
cls
@
classmethod
@
classmethod
def
to_class_name
(
cls
,
type_name
)
->
str
:
def
to_class_name
(
cls
,
type_name
)
->
Optional
[
str
]
:
if
type_name
.
startswith
(
'__torch__.'
):
if
type_name
.
startswith
(
'__torch__.'
):
return
type_name
[
len
(
'__torch__.'
):]
return
type_name
[
len
(
'__torch__.'
):]
elif
type_name
.
startswith
(
'__mutated__.'
):
elif
type_name
.
startswith
(
'__mutated__.'
):
...
@@ -119,7 +121,7 @@ class PyTorchOperation(Operation):
...
@@ -119,7 +121,7 @@ class PyTorchOperation(Operation):
def
is_functional
(
cls
,
type_name
)
->
bool
:
def
is_functional
(
cls
,
type_name
)
->
bool
:
return
type_name
.
startswith
(
'Function.'
)
return
type_name
.
startswith
(
'Function.'
)
def
_to_class_name
(
self
)
->
str
:
def
_to_class_name
(
self
)
->
Optional
[
str
]
:
if
self
.
type
.
startswith
(
'__torch__.'
):
if
self
.
type
.
startswith
(
'__torch__.'
):
return
self
.
type
[
len
(
'__torch__.'
):]
return
self
.
type
[
len
(
'__torch__.'
):]
elif
self
.
type
.
startswith
(
'__mutated__.'
):
elif
self
.
type
.
startswith
(
'__mutated__.'
):
...
@@ -127,7 +129,7 @@ class PyTorchOperation(Operation):
...
@@ -127,7 +129,7 @@ class PyTorchOperation(Operation):
else
:
else
:
return
None
return
None
def
get_import_pkg
(
self
)
->
str
:
def
get_import_pkg
(
self
)
->
Optional
[
str
]
:
if
self
.
type
.
startswith
(
'__torch__.'
):
if
self
.
type
.
startswith
(
'__torch__.'
):
return
self
.
type
[
len
(
'__torch__.'
):].
split
(
'.'
)[
0
]
return
self
.
type
[
len
(
'__torch__.'
):].
split
(
'.'
)[
0
]
elif
self
.
type
.
startswith
(
'__mutated__.'
):
elif
self
.
type
.
startswith
(
'__mutated__.'
):
...
@@ -135,14 +137,14 @@ class PyTorchOperation(Operation):
...
@@ -135,14 +137,14 @@ class PyTorchOperation(Operation):
else
:
else
:
return
None
return
None
def
to_init_code
(
self
,
field
:
str
)
->
str
:
def
to_init_code
(
self
,
field
:
str
)
->
Optional
[
str
]
:
if
self
.
_to_class_name
()
is
not
None
:
if
self
.
_to_class_name
()
is
not
None
:
assert
'positional_args'
not
in
self
.
parameters
assert
'positional_args'
not
in
self
.
parameters
kw_params
=
', '
.
join
(
f
'
{
key
}
=
{
repr
(
value
)
}
'
for
key
,
value
in
self
.
parameters
.
items
())
kw_params
=
', '
.
join
(
f
'
{
key
}
=
{
repr
(
value
)
}
'
for
key
,
value
in
self
.
parameters
.
items
())
return
f
'self.
{
field
}
=
{
self
.
_to_class_name
()
}
(
{
kw_params
}
)'
return
f
'self.
{
field
}
=
{
self
.
_to_class_name
()
}
(
{
kw_params
}
)'
return
None
return
None
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
"""
"""
Parameters
Parameters
----------
----------
...
@@ -207,7 +209,9 @@ class Cell(PyTorchOperation):
...
@@ -207,7 +209,9 @@ class Cell(PyTorchOperation):
No real usage. Exists for compatibility with base class.
No real usage. Exists for compatibility with base class.
"""
"""
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
attributes
:
Dict
[
str
,
Any
]
=
None
):
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
),
attributes
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
)):
self
.
type
=
'_cell'
self
.
type
=
'_cell'
self
.
cell_name
=
cell_name
self
.
cell_name
=
cell_name
self
.
parameters
=
parameters
or
{}
self
.
parameters
=
parameters
or
{}
...
@@ -217,7 +221,7 @@ class Cell(PyTorchOperation):
...
@@ -217,7 +221,7 @@ class Cell(PyTorchOperation):
# TODO: ugly, think about how to refactor this part
# TODO: ugly, think about how to refactor this part
return
_convert_name
(
self
.
cell_name
)
return
_convert_name
(
self
.
cell_name
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
class
_IOPseudoOperation
(
Operation
):
class
_IOPseudoOperation
(
Operation
):
...
@@ -227,7 +231,7 @@ class _IOPseudoOperation(Operation):
...
@@ -227,7 +231,7 @@ class _IOPseudoOperation(Operation):
especially in static type checking.
especially in static type checking.
"""
"""
def
__init__
(
self
,
type_name
:
str
,
io_names
:
List
=
None
):
def
__init__
(
self
,
type_name
:
str
,
io_names
:
List
[
str
]
=
cast
(
List
[
str
],
None
)
)
:
assert
type_name
.
startswith
(
'_'
)
assert
type_name
.
startswith
(
'_'
)
super
(
_IOPseudoOperation
,
self
).
__init__
(
type_name
,
{},
True
)
super
(
_IOPseudoOperation
,
self
).
__init__
(
type_name
,
{},
True
)
self
.
io_names
=
io_names
self
.
io_names
=
io_names
...
@@ -235,7 +239,7 @@ class _IOPseudoOperation(Operation):
...
@@ -235,7 +239,7 @@ class _IOPseudoOperation(Operation):
def
to_init_code
(
self
,
field
:
str
)
->
str
:
def
to_init_code
(
self
,
field
:
str
)
->
str
:
raise
ValueError
(
f
'Cannot generate code for pseudo operation "
{
self
.
type
}
"'
)
raise
ValueError
(
f
'Cannot generate code for pseudo operation "
{
self
.
type
}
"'
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
]
,
inputs_value
:
List
[
Any
]
)
->
str
:
raise
ValueError
(
f
'Cannot generate code for pseudo operation "
{
self
.
type
}
"'
)
raise
ValueError
(
f
'Cannot generate code for pseudo operation "
{
self
.
type
}
"'
)
def
__bool__
(
self
)
->
bool
:
def
__bool__
(
self
)
->
bool
:
...
...
nni/retiarii/operation_def/__init__.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
"""
Definition of operation types.
Definition of operation types.
...
...
nni/retiarii/operation_def/torch_op_def.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
from
typing
import
(
Any
,
Dict
,
List
)
from
typing
import
(
Any
,
Dict
,
List
)
import
torch
import
torch
import
torch.nn.functional
as
nn_functional
from
..operation
import
PyTorchOperation
from
..operation
import
PyTorchOperation
...
@@ -39,23 +42,23 @@ class NoOpIdentity(PyTorchOperation):
...
@@ -39,23 +42,23 @@ class NoOpIdentity(PyTorchOperation):
"""
"""
_ori_type_name
=
[
'noop_identity'
]
_ori_type_name
=
[
'noop_identity'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
", "
.
join
(
inputs
)
}
'
return
f
'
{
output
}
=
{
", "
.
join
(
inputs
)
}
'
class
ModuleOperator
(
PyTorchOperation
):
class
ModuleOperator
(
PyTorchOperation
):
_ori_type_name
=
[
'ModuleOperator'
,
'shared'
]
_ori_type_name
=
[
'ModuleOperator'
,
'shared'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
class
FunctionalOperator
(
PyTorchOperation
):
class
FunctionalOperator
(
PyTorchOperation
):
_ori_type_name
=
[
'FunctionalOperator'
]
_ori_type_name
=
[
'FunctionalOperator'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
func_name
=
self
.
type
[
len
(
'Function.'
):]
func_name
=
self
.
type
[
len
(
'Function.'
):]
if
not
hasattr
(
torch
.
nn
.
functional
,
func_name
):
if
not
hasattr
(
nn
_
functional
,
func_name
):
raise
RuntimeError
(
'For now, we only support calling independent functions from `torch.nn.functional`, '
raise
RuntimeError
(
'For now, we only support calling independent functions from `torch.nn.functional`, '
f
'
{
func_name
}
is not in it.'
)
f
'
{
func_name
}
is not in it.'
)
return
f
'
{
output
}
= F.
{
func_name
}
(
{
", "
.
join
(
inputs
)
}
)'
return
f
'
{
output
}
= F.
{
func_name
}
(
{
", "
.
join
(
inputs
)
}
)'
...
@@ -64,7 +67,7 @@ class FunctionalOperator(PyTorchOperation):
...
@@ -64,7 +67,7 @@ class FunctionalOperator(PyTorchOperation):
class
PrimConstant
(
PyTorchOperation
):
class
PrimConstant
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::Constant'
]
_ori_type_name
=
[
'prim::Constant'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
# TODO: deal with all the types
if
self
.
parameters
[
'type'
]
in
[
'None'
,
'NoneType'
]:
if
self
.
parameters
[
'type'
]
in
[
'None'
,
'NoneType'
]:
...
@@ -87,28 +90,28 @@ class PrimConstant(PyTorchOperation):
...
@@ -87,28 +90,28 @@ class PrimConstant(PyTorchOperation):
class
PrimListConstruct
(
PyTorchOperation
):
class
PrimListConstruct
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::ListConstruct'
]
_ori_type_name
=
[
'prim::ListConstruct'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
class
PrimListUnpack
(
PyTorchOperation
):
class
PrimListUnpack
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::ListUnpack'
]
_ori_type_name
=
[
'prim::ListUnpack'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
class
PrimTupleConstruct
(
PyTorchOperation
):
class
PrimTupleConstruct
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::TupleConstruct'
]
_ori_type_name
=
[
'prim::TupleConstruct'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= (
{
", "
.
join
(
inputs
)
}
)'
return
f
'
{
output
}
= (
{
", "
.
join
(
inputs
)
}
)'
class
PrimTupleUnpack
(
PyTorchOperation
):
class
PrimTupleUnpack
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::TupleUnpack'
]
_ori_type_name
=
[
'prim::TupleUnpack'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
# have single output here, because the following code uses index to access the unpacked values
# have single output here, because the following code uses index to access the unpacked values
assert
len
(
inputs
)
==
1
assert
len
(
inputs
)
==
1
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
...
@@ -117,7 +120,7 @@ class PrimTupleUnpack(PyTorchOperation):
...
@@ -117,7 +120,7 @@ class PrimTupleUnpack(PyTorchOperation):
class
PrimGetAttr
(
PyTorchOperation
):
class
PrimGetAttr
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::GetAttr'
]
_ori_type_name
=
[
'prim::GetAttr'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
if
self
.
parameters
[
'value'
]
is
not
None
:
if
self
.
parameters
[
'value'
]
is
not
None
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'value'
]
}
"
return
f
"
{
output
}
=
{
self
.
parameters
[
'value'
]
}
"
else
:
else
:
...
@@ -127,14 +130,14 @@ class PrimGetAttr(PyTorchOperation):
...
@@ -127,14 +130,14 @@ class PrimGetAttr(PyTorchOperation):
class
PrimUncheckedCast
(
PyTorchOperation
):
class
PrimUncheckedCast
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::unchecked_cast'
]
_ori_type_name
=
[
'prim::unchecked_cast'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
class
SimpleMember
(
PyTorchOperation
):
class
SimpleMember
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::is_cuda'
,
'prim::data'
]
_ori_type_name
=
[
'prim::is_cuda'
,
'prim::data'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
member_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
member_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
member_name
}
'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
member_name
}
'
...
@@ -142,16 +145,16 @@ class SimpleMember(PyTorchOperation):
...
@@ -142,16 +145,16 @@ class SimpleMember(PyTorchOperation):
class
AtenContiguous
(
PyTorchOperation
):
class
AtenContiguous
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::contiguous'
]
_ori_type_name
=
[
'aten::contiguous'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
# defined in pytorch/c10/core/MemoryFormat.h
# defined in pytorch/c10/core/MemoryFormat.h
assert
inputs_value
[
1
]
in
[
0
,
1
,
2
]
assert
inputs_value
is
not
None
and
inputs_value
[
1
]
in
[
0
,
1
,
2
]
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.contiguous(memory_format=
{
mem_format
[
inputs_value
[
1
]]
}
)'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.contiguous(memory_format=
{
mem_format
[
inputs_value
[
1
]]
}
)'
class
AtenGetitem
(
PyTorchOperation
):
class
AtenGetitem
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::__getitem__'
]
_ori_type_name
=
[
'aten::__getitem__'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
...
@@ -159,7 +162,7 @@ class AtenGetitem(PyTorchOperation):
...
@@ -159,7 +162,7 @@ class AtenGetitem(PyTorchOperation):
class
AtenAppend
(
PyTorchOperation
):
class
AtenAppend
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::append'
]
_ori_type_name
=
[
'aten::append'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
return
f
'_,
{
output
}
=
{
inputs
[
0
]
}
.append(
{
inputs
[
1
]
}
),
{
inputs
[
0
]
}
'
return
f
'_,
{
output
}
=
{
inputs
[
0
]
}
.append(
{
inputs
[
1
]
}
),
{
inputs
[
0
]
}
'
...
@@ -167,7 +170,7 @@ class AtenAppend(PyTorchOperation):
...
@@ -167,7 +170,7 @@ class AtenAppend(PyTorchOperation):
class
MergedSlice
(
PyTorchOperation
):
class
MergedSlice
(
PyTorchOperation
):
_ori_type_name
=
[
'MergedSlice'
]
_ori_type_name
=
[
'MergedSlice'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
if
(
len
(
inputs
)
-
1
)
%
4
==
0
:
if
(
len
(
inputs
)
-
1
)
%
4
==
0
:
slices
=
[]
slices
=
[]
dim
=
int
((
len
(
inputs
)
-
1
)
/
4
)
dim
=
int
((
len
(
inputs
)
-
1
)
/
4
)
...
@@ -187,21 +190,21 @@ class MergedSlice(PyTorchOperation):
...
@@ -187,21 +190,21 @@ class MergedSlice(PyTorchOperation):
class
AtenBool
(
PyTorchOperation
):
class
AtenBool
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::Bool'
]
_ori_type_name
=
[
'aten::Bool'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
class
AtenNot
(
PyTorchOperation
):
class
AtenNot
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::__not__'
]
_ori_type_name
=
[
'aten::__not__'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= not
{
inputs
[
0
]
}
'
return
f
'
{
output
}
= not
{
inputs
[
0
]
}
'
class
AtenCat
(
PyTorchOperation
):
class
AtenCat
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::cat'
]
_ori_type_name
=
[
'aten::cat'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
...
@@ -215,7 +218,7 @@ class AtenTensors(PyTorchOperation):
...
@@ -215,7 +218,7 @@ class AtenTensors(PyTorchOperation):
'aten::new_empty'
,
'aten::new_zeros'
,
'aten::arange'
,
'aten::new_empty'
,
'aten::new_zeros'
,
'aten::arange'
,
'aten::tensor'
,
'aten::ones'
,
'aten::zeros'
,
'aten::as_tensor'
]
'aten::tensor'
,
'aten::ones'
,
'aten::zeros'
,
'aten::as_tensor'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
schemas
=
torch
.
_C
.
_jit_get_schemas_for_operator
(
self
.
type
)
schemas
=
torch
.
_C
.
_jit_get_schemas_for_operator
(
self
.
type
)
# match number of inputs
# match number of inputs
overloaded_defs
=
[
len
(
s
.
arguments
)
for
s
in
schemas
]
overloaded_defs
=
[
len
(
s
.
arguments
)
for
s
in
schemas
]
...
@@ -257,40 +260,41 @@ class AtenTensors(PyTorchOperation):
...
@@ -257,40 +260,41 @@ class AtenTensors(PyTorchOperation):
class
AtenFloordiv
(
PyTorchOperation
):
class
AtenFloordiv
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::floordiv'
]
_ori_type_name
=
[
'aten::floordiv'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
//
{
inputs
[
1
]
}
'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
//
{
inputs
[
1
]
}
'
class
AtenMul
(
PyTorchOperation
):
class
AtenMul
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::mul'
]
_ori_type_name
=
[
'aten::mul'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
*
{
inputs
[
1
]
}
'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
*
{
inputs
[
1
]
}
'
class
AtenLen
(
PyTorchOperation
):
class
AtenLen
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::len'
]
_ori_type_name
=
[
'aten::len'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= len(
{
inputs
[
0
]
}
)'
return
f
'
{
output
}
= len(
{
inputs
[
0
]
}
)'
class
AtenIntImplicit
(
PyTorchOperation
):
class
AtenIntImplicit
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::IntImplicit'
,
'aten::Float'
,
'aten::Int'
,
'aten::ScalarImplicit'
]
_ori_type_name
=
[
'aten::IntImplicit'
,
'aten::Float'
,
'aten::Int'
,
'aten::ScalarImplicit'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
if
self
.
type
.
endswith
(
'Implicit'
):
if
self
.
type
.
endswith
(
'Implicit'
):
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
elif
self
.
type
==
'aten::Int'
:
elif
self
.
type
==
'aten::Int'
:
return
f
'
{
output
}
= int(
{
inputs
[
0
]
}
)'
return
f
'
{
output
}
= int(
{
inputs
[
0
]
}
)'
elif
self
.
type
==
'aten::Float'
:
elif
self
.
type
==
'aten::Float'
:
return
f
'
{
output
}
= float(
{
inputs
[
0
]
}
)'
return
f
'
{
output
}
= float(
{
inputs
[
0
]
}
)'
raise
TypeError
(
f
'Unexpected type:
{
self
.
type
}
'
)
class
AtenIndex
(
PyTorchOperation
):
class
AtenIndex
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::index'
]
_ori_type_name
=
[
'aten::index'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
...
@@ -355,13 +359,13 @@ def _get_tensor_ops():
...
@@ -355,13 +359,13 @@ def _get_tensor_ops():
def
_get_torch_ops
():
def
_get_torch_ops
():
torch_op_args
=
{}
torch_op_args
=
{}
for
mod
in
torch
.
jit
.
_builtins
.
_modules_containing_builtins
:
for
mod
in
torch
.
jit
.
_builtins
.
_modules_containing_builtins
:
# type: ignore
name
=
mod
.
__name__
name
=
mod
.
__name__
if
name
==
'torch._C._nn'
:
if
name
==
'torch._C._nn'
:
continue
continue
# only process 'torch.XXX'
# only process 'torch.XXX'
for
elem
in
dir
(
mod
):
for
elem
in
dir
(
mod
):
builtin
=
torch
.
jit
.
_builtins
.
_find_builtin
(
getattr
(
mod
,
elem
))
builtin
=
torch
.
jit
.
_builtins
.
_find_builtin
(
getattr
(
mod
,
elem
))
# type: ignore
if
builtin
is
not
None
:
if
builtin
is
not
None
:
schemas
=
torch
.
_C
.
_jit_get_schemas_for_operator
(
builtin
)
schemas
=
torch
.
_C
.
_jit_get_schemas_for_operator
(
builtin
)
for
schema
in
schemas
:
for
schema
in
schemas
:
...
@@ -436,7 +440,7 @@ class TensorOps(PyTorchOperation):
...
@@ -436,7 +440,7 @@ class TensorOps(PyTorchOperation):
return
None
return
None
raise
RuntimeError
(
f
'tensor op type
{
_type
}
has no matched'
)
raise
RuntimeError
(
f
'tensor op type
{
_type
}
has no matched'
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
# TODO: deal with conditional ops
# TODO: deal with conditional ops
if
self
.
type
in
TensorOps
.
comparison_ops
:
if
self
.
type
in
TensorOps
.
comparison_ops
:
return
f
'
{
output
}
= (
{
inputs
[
0
]
}
{
TensorOps
.
comparison_ops
[
self
.
type
]
}
{
inputs
[
1
]
}
)'
return
f
'
{
output
}
= (
{
inputs
[
0
]
}
{
TensorOps
.
comparison_ops
[
self
.
type
]
}
{
inputs
[
1
]
}
)'
...
@@ -486,7 +490,7 @@ class TorchOps(PyTorchOperation):
...
@@ -486,7 +490,7 @@ class TorchOps(PyTorchOperation):
else
:
else
:
raise
RuntimeError
(
f
'torch op type
{
_type
}
has no matched'
)
raise
RuntimeError
(
f
'torch op type
{
_type
}
has no matched'
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
matched_args
=
TorchOps
.
_get_matched_args
(
self
.
type
,
inputs
)
matched_args
=
TorchOps
.
_get_matched_args
(
self
.
type
,
inputs
)
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
]
}
'
if
t
.
startswith
(
'Optional['
)
else
f
'
{
inputs
[
i
]
}
'
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
]
}
'
if
t
.
startswith
(
'Optional['
)
else
f
'
{
inputs
[
i
]
}
'
...
@@ -498,7 +502,7 @@ class AtenAvgpool2d(PyTorchOperation):
...
@@ -498,7 +502,7 @@ class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name
=
[
'aten::avg_pool2d'
]
_ori_type_name
=
[
'aten::avg_pool2d'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= F.avg_pool2d(
{
", "
.
join
(
inputs
)
}
)'
return
f
'
{
output
}
= F.avg_pool2d(
{
", "
.
join
(
inputs
)
}
)'
...
@@ -506,7 +510,7 @@ class ToDevice(PyTorchOperation):
...
@@ -506,7 +510,7 @@ class ToDevice(PyTorchOperation):
_artificial_op_name
=
"ToDevice"
_artificial_op_name
=
"ToDevice"
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
],
_internal
:
bool
=
False
,
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
],
_internal
:
bool
=
False
,
attributes
:
Dict
[
str
,
Any
]
=
None
):
attributes
:
Dict
[
str
,
Any
]
=
{}
):
self
.
type
=
"ToDevice"
self
.
type
=
"ToDevice"
self
.
device
=
parameters
[
'device'
]
self
.
device
=
parameters
[
'device'
]
self
.
overridden_device_repr
=
None
self
.
overridden_device_repr
=
None
...
@@ -540,5 +544,5 @@ class AtenDet(PyTorchOperation):
...
@@ -540,5 +544,5 @@ class AtenDet(PyTorchOperation):
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name
=
[
'aten::linalg_det'
]
_ori_type_name
=
[
'aten::linalg_det'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= torch.det(
{
inputs
[
0
]
}
)'
return
f
'
{
output
}
= torch.det(
{
inputs
[
0
]
}
)'
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment