Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
8a1fdd53
Unverified
Commit
8a1fdd53
authored
Mar 23, 2022
by
Yuge Zhang
Committed by
GitHub
Mar 23, 2022
Browse files
Remove PyTorch version larger equal in Retiarii (#4622)
parent
50ab44ba
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
111 additions
and
163 deletions
+111
-163
nni/retiarii/nn/pytorch/.gitignore
nni/retiarii/nn/pytorch/.gitignore
+1
-0
nni/retiarii/nn/pytorch/__init__.py
nni/retiarii/nn/pytorch/__init__.py
+4
-1
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+1
-1
nni/retiarii/nn/pytorch/cell.py
nni/retiarii/nn/pytorch/cell.py
+2
-2
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+1
-1
nni/retiarii/nn/pytorch/hypermodule.py
nni/retiarii/nn/pytorch/hypermodule.py
+1
-1
nni/retiarii/nn/pytorch/mutation_utils.py
nni/retiarii/nn/pytorch/mutation_utils.py
+0
-0
nni/retiarii/nn/pytorch/nasbench101.py
nni/retiarii/nn/pytorch/nasbench101.py
+1
-1
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+83
-152
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+10
-4
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+7
-0
No files found.
nni/retiarii/nn/pytorch/.gitignore
0 → 100644
View file @
8a1fdd53
_nn.py
nni/retiarii/nn/pytorch/__init__.py
View file @
8a1fdd53
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.api
import
*
from
.component
import
*
from
.nn
import
*
from
.hypermodule
import
*
\ No newline at end of file
from
.hypermodule
import
*
nni/retiarii/nn/pytorch/api.py
View file @
8a1fdd53
...
...
@@ -12,7 +12,7 @@ import torch.nn as nn
from
nni.common.serializer
import
Translatable
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_value
from
.
mutation_
utils
import
Mutable
,
generate_new_label
,
get_fixed_value
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'Placeholder'
,
'ChosenInputs'
]
...
...
nni/retiarii/nn/pytorch/cell.py
View file @
8a1fdd53
...
...
@@ -10,8 +10,8 @@ import torch
import
torch.nn
as
nn
from
.api
import
ChosenInputs
,
LayerChoice
,
InputChoice
from
.nn
import
ModuleList
from
.utils
import
generate_new_label
from
.nn
import
ModuleList
# pylint: disable=no-name-in-module
from
.
mutation_
utils
import
generate_new_label
class
_ListIdentity
(
nn
.
Identity
):
...
...
nni/retiarii/nn/pytorch/component.py
View file @
8a1fdd53
...
...
@@ -10,7 +10,7 @@ from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from
.api
import
LayerChoice
from
.cell
import
Cell
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_value
from
.
mutation_
utils
import
Mutable
,
generate_new_label
,
get_fixed_value
__all__
=
[
'Repeat'
,
'Cell'
,
'NasBench101Cell'
,
'NasBench101Mutator'
,
'NasBench201Cell'
]
...
...
nni/retiarii/nn/pytorch/hypermodule.py
View file @
8a1fdd53
...
...
@@ -8,7 +8,7 @@ import torch.nn as nn
from
nni.retiarii.serializer
import
basic_unit
from
.api
import
LayerChoice
from
.utils
import
generate_new_label
from
.
mutation_
utils
import
generate_new_label
__all__
=
[
'AutoActivation'
]
...
...
nni/retiarii/nn/pytorch/utils.py
→
nni/retiarii/nn/pytorch/
mutation_
utils.py
View file @
8a1fdd53
File moved
nni/retiarii/nn/pytorch/nasbench101.py
View file @
8a1fdd53
...
...
@@ -9,7 +9,7 @@ import torch.nn as nn
from
nni.retiarii.mutator
import
InvalidMutation
,
Mutator
from
nni.retiarii.graph
import
Model
from
.api
import
InputChoice
,
ValueChoice
,
LayerChoice
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_dict
from
.
mutation_
utils
import
Mutable
,
generate_new_label
,
get_fixed_dict
_logger
=
logging
.
getLogger
(
__name__
)
...
...
nni/retiarii/nn/pytorch/nn.py
View file @
8a1fdd53
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
packaging.version
import
Version
import
inspect
import
warnings
from
pathlib
import
Path
import
torch
import
torch.nn
as
nn
from
...serializer
import
basic_unit
# NOTE: support pytorch version >= 1.5.0
__all__
=
[
'Module'
,
'Sequential'
,
'ModuleList'
,
# TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity'
,
'Linear'
,
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Threshold'
,
'ReLU'
,
'Hardtanh'
,
'ReLU6'
,
'Sigmoid'
,
'Tanh'
,
'Softmax'
,
'Softmax2d'
,
'LogSoftmax'
,
'ELU'
,
'SELU'
,
'CELU'
,
'GLU'
,
'GELU'
,
'Hardshrink'
,
'LeakyReLU'
,
'LogSigmoid'
,
'Softplus'
,
'Softshrink'
,
'MultiheadAttention'
,
'PReLU'
,
'Softsign'
,
'Softmin'
,
'Tanhshrink'
,
'RReLU'
,
'AvgPool1d'
,
'AvgPool2d'
,
'AvgPool3d'
,
'MaxPool1d'
,
'MaxPool2d'
,
'MaxPool3d'
,
'MaxUnpool1d'
,
'MaxUnpool2d'
,
'MaxUnpool3d'
,
'FractionalMaxPool2d'
,
"FractionalMaxPool3d"
,
'LPPool1d'
,
'LPPool2d'
,
'LocalResponseNorm'
,
'BatchNorm1d'
,
'BatchNorm2d'
,
'BatchNorm3d'
,
'InstanceNorm1d'
,
'InstanceNorm2d'
,
'InstanceNorm3d'
,
'LayerNorm'
,
'GroupNorm'
,
'SyncBatchNorm'
,
'Dropout'
,
'Dropout2d'
,
'Dropout3d'
,
'AlphaDropout'
,
'FeatureAlphaDropout'
,
'ReflectionPad1d'
,
'ReflectionPad2d'
,
'ReplicationPad2d'
,
'ReplicationPad1d'
,
'ReplicationPad3d'
,
'CrossMapLRN2d'
,
'Embedding'
,
'EmbeddingBag'
,
'RNNBase'
,
'RNN'
,
'LSTM'
,
'GRU'
,
'RNNCellBase'
,
'RNNCell'
,
'LSTMCell'
,
'GRUCell'
,
'PixelShuffle'
,
'Upsample'
,
'UpsamplingNearest2d'
,
'UpsamplingBilinear2d'
,
'PairwiseDistance'
,
'AdaptiveMaxPool1d'
,
'AdaptiveMaxPool2d'
,
'AdaptiveMaxPool3d'
,
'AdaptiveAvgPool1d'
,
'AdaptiveAvgPool2d'
,
'AdaptiveAvgPool3d'
,
'TripletMarginLoss'
,
'ZeroPad2d'
,
'ConstantPad1d'
,
'ConstantPad2d'
,
'ConstantPad3d'
,
'Bilinear'
,
'CosineSimilarity'
,
'Unfold'
,
'Fold'
,
'AdaptiveLogSoftmaxWithLoss'
,
'TransformerEncoder'
,
'TransformerDecoder'
,
'TransformerEncoderLayer'
,
'TransformerDecoderLayer'
,
'Transformer'
,
'Flatten'
,
'Hardsigmoid'
]
if
Version
(
torch
.
__version__
)
>=
Version
(
'1.6.0'
):
__all__
.
append
(
'Hardswish'
)
if
Version
(
torch
.
__version__
)
>=
Version
(
'1.7.0'
):
__all__
.
extend
([
'Unflatten'
,
'SiLU'
,
'TripletMarginWithDistanceLoss'
])
Module
=
nn
.
Module
Sequential
=
nn
.
Sequential
ModuleList
=
basic_unit
(
nn
.
ModuleList
,
basic_unit_tag
=
False
)
Identity
=
basic_unit
(
nn
.
Identity
)
Linear
=
basic_unit
(
nn
.
Linear
)
Conv1d
=
basic_unit
(
nn
.
Conv1d
)
Conv2d
=
basic_unit
(
nn
.
Conv2d
)
Conv3d
=
basic_unit
(
nn
.
Conv3d
)
ConvTranspose1d
=
basic_unit
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
basic_unit
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
basic_unit
(
nn
.
ConvTranspose3d
)
Threshold
=
basic_unit
(
nn
.
Threshold
)
ReLU
=
basic_unit
(
nn
.
ReLU
)
Hardtanh
=
basic_unit
(
nn
.
Hardtanh
)
ReLU6
=
basic_unit
(
nn
.
ReLU6
)
Sigmoid
=
basic_unit
(
nn
.
Sigmoid
)
Tanh
=
basic_unit
(
nn
.
Tanh
)
Softmax
=
basic_unit
(
nn
.
Softmax
)
Softmax2d
=
basic_unit
(
nn
.
Softmax2d
)
LogSoftmax
=
basic_unit
(
nn
.
LogSoftmax
)
ELU
=
basic_unit
(
nn
.
ELU
)
SELU
=
basic_unit
(
nn
.
SELU
)
CELU
=
basic_unit
(
nn
.
CELU
)
GLU
=
basic_unit
(
nn
.
GLU
)
GELU
=
basic_unit
(
nn
.
GELU
)
Hardshrink
=
basic_unit
(
nn
.
Hardshrink
)
LeakyReLU
=
basic_unit
(
nn
.
LeakyReLU
)
LogSigmoid
=
basic_unit
(
nn
.
LogSigmoid
)
Softplus
=
basic_unit
(
nn
.
Softplus
)
Softshrink
=
basic_unit
(
nn
.
Softshrink
)
MultiheadAttention
=
basic_unit
(
nn
.
MultiheadAttention
)
PReLU
=
basic_unit
(
nn
.
PReLU
)
Softsign
=
basic_unit
(
nn
.
Softsign
)
Softmin
=
basic_unit
(
nn
.
Softmin
)
Tanhshrink
=
basic_unit
(
nn
.
Tanhshrink
)
RReLU
=
basic_unit
(
nn
.
RReLU
)
AvgPool1d
=
basic_unit
(
nn
.
AvgPool1d
)
AvgPool2d
=
basic_unit
(
nn
.
AvgPool2d
)
AvgPool3d
=
basic_unit
(
nn
.
AvgPool3d
)
MaxPool1d
=
basic_unit
(
nn
.
MaxPool1d
)
MaxPool2d
=
basic_unit
(
nn
.
MaxPool2d
)
MaxPool3d
=
basic_unit
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
basic_unit
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
basic_unit
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
basic_unit
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
basic_unit
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
basic_unit
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
basic_unit
(
nn
.
LPPool1d
)
LPPool2d
=
basic_unit
(
nn
.
LPPool2d
)
LocalResponseNorm
=
basic_unit
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
basic_unit
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
basic_unit
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
basic_unit
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
basic_unit
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
basic_unit
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
basic_unit
(
nn
.
InstanceNorm3d
)
LayerNorm
=
basic_unit
(
nn
.
LayerNorm
)
GroupNorm
=
basic_unit
(
nn
.
GroupNorm
)
SyncBatchNorm
=
basic_unit
(
nn
.
SyncBatchNorm
)
Dropout
=
basic_unit
(
nn
.
Dropout
)
Dropout2d
=
basic_unit
(
nn
.
Dropout2d
)
Dropout3d
=
basic_unit
(
nn
.
Dropout3d
)
AlphaDropout
=
basic_unit
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
basic_unit
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
basic_unit
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
basic_unit
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
basic_unit
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
basic_unit
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
basic_unit
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
basic_unit
(
nn
.
CrossMapLRN2d
)
Embedding
=
basic_unit
(
nn
.
Embedding
)
EmbeddingBag
=
basic_unit
(
nn
.
EmbeddingBag
)
RNNBase
=
basic_unit
(
nn
.
RNNBase
)
RNN
=
basic_unit
(
nn
.
RNN
)
LSTM
=
basic_unit
(
nn
.
LSTM
)
GRU
=
basic_unit
(
nn
.
GRU
)
RNNCellBase
=
basic_unit
(
nn
.
RNNCellBase
)
RNNCell
=
basic_unit
(
nn
.
RNNCell
)
LSTMCell
=
basic_unit
(
nn
.
LSTMCell
)
GRUCell
=
basic_unit
(
nn
.
GRUCell
)
PixelShuffle
=
basic_unit
(
nn
.
PixelShuffle
)
Upsample
=
basic_unit
(
nn
.
Upsample
)
UpsamplingNearest2d
=
basic_unit
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
basic_unit
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
basic_unit
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
basic_unit
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
basic_unit
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
basic_unit
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
basic_unit
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
basic_unit
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
basic_unit
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
basic_unit
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
basic_unit
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
basic_unit
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
basic_unit
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
basic_unit
(
nn
.
ConstantPad3d
)
Bilinear
=
basic_unit
(
nn
.
Bilinear
)
CosineSimilarity
=
basic_unit
(
nn
.
CosineSimilarity
)
Unfold
=
basic_unit
(
nn
.
Unfold
)
Fold
=
basic_unit
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
basic_unit
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
basic_unit
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
basic_unit
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
basic_unit
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
basic_unit
(
nn
.
TransformerDecoderLayer
)
Transformer
=
basic_unit
(
nn
.
Transformer
)
Flatten
=
basic_unit
(
nn
.
Flatten
)
Hardsigmoid
=
basic_unit
(
nn
.
Hardsigmoid
)
if
Version
(
torch
.
__version__
)
>=
Version
(
'1.6.0'
):
Hardswish
=
basic_unit
(
nn
.
Hardswish
)
if
Version
(
torch
.
__version__
)
>=
Version
(
'1.7.0'
):
SiLU
=
basic_unit
(
nn
.
SiLU
)
Unflatten
=
basic_unit
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
basic_unit
(
nn
.
TripletMarginWithDistanceLoss
)
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
nn_cache_file_path
=
Path
(
__file__
).
parent
/
'_nn.py'
cache_valid
=
False
if
nn_cache_file_path
.
exists
():
from
.
import
_nn
# pylint: disable=no-name-in-module
# valid only when torch version match
if
_nn
.
_torch_version
==
torch
.
__version__
:
cache_valid
=
True
if
not
cache_valid
:
_NO_WRAP_CLASSES
=
[
# not an nn.Module
'Parameter'
,
'ParameterList'
,
'UninitializedBuffer'
,
'UninitializedParameter'
,
# arguments are special
'Module'
,
'Sequential'
,
# utilities
'Container'
,
'DataParallel'
,
]
_WRAP_WITHOUT_TAG_CLASSES
=
[
# special support on graph engine
'ModuleList'
,
'ModuleDict'
,
]
code
=
[
'# This file is auto-generated to make auto-completion work.'
,
'# When pytorch version does not match, it will get automatically updated.'
,
'# pylint: skip-file'
,
f
'_torch_version = "
{
torch
.
__version__
}
"'
,
'import torch.nn as nn'
,
'from nni.retiarii.serializer import basic_unit'
,
]
all_names
=
[]
# Add modules, classes, functions in torch.nn into this module.
for
name
,
obj
in
inspect
.
getmembers
(
torch
.
nn
):
if
inspect
.
isclass
(
obj
):
if
name
in
_NO_WRAP_CLASSES
:
code
.
append
(
f
'
{
name
}
= nn.
{
name
}
'
)
elif
not
issubclass
(
obj
,
nn
.
Module
):
# It should never go here
# We did it to play safe
warnings
.
warn
(
f
'
{
obj
}
is found to be not a nn.Module, which is unexpected. '
'It means your PyTorch version might not be supported.'
,
RuntimeWarning
)
code
.
append
(
f
'
{
name
}
= nn.
{
name
}
'
)
elif
name
in
_WRAP_WITHOUT_TAG_CLASSES
:
code
.
append
(
f
'
{
name
}
= basic_unit(nn.
{
name
}
, basic_unit_tag=False)'
)
else
:
code
.
append
(
f
'
{
name
}
= basic_unit(nn.
{
name
}
)'
)
all_names
.
append
(
name
)
elif
inspect
.
isfunction
(
obj
)
or
inspect
.
ismodule
(
obj
):
code
.
append
(
f
'
{
name
}
= nn.
{
name
}
'
)
# no modification
all_names
.
append
(
name
)
code
.
append
(
f
'__all__ =
{
all_names
}
'
)
with
nn_cache_file_path
.
open
(
'w'
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
(
code
))
# Import all modules from generated _nn.py
from
.
import
_nn
# pylint: disable=no-name-in-module
__all__
=
_nn
.
__all__
from
._nn
import
*
# pylint: disable=import-error, wildcard-import
nni/retiarii/serializer.py
View file @
8a1fdd53
...
...
@@ -170,7 +170,13 @@ def _torchscript_patch(cls) -> None:
cls
.
_get_nni_attr
=
torch
.
jit
.
ignore
(
cls
.
_get_nni_attr
)
if
hasattr
(
cls
,
'trace_symbol'
):
# these must all exist or all non-exist
cls
.
trace_symbol
=
torch
.
jit
.
unused
(
cls
.
trace_symbol
)
cls
.
trace_args
=
torch
.
jit
.
unused
(
cls
.
trace_args
)
cls
.
trace_kwargs
=
torch
.
jit
.
unused
(
cls
.
trace_kwargs
)
cls
.
trace_copy
=
torch
.
jit
.
ignore
(
cls
.
trace_copy
)
try
:
cls
.
trace_symbol
=
torch
.
jit
.
unused
(
cls
.
trace_symbol
)
cls
.
trace_args
=
torch
.
jit
.
unused
(
cls
.
trace_args
)
cls
.
trace_kwargs
=
torch
.
jit
.
unused
(
cls
.
trace_kwargs
)
cls
.
trace_copy
=
torch
.
jit
.
ignore
(
cls
.
trace_copy
)
except
AttributeError
as
e
:
if
'property'
in
str
(
e
):
raise
RuntimeError
(
'Trace on PyTorch module failed. Your PyTorch version might be outdated. '
'Please try to upgrade PyTorch.'
)
raise
test/ut/retiarii/test_highlevel_apis.py
View file @
8a1fdd53
...
...
@@ -1001,3 +1001,10 @@ class Shared(unittest.TestCase):
for
_
in
range
(
10
):
model
=
_apply_all_mutators
(
init_model
,
mutators
,
sampler
)
assert
(
model
.
evaluator
.
trace_kwargs
[
'x'
],
model
.
evaluator
.
trace_kwargs
[
'y'
])
in
[(
1
,
2
),
(
3
,
4
)]
def
test_retiarii_nn_import
(
self
):
dummy
=
torch
.
zeros
(
1
,
16
,
32
,
24
)
nn
.
init
.
uniform_
(
dummy
)
conv
=
nn
.
Conv2d
(
1
,
3
,
1
)
param
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
3
,
24
,
24
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment