Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
8a1fdd53
Unverified
Commit
8a1fdd53
authored
Mar 23, 2022
by
Yuge Zhang
Committed by
GitHub
Mar 23, 2022
Browse files
Remove PyTorch version larger equal in Retiarii (#4622)
parent
50ab44ba
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
111 additions
and
163 deletions
+111
-163
nni/retiarii/nn/pytorch/.gitignore
nni/retiarii/nn/pytorch/.gitignore
+1
-0
nni/retiarii/nn/pytorch/__init__.py
nni/retiarii/nn/pytorch/__init__.py
+4
-1
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+1
-1
nni/retiarii/nn/pytorch/cell.py
nni/retiarii/nn/pytorch/cell.py
+2
-2
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+1
-1
nni/retiarii/nn/pytorch/hypermodule.py
nni/retiarii/nn/pytorch/hypermodule.py
+1
-1
nni/retiarii/nn/pytorch/mutation_utils.py
nni/retiarii/nn/pytorch/mutation_utils.py
+0
-0
nni/retiarii/nn/pytorch/nasbench101.py
nni/retiarii/nn/pytorch/nasbench101.py
+1
-1
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+83
-152
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+10
-4
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+7
-0
No files found.
nni/retiarii/nn/pytorch/.gitignore
0 → 100644
View file @
8a1fdd53
_nn.py
nni/retiarii/nn/pytorch/__init__.py
View file @
8a1fdd53
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.api
import
*
from
.api
import
*
from
.component
import
*
from
.component
import
*
from
.nn
import
*
from
.nn
import
*
from
.hypermodule
import
*
from
.hypermodule
import
*
\ No newline at end of file
nni/retiarii/nn/pytorch/api.py
View file @
8a1fdd53
...
@@ -12,7 +12,7 @@ import torch.nn as nn
...
@@ -12,7 +12,7 @@ import torch.nn as nn
from
nni.common.serializer
import
Translatable
from
nni.common.serializer
import
Translatable
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_value
from
.
mutation_
utils
import
Mutable
,
generate_new_label
,
get_fixed_value
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'Placeholder'
,
'ChosenInputs'
]
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'Placeholder'
,
'ChosenInputs'
]
...
...
nni/retiarii/nn/pytorch/cell.py
View file @
8a1fdd53
...
@@ -10,8 +10,8 @@ import torch
...
@@ -10,8 +10,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.api
import
ChosenInputs
,
LayerChoice
,
InputChoice
from
.api
import
ChosenInputs
,
LayerChoice
,
InputChoice
from
.nn
import
ModuleList
from
.nn
import
ModuleList
# pylint: disable=no-name-in-module
from
.utils
import
generate_new_label
from
.
mutation_
utils
import
generate_new_label
class
_ListIdentity
(
nn
.
Identity
):
class
_ListIdentity
(
nn
.
Identity
):
...
...
nni/retiarii/nn/pytorch/component.py
View file @
8a1fdd53
...
@@ -10,7 +10,7 @@ from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
...
@@ -10,7 +10,7 @@ from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from
.api
import
LayerChoice
from
.api
import
LayerChoice
from
.cell
import
Cell
from
.cell
import
Cell
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_value
from
.
mutation_
utils
import
Mutable
,
generate_new_label
,
get_fixed_value
__all__
=
[
'Repeat'
,
'Cell'
,
'NasBench101Cell'
,
'NasBench101Mutator'
,
'NasBench201Cell'
]
__all__
=
[
'Repeat'
,
'Cell'
,
'NasBench101Cell'
,
'NasBench101Mutator'
,
'NasBench201Cell'
]
...
...
nni/retiarii/nn/pytorch/hypermodule.py
View file @
8a1fdd53
...
@@ -8,7 +8,7 @@ import torch.nn as nn
...
@@ -8,7 +8,7 @@ import torch.nn as nn
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.serializer
import
basic_unit
from
.api
import
LayerChoice
from
.api
import
LayerChoice
from
.utils
import
generate_new_label
from
.
mutation_
utils
import
generate_new_label
__all__
=
[
'AutoActivation'
]
__all__
=
[
'AutoActivation'
]
...
...
nni/retiarii/nn/pytorch/utils.py
→
nni/retiarii/nn/pytorch/
mutation_
utils.py
View file @
8a1fdd53
File moved
nni/retiarii/nn/pytorch/nasbench101.py
View file @
8a1fdd53
...
@@ -9,7 +9,7 @@ import torch.nn as nn
...
@@ -9,7 +9,7 @@ import torch.nn as nn
from
nni.retiarii.mutator
import
InvalidMutation
,
Mutator
from
nni.retiarii.mutator
import
InvalidMutation
,
Mutator
from
nni.retiarii.graph
import
Model
from
nni.retiarii.graph
import
Model
from
.api
import
InputChoice
,
ValueChoice
,
LayerChoice
from
.api
import
InputChoice
,
ValueChoice
,
LayerChoice
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_dict
from
.
mutation_
utils
import
Mutable
,
generate_new_label
,
get_fixed_dict
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
...
nni/retiarii/nn/pytorch/nn.py
View file @
8a1fdd53
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
packaging.version
import
Version
import
inspect
import
warnings
from
pathlib
import
Path
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
...serializer
import
basic_unit
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
# NOTE: support pytorch version >= 1.5.0
nn_cache_file_path
=
Path
(
__file__
).
parent
/
'_nn.py'
__all__
=
[
cache_valid
=
False
'Module'
,
'Sequential'
,
'ModuleList'
,
# TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity'
,
'Linear'
,
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
if
nn_cache_file_path
.
exists
():
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Threshold'
,
'ReLU'
,
'Hardtanh'
,
'ReLU6'
,
from
.
import
_nn
# pylint: disable=no-name-in-module
'Sigmoid'
,
'Tanh'
,
'Softmax'
,
'Softmax2d'
,
'LogSoftmax'
,
'ELU'
,
'SELU'
,
'CELU'
,
'GLU'
,
'GELU'
,
'Hardshrink'
,
# valid only when torch version match
'LeakyReLU'
,
'LogSigmoid'
,
'Softplus'
,
'Softshrink'
,
'MultiheadAttention'
,
'PReLU'
,
'Softsign'
,
'Softmin'
,
if
_nn
.
_torch_version
==
torch
.
__version__
:
'Tanhshrink'
,
'RReLU'
,
'AvgPool1d'
,
'AvgPool2d'
,
'AvgPool3d'
,
'MaxPool1d'
,
'MaxPool2d'
,
cache_valid
=
True
'MaxPool3d'
,
'MaxUnpool1d'
,
'MaxUnpool2d'
,
'MaxUnpool3d'
,
'FractionalMaxPool2d'
,
"FractionalMaxPool3d"
,
'LPPool1d'
,
'LPPool2d'
,
'LocalResponseNorm'
,
'BatchNorm1d'
,
'BatchNorm2d'
,
'BatchNorm3d'
,
'InstanceNorm1d'
,
if
not
cache_valid
:
'InstanceNorm2d'
,
'InstanceNorm3d'
,
'LayerNorm'
,
'GroupNorm'
,
'SyncBatchNorm'
,
_NO_WRAP_CLASSES
=
[
'Dropout'
,
'Dropout2d'
,
'Dropout3d'
,
'AlphaDropout'
,
'FeatureAlphaDropout'
,
# not an nn.Module
'ReflectionPad1d'
,
'ReflectionPad2d'
,
'ReplicationPad2d'
,
'ReplicationPad1d'
,
'ReplicationPad3d'
,
'Parameter'
,
'CrossMapLRN2d'
,
'Embedding'
,
'EmbeddingBag'
,
'RNNBase'
,
'RNN'
,
'LSTM'
,
'GRU'
,
'RNNCellBase'
,
'RNNCell'
,
'ParameterList'
,
'LSTMCell'
,
'GRUCell'
,
'PixelShuffle'
,
'Upsample'
,
'UpsamplingNearest2d'
,
'UpsamplingBilinear2d'
,
'UninitializedBuffer'
,
'PairwiseDistance'
,
'AdaptiveMaxPool1d'
,
'AdaptiveMaxPool2d'
,
'AdaptiveMaxPool3d'
,
'AdaptiveAvgPool1d'
,
'UninitializedParameter'
,
'AdaptiveAvgPool2d'
,
'AdaptiveAvgPool3d'
,
'TripletMarginLoss'
,
'ZeroPad2d'
,
'ConstantPad1d'
,
'ConstantPad2d'
,
'ConstantPad3d'
,
'Bilinear'
,
'CosineSimilarity'
,
'Unfold'
,
'Fold'
,
# arguments are special
'AdaptiveLogSoftmaxWithLoss'
,
'TransformerEncoder'
,
'TransformerDecoder'
,
'Module'
,
'TransformerEncoderLayer'
,
'TransformerDecoderLayer'
,
'Transformer'
,
'Sequential'
,
'Flatten'
,
'Hardsigmoid'
]
# utilities
'Container'
,
if
Version
(
torch
.
__version__
)
>=
Version
(
'1.6.0'
):
'DataParallel'
,
__all__
.
append
(
'Hardswish'
)
]
if
Version
(
torch
.
__version__
)
>=
Version
(
'1.7.0'
):
_WRAP_WITHOUT_TAG_CLASSES
=
[
__all__
.
extend
([
'Unflatten'
,
'SiLU'
,
'TripletMarginWithDistanceLoss'
])
# special support on graph engine
'ModuleList'
,
'ModuleDict'
,
Module
=
nn
.
Module
]
Sequential
=
nn
.
Sequential
code
=
[
ModuleList
=
basic_unit
(
nn
.
ModuleList
,
basic_unit_tag
=
False
)
'# This file is auto-generated to make auto-completion work.'
,
'# When pytorch version does not match, it will get automatically updated.'
,
Identity
=
basic_unit
(
nn
.
Identity
)
'# pylint: skip-file'
,
Linear
=
basic_unit
(
nn
.
Linear
)
f
'_torch_version = "
{
torch
.
__version__
}
"'
,
Conv1d
=
basic_unit
(
nn
.
Conv1d
)
'import torch.nn as nn'
,
Conv2d
=
basic_unit
(
nn
.
Conv2d
)
'from nni.retiarii.serializer import basic_unit'
,
Conv3d
=
basic_unit
(
nn
.
Conv3d
)
]
ConvTranspose1d
=
basic_unit
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
basic_unit
(
nn
.
ConvTranspose2d
)
all_names
=
[]
ConvTranspose3d
=
basic_unit
(
nn
.
ConvTranspose3d
)
Threshold
=
basic_unit
(
nn
.
Threshold
)
# Add modules, classes, functions in torch.nn into this module.
ReLU
=
basic_unit
(
nn
.
ReLU
)
for
name
,
obj
in
inspect
.
getmembers
(
torch
.
nn
):
Hardtanh
=
basic_unit
(
nn
.
Hardtanh
)
if
inspect
.
isclass
(
obj
):
ReLU6
=
basic_unit
(
nn
.
ReLU6
)
if
name
in
_NO_WRAP_CLASSES
:
Sigmoid
=
basic_unit
(
nn
.
Sigmoid
)
code
.
append
(
f
'
{
name
}
= nn.
{
name
}
'
)
Tanh
=
basic_unit
(
nn
.
Tanh
)
elif
not
issubclass
(
obj
,
nn
.
Module
):
Softmax
=
basic_unit
(
nn
.
Softmax
)
# It should never go here
Softmax2d
=
basic_unit
(
nn
.
Softmax2d
)
# We did it to play safe
LogSoftmax
=
basic_unit
(
nn
.
LogSoftmax
)
warnings
.
warn
(
f
'
{
obj
}
is found to be not a nn.Module, which is unexpected. '
ELU
=
basic_unit
(
nn
.
ELU
)
'It means your PyTorch version might not be supported.'
,
RuntimeWarning
)
SELU
=
basic_unit
(
nn
.
SELU
)
code
.
append
(
f
'
{
name
}
= nn.
{
name
}
'
)
CELU
=
basic_unit
(
nn
.
CELU
)
elif
name
in
_WRAP_WITHOUT_TAG_CLASSES
:
GLU
=
basic_unit
(
nn
.
GLU
)
code
.
append
(
f
'
{
name
}
= basic_unit(nn.
{
name
}
, basic_unit_tag=False)'
)
GELU
=
basic_unit
(
nn
.
GELU
)
else
:
Hardshrink
=
basic_unit
(
nn
.
Hardshrink
)
code
.
append
(
f
'
{
name
}
= basic_unit(nn.
{
name
}
)'
)
LeakyReLU
=
basic_unit
(
nn
.
LeakyReLU
)
LogSigmoid
=
basic_unit
(
nn
.
LogSigmoid
)
all_names
.
append
(
name
)
Softplus
=
basic_unit
(
nn
.
Softplus
)
Softshrink
=
basic_unit
(
nn
.
Softshrink
)
elif
inspect
.
isfunction
(
obj
)
or
inspect
.
ismodule
(
obj
):
MultiheadAttention
=
basic_unit
(
nn
.
MultiheadAttention
)
code
.
append
(
f
'
{
name
}
= nn.
{
name
}
'
)
# no modification
PReLU
=
basic_unit
(
nn
.
PReLU
)
all_names
.
append
(
name
)
Softsign
=
basic_unit
(
nn
.
Softsign
)
Softmin
=
basic_unit
(
nn
.
Softmin
)
code
.
append
(
f
'__all__ =
{
all_names
}
'
)
Tanhshrink
=
basic_unit
(
nn
.
Tanhshrink
)
RReLU
=
basic_unit
(
nn
.
RReLU
)
with
nn_cache_file_path
.
open
(
'w'
)
as
fp
:
AvgPool1d
=
basic_unit
(
nn
.
AvgPool1d
)
fp
.
write
(
'
\n
'
.
join
(
code
))
AvgPool2d
=
basic_unit
(
nn
.
AvgPool2d
)
AvgPool3d
=
basic_unit
(
nn
.
AvgPool3d
)
MaxPool1d
=
basic_unit
(
nn
.
MaxPool1d
)
# Import all modules from generated _nn.py
MaxPool2d
=
basic_unit
(
nn
.
MaxPool2d
)
MaxPool3d
=
basic_unit
(
nn
.
MaxPool3d
)
from
.
import
_nn
# pylint: disable=no-name-in-module
MaxUnpool1d
=
basic_unit
(
nn
.
MaxUnpool1d
)
__all__
=
_nn
.
__all__
MaxUnpool2d
=
basic_unit
(
nn
.
MaxUnpool2d
)
from
._nn
import
*
# pylint: disable=import-error, wildcard-import
MaxUnpool3d
=
basic_unit
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
basic_unit
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
basic_unit
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
basic_unit
(
nn
.
LPPool1d
)
LPPool2d
=
basic_unit
(
nn
.
LPPool2d
)
LocalResponseNorm
=
basic_unit
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
basic_unit
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
basic_unit
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
basic_unit
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
basic_unit
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
basic_unit
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
basic_unit
(
nn
.
InstanceNorm3d
)
LayerNorm
=
basic_unit
(
nn
.
LayerNorm
)
GroupNorm
=
basic_unit
(
nn
.
GroupNorm
)
SyncBatchNorm
=
basic_unit
(
nn
.
SyncBatchNorm
)
Dropout
=
basic_unit
(
nn
.
Dropout
)
Dropout2d
=
basic_unit
(
nn
.
Dropout2d
)
Dropout3d
=
basic_unit
(
nn
.
Dropout3d
)
AlphaDropout
=
basic_unit
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
basic_unit
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
basic_unit
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
basic_unit
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
basic_unit
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
basic_unit
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
basic_unit
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
basic_unit
(
nn
.
CrossMapLRN2d
)
Embedding
=
basic_unit
(
nn
.
Embedding
)
EmbeddingBag
=
basic_unit
(
nn
.
EmbeddingBag
)
RNNBase
=
basic_unit
(
nn
.
RNNBase
)
RNN
=
basic_unit
(
nn
.
RNN
)
LSTM
=
basic_unit
(
nn
.
LSTM
)
GRU
=
basic_unit
(
nn
.
GRU
)
RNNCellBase
=
basic_unit
(
nn
.
RNNCellBase
)
RNNCell
=
basic_unit
(
nn
.
RNNCell
)
LSTMCell
=
basic_unit
(
nn
.
LSTMCell
)
GRUCell
=
basic_unit
(
nn
.
GRUCell
)
PixelShuffle
=
basic_unit
(
nn
.
PixelShuffle
)
Upsample
=
basic_unit
(
nn
.
Upsample
)
UpsamplingNearest2d
=
basic_unit
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
basic_unit
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
basic_unit
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
basic_unit
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
basic_unit
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
basic_unit
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
basic_unit
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
basic_unit
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
basic_unit
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
basic_unit
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
basic_unit
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
basic_unit
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
basic_unit
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
basic_unit
(
nn
.
ConstantPad3d
)
Bilinear
=
basic_unit
(
nn
.
Bilinear
)
CosineSimilarity
=
basic_unit
(
nn
.
CosineSimilarity
)
Unfold
=
basic_unit
(
nn
.
Unfold
)
Fold
=
basic_unit
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
basic_unit
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
basic_unit
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
basic_unit
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
basic_unit
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
basic_unit
(
nn
.
TransformerDecoderLayer
)
Transformer
=
basic_unit
(
nn
.
Transformer
)
Flatten
=
basic_unit
(
nn
.
Flatten
)
Hardsigmoid
=
basic_unit
(
nn
.
Hardsigmoid
)
if
Version
(
torch
.
__version__
)
>=
Version
(
'1.6.0'
):
Hardswish
=
basic_unit
(
nn
.
Hardswish
)
if
Version
(
torch
.
__version__
)
>=
Version
(
'1.7.0'
):
SiLU
=
basic_unit
(
nn
.
SiLU
)
Unflatten
=
basic_unit
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
basic_unit
(
nn
.
TripletMarginWithDistanceLoss
)
nni/retiarii/serializer.py
View file @
8a1fdd53
...
@@ -170,7 +170,13 @@ def _torchscript_patch(cls) -> None:
...
@@ -170,7 +170,13 @@ def _torchscript_patch(cls) -> None:
cls
.
_get_nni_attr
=
torch
.
jit
.
ignore
(
cls
.
_get_nni_attr
)
cls
.
_get_nni_attr
=
torch
.
jit
.
ignore
(
cls
.
_get_nni_attr
)
if
hasattr
(
cls
,
'trace_symbol'
):
if
hasattr
(
cls
,
'trace_symbol'
):
# these must all exist or all non-exist
# these must all exist or all non-exist
cls
.
trace_symbol
=
torch
.
jit
.
unused
(
cls
.
trace_symbol
)
try
:
cls
.
trace_args
=
torch
.
jit
.
unused
(
cls
.
trace_args
)
cls
.
trace_symbol
=
torch
.
jit
.
unused
(
cls
.
trace_symbol
)
cls
.
trace_kwargs
=
torch
.
jit
.
unused
(
cls
.
trace_kwargs
)
cls
.
trace_args
=
torch
.
jit
.
unused
(
cls
.
trace_args
)
cls
.
trace_copy
=
torch
.
jit
.
ignore
(
cls
.
trace_copy
)
cls
.
trace_kwargs
=
torch
.
jit
.
unused
(
cls
.
trace_kwargs
)
cls
.
trace_copy
=
torch
.
jit
.
ignore
(
cls
.
trace_copy
)
except
AttributeError
as
e
:
if
'property'
in
str
(
e
):
raise
RuntimeError
(
'Trace on PyTorch module failed. Your PyTorch version might be outdated. '
'Please try to upgrade PyTorch.'
)
raise
test/ut/retiarii/test_highlevel_apis.py
View file @
8a1fdd53
...
@@ -1001,3 +1001,10 @@ class Shared(unittest.TestCase):
...
@@ -1001,3 +1001,10 @@ class Shared(unittest.TestCase):
for
_
in
range
(
10
):
for
_
in
range
(
10
):
model
=
_apply_all_mutators
(
init_model
,
mutators
,
sampler
)
model
=
_apply_all_mutators
(
init_model
,
mutators
,
sampler
)
assert
(
model
.
evaluator
.
trace_kwargs
[
'x'
],
model
.
evaluator
.
trace_kwargs
[
'y'
])
in
[(
1
,
2
),
(
3
,
4
)]
assert
(
model
.
evaluator
.
trace_kwargs
[
'x'
],
model
.
evaluator
.
trace_kwargs
[
'y'
])
in
[(
1
,
2
),
(
3
,
4
)]
def
test_retiarii_nn_import
(
self
):
dummy
=
torch
.
zeros
(
1
,
16
,
32
,
24
)
nn
.
init
.
uniform_
(
dummy
)
conv
=
nn
.
Conv2d
(
1
,
3
,
1
)
param
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
3
,
24
,
24
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment