Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
cae4308f
Unverified
Commit
cae4308f
authored
Mar 03, 2021
by
Yuge Zhang
Committed by
GitHub
Mar 03, 2021
Browse files
[Retiarii] Rename APIs and refine documentation (#3404)
parent
d047d6f4
Changes
59
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
67 additions
and
88 deletions
+67
-88
test/retiarii_test/mnasnet/test.py
test/retiarii_test/mnasnet/test.py
+5
-7
test/retiarii_test/mnist/test.py
test/retiarii_test/mnist/test.py
+4
-4
test/ut/retiarii/converted_mnist_pytorch.json
test/ut/retiarii/converted_mnist_pytorch.json
+1
-1
test/ut/retiarii/imported/model.py
test/ut/retiarii/imported/model.py
+2
-2
test/ut/retiarii/inject_nn.py
test/ut/retiarii/inject_nn.py
+2
-18
test/ut/retiarii/mnist-tensorflow.json
test/ut/retiarii/mnist-tensorflow.json
+1
-1
test/ut/retiarii/mnist_pytorch.json
test/ut/retiarii/mnist_pytorch.json
+1
-1
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+1
-1
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+17
-15
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+1
-2
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+0
-2
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+1
-2
test/ut/retiarii/test_dedup_input.py
test/ut/retiarii/test_dedup_input.py
+1
-2
test/ut/retiarii/test_engine.py
test/ut/retiarii/test_engine.py
+1
-1
test/ut/retiarii/test_graph.py
test/ut/retiarii/test_graph.py
+1
-1
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+2
-2
test/ut/retiarii/test_lightning_trainer.py
test/ut/retiarii/test_lightning_trainer.py
+7
-7
test/ut/retiarii/test_serializer.py
test/ut/retiarii/test_serializer.py
+17
-17
test/ut/retiarii/test_strategy.py
test/ut/retiarii/test_strategy.py
+2
-2
No files found.
test/retiarii_test/mnasnet/test.py
View file @
cae4308f
...
...
@@ -3,10 +3,8 @@ import sys
import
torch
from
pathlib
import
Path
from
nni.retiarii.trainer.pytorch
import
PyTorchImageClassificationTrainer
import
nni.retiarii.trainer.pytorch.lightning
as
pl
from
nni.retiarii
import
blackbox_module
as
bm
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
from
nni.retiarii
import
serialize
from
base_mnasnet
import
MNASNet
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.strategy
import
TPEStrategy
...
...
@@ -35,8 +33,8 @@ if __name__ == '__main__':
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
])
train_dataset
=
bm
(
CIFAR10
)(
root
=
'data/cifar10'
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
test_dataset
=
bm
(
CIFAR10
)(
root
=
'data/cifar10'
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
train_dataset
=
serialize
(
CIFAR10
,
root
=
'data/cifar10'
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
test_dataset
=
serialize
(
CIFAR10
,
root
=
'data/cifar10'
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
1
,
limit_train_batches
=
0.2
)
...
...
@@ -56,4 +54,4 @@ if __name__ == '__main__':
exp_config
.
max_trial_number
=
10
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
80
81
)
exp
.
run
(
exp_config
,
80
97
)
test/retiarii_test/mnist/test.py
View file @
cae4308f
...
...
@@ -2,9 +2,9 @@ import random
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.strategy
as
strategy
import
nni.retiarii.
traine
r.pytorch.lightning
as
pl
import
nni.retiarii.
evaluato
r.pytorch.lightning
as
pl
import
torch.nn.functional
as
F
from
nni.retiarii
import
blackbox_module
as
bm
from
nni.retiarii
import
serialize
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
...
...
@@ -36,8 +36,8 @@ class Net(nn.Module):
if
__name__
==
'__main__'
:
base_model
=
Net
(
128
)
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
)
...
...
test/ut/retiarii/converted_mnist_pytorch.json
View file @
cae4308f
...
...
@@ -340,7 +340,7 @@
}
]
},
"_
training_config
"
:
{
"_
evaluator
"
:
{
"module"
:
"nni.retiarii.trainer.PyTorchImageClassificationTrainer"
,
"kwargs"
:
{
"dataset_cls"
:
"MNIST"
,
...
...
test/ut/retiarii/imported/model.py
View file @
cae4308f
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
b
lackbox_module
from
nni.retiarii
import
b
asic_unit
@
b
lackbox_module
@
b
asic_unit
class
ImportTest
(
nn
.
Module
):
def
__init__
(
self
,
foo
,
bar
):
super
().
__init__
()
...
...
test/ut/retiarii/inject_nn.py
View file @
cae4308f
...
...
@@ -4,7 +4,7 @@ import logging
import
torch
import
torch.nn
as
nn
from
nni.retiarii.utils
import
add_record
,
del_record
,
version_larger_equal
from
nni.retiarii.utils
import
version_larger_equal
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -13,39 +13,23 @@ def wrap_module(original_class):
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
original_class
.
bak_init_for_inject
=
orig_init
if
hasattr
(
original_class
,
'__del__'
):
orig_del
=
original_class
.
__del__
original_class
.
bak_del_for_inject
=
orig_del
else
:
orig_del
=
None
original_class
.
bak_del_for_inject
=
None
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
self
.
_init_parameters
=
full_args
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
def
__del__
(
self
):
del_record
(
id
(
self
))
if
orig_del
is
not
None
:
orig_del
(
self
)
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__del__
=
__del__
return
original_class
def
unwrap_module
(
wrapped_class
):
if
hasattr
(
wrapped_class
,
'bak_init_for_inject'
):
wrapped_class
.
__init__
=
wrapped_class
.
bak_init_for_inject
delattr
(
wrapped_class
,
'bak_init_for_inject'
)
if
hasattr
(
wrapped_class
,
'bak_del_for_inject'
):
if
wrapped_class
.
bak_del_for_inject
is
not
None
:
wrapped_class
.
__del__
=
wrapped_class
.
bak_del_for_inject
delattr
(
wrapped_class
,
'bak_del_for_inject'
)
return
None
def
remove_inject_pytorch_nn
():
...
...
test/ut/retiarii/mnist-tensorflow.json
View file @
cae4308f
...
...
@@ -38,7 +38,7 @@
]
},
"_
training_config
"
:
{
"_
evaluator
"
:
{
"__type__"
:
"_debug_no_trainer"
}
}
test/ut/retiarii/mnist_pytorch.json
View file @
cae4308f
...
...
@@ -38,7 +38,7 @@
]
},
"_
training_config
"
:
{
"_
evaluator
"
:
{
"module"
:
"nni.retiarii.trainer.PyTorchImageClassificationTrainer"
,
"kwargs"
:
{
"dataset_cls"
:
"MNIST"
,
...
...
test/ut/retiarii/test_cgo_engine.py
View file @
cae4308f
...
...
@@ -18,7 +18,7 @@ from nni.retiarii import Model, Node
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.
traine
r.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.
evaluato
r.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.utils
import
import_
...
...
test/ut/retiarii/test_convert.py
View file @
cae4308f
...
...
@@ -12,10 +12,9 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
b
lackbox_module
from
nni.retiarii
import
b
asic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
class
MnistNet
(
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -35,8 +34,8 @@ class MnistNet(nn.Module):
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
# NOTE:
blackbox
module cannot be placed within class or function
@
b
lackbox_module
# NOTE:
serialize
module cannot be placed within class or function
@
b
asic_unit
class
Linear
(
nn
.
Module
):
def
__init__
(
self
,
d_embed
,
d_proj
):
super
().
__init__
()
...
...
@@ -66,9 +65,6 @@ class TestConvert(unittest.TestCase):
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
from
.inject_nn
import
remove_inject_pytorch_nn
remove_inject_pytorch_nn
()
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
...
...
@@ -458,9 +454,12 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
VAE
().
eval
(),
(
torch
.
rand
(
128
,
1
,
28
,
28
),))
def
test_torchvision_resnet18
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
from
.inject_nn
import
inject_pytorch_nn
,
remove_inject_pytorch_nn
try
:
inject_pytorch_nn
()
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
finally
:
remove_inject_pytorch_nn
()
def
test_resnet
(
self
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
...
...
@@ -572,8 +571,11 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
resnet18
,
(
torch
.
randn
(
1
,
3
,
224
,
224
),))
def
test_alexnet
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
model
=
torchvision
.
models
.
AlexNet
()
self
.
checkExportImport
(
model
,
(
x
,))
from
.inject_nn
import
inject_pytorch_nn
,
remove_inject_pytorch_nn
try
:
inject_pytorch_nn
()
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
model
=
torchvision
.
models
.
AlexNet
()
self
.
checkExportImport
(
model
,
(
x
,))
finally
:
remove_inject_pytorch_nn
()
test/ut/retiarii/test_convert_basic.py
View file @
cae4308f
...
...
@@ -8,10 +8,9 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
b
lackbox_module
from
nni.retiarii
import
b
asic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
# following pytorch v1.7.1
...
...
test/ut/retiarii/test_convert_operators.py
View file @
cae4308f
...
...
@@ -15,10 +15,8 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_module
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
# following pytorch v1.7.1
...
...
test/ut/retiarii/test_convert_pytorch.py
View file @
cae4308f
...
...
@@ -14,10 +14,9 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_modul
e
from
nni.retiarii
import
serializ
e
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
class
TestPytorch
(
unittest
.
TestCase
):
...
...
test/ut/retiarii/test_dedup_input.py
View file @
cae4308f
...
...
@@ -17,7 +17,6 @@ from nni.retiarii import Model, Node
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.trainer.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.utils
import
import_
...
...
@@ -74,7 +73,7 @@ class DedupInputTest(unittest.TestCase):
# sys.path.insert(0, 'generated')
# multi_model = import_('debug_dedup_input.logical_0')
# trainer = PyTorchMultiModelTrainer(
# multi_model(), phy_models[0][0].
training_config
.kwargs
# multi_model(), phy_models[0][0].
evaluator
.kwargs
# )
# trainer.fit()
...
...
test/ut/retiarii/test_engine.py
View file @
cae4308f
...
...
@@ -9,7 +9,7 @@ import nni
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.integration
import
RetiariiAdvisor
,
register_advisor
from
nni.retiarii.
traine
r.pytorch
import
PyTorchImageClassificationTrainer
from
nni.retiarii.
evaluato
r.pytorch
import
PyTorchImageClassificationTrainer
from
nni.retiarii.utils
import
import_
...
...
test/ut/retiarii/test_graph.py
View file @
cae4308f
...
...
@@ -23,7 +23,7 @@ def _test_file(json_path):
# add default values to JSON, so we can compare with `==`
for
graph_name
,
graph
in
orig_ir
.
items
():
if
graph_name
==
'_
training_config
'
:
if
graph_name
==
'_
evaluator
'
:
continue
if
'inputs'
not
in
graph
:
graph
[
'inputs'
]
=
None
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
cae4308f
...
...
@@ -4,7 +4,7 @@ import unittest
import
nni.retiarii.nn.pytorch
as
nn
import
torch
import
torch.nn.functional
as
F
from
nni.retiarii
import
Sampler
,
b
lackbox_module
from
nni.retiarii
import
Sampler
,
b
asic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.nn.pytorch.mutator
import
process_inline_mutation
...
...
@@ -29,7 +29,7 @@ class RandomSampler(Sampler):
return
random
.
choice
(
candidates
)
@
b
lackbox_module
@
b
asic_unit
class
MutableConv
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
...
...
test/ut/retiarii/test_lightning_trainer.py
View file @
cae4308f
...
...
@@ -2,13 +2,13 @@ import json
import
pytest
import
nni
import
nni.retiarii.
traine
r.pytorch.lightning
as
pl
import
nni.retiarii.
evaluato
r.pytorch.lightning
as
pl
import
pytorch_lightning
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.retiarii
import
blackbox_module
as
bm
from
nni.retiarii.
traine
r
import
Functional
Traine
r
from
nni.retiarii
import
serialize_cls
,
serialize
from
nni.retiarii.
evaluato
r
import
Functional
Evaluato
r
from
sklearn.datasets
import
load_diabetes
from
torch.utils.data
import
Dataset
from
torchvision
import
transforms
...
...
@@ -49,7 +49,7 @@ class FCNet(nn.Module):
return
output
.
view
(
-
1
)
@
bm
@
serialize_cls
class
DiabetesDataset
(
Dataset
):
def
__init__
(
self
,
train
=
True
):
data
=
load_diabetes
()
...
...
@@ -91,8 +91,8 @@ def _reset():
def
test_mnist
():
_reset
()
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
,
limit_train_batches
=
0.25
,
# for faster training
...
...
@@ -121,7 +121,7 @@ def test_diabetes():
@
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs.'
)
def
test_functional
():
Functional
Traine
r
(
_foo
).
_execute
(
MNISTModel
)
Functional
Evaluato
r
(
_foo
).
_execute
(
MNISTModel
)
if
__name__
==
'__main__'
:
...
...
test/ut/retiarii/test_serializer.py
View file @
cae4308f
...
...
@@ -4,7 +4,7 @@ import re
import
sys
import
torch
from
nni.retiarii
import
json_dumps
,
json_loads
,
blackbox
from
nni.retiarii
import
json_dumps
,
json_loads
,
serialize
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
...
...
@@ -23,30 +23,30 @@ class Foo:
return
self
.
aa
==
other
.
aa
and
self
.
bb
==
other
.
bb
def
test_
blackbox
():
module
=
blackbox
(
Foo
,
3
)
def
test_
serialize
():
module
=
serialize
(
Foo
,
3
)
assert
json_loads
(
json_dumps
(
module
))
==
module
module
=
blackbox
(
Foo
,
b
=
2
,
a
=
1
)
module
=
serialize
(
Foo
,
b
=
2
,
a
=
1
)
assert
json_loads
(
json_dumps
(
module
))
==
module
module
=
blackbox
(
Foo
,
Foo
(
1
),
5
)
module
=
serialize
(
Foo
,
Foo
(
1
),
5
)
dumped_module
=
json_dumps
(
module
)
assert
len
(
dumped_module
)
>
200
# should not be too longer if the serialization is correct
module
=
blackbox
(
Foo
,
blackbox
(
Foo
,
1
),
5
)
module
=
serialize
(
Foo
,
serialize
(
Foo
,
1
),
5
)
dumped_module
=
json_dumps
(
module
)
assert
len
(
dumped_module
)
<
200
# should not be too longer if the serialization is correct
assert
json_loads
(
dumped_module
)
==
module
def
test_b
lackbox_module
():
def
test_b
asic_unit
():
module
=
ImportTest
(
3
,
0.5
)
assert
json_loads
(
json_dumps
(
module
))
==
module
def
test_dataset
():
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
)
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
)
dataloader
=
serialize
(
DataLoader
,
dataset
,
batch_size
=
10
)
dumped_ans
=
{
"__type__"
:
"torch.utils.data.dataloader.DataLoader"
,
...
...
@@ -62,19 +62,19 @@ def test_dataset():
dataloader
=
json_loads
(
json_dumps
(
dumped_ans
))
assert
isinstance
(
dataloader
,
DataLoader
)
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
blackbox
(
dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
serialize
(
transforms
.
Compose
,
[
blackbox
(
transforms
.
ToTensor
),
blackbox
(
transforms
.
Normalize
,
(
0.1307
,),
(
0.3081
,))]
[
serialize
(
transforms
.
ToTensor
),
serialize
(
transforms
.
Normalize
,
(
0.1307
,),
(
0.3081
,))]
))
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
dataloader
=
serialize
(
DataLoader
,
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))]))
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
dataloader
=
serialize
(
DataLoader
,
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
...
...
@@ -87,7 +87,7 @@ def test_type():
if
__name__
==
'__main__'
:
test_
blackbox
()
test_b
lackbox_module
()
test_
serialize
()
test_b
asic_unit
()
test_dataset
()
test_type
()
test/ut/retiarii/test_strategy.py
View file @
cae4308f
...
...
@@ -12,7 +12,7 @@ from nni.retiarii import Model
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.execution
import
wait_models
from
nni.retiarii.execution.interface
import
AbstractExecutionEngine
,
WorkerInfo
,
MetricData
,
AbstractGraphListener
from
nni.retiarii.graph
import
Debug
Training
,
ModelStatus
from
nni.retiarii.graph
import
Debug
Evaluator
,
ModelStatus
from
nni.retiarii.nn.pytorch.mutator
import
process_inline_mutation
...
...
@@ -80,7 +80,7 @@ def _get_model_and_mutators():
base_model
=
Net
()
script_module
=
torch
.
jit
.
script
(
base_model
)
base_model_ir
=
convert_to_graph
(
script_module
,
base_model
)
base_model_ir
.
training_config
=
DebugTraining
()
base_model_ir
.
evaluator
=
DebugEvaluator
()
mutators
=
process_inline_mutation
(
base_model_ir
)
return
base_model_ir
,
mutators
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment