Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
cae4308f
Unverified
Commit
cae4308f
authored
Mar 03, 2021
by
Yuge Zhang
Committed by
GitHub
Mar 03, 2021
Browse files
[Retiarii] Rename APIs and refine documentation (#3404)
parent
d047d6f4
Changes
59
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
67 additions
and
88 deletions
+67
-88
test/retiarii_test/mnasnet/test.py
test/retiarii_test/mnasnet/test.py
+5
-7
test/retiarii_test/mnist/test.py
test/retiarii_test/mnist/test.py
+4
-4
test/ut/retiarii/converted_mnist_pytorch.json
test/ut/retiarii/converted_mnist_pytorch.json
+1
-1
test/ut/retiarii/imported/model.py
test/ut/retiarii/imported/model.py
+2
-2
test/ut/retiarii/inject_nn.py
test/ut/retiarii/inject_nn.py
+2
-18
test/ut/retiarii/mnist-tensorflow.json
test/ut/retiarii/mnist-tensorflow.json
+1
-1
test/ut/retiarii/mnist_pytorch.json
test/ut/retiarii/mnist_pytorch.json
+1
-1
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+1
-1
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+17
-15
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+1
-2
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+0
-2
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+1
-2
test/ut/retiarii/test_dedup_input.py
test/ut/retiarii/test_dedup_input.py
+1
-2
test/ut/retiarii/test_engine.py
test/ut/retiarii/test_engine.py
+1
-1
test/ut/retiarii/test_graph.py
test/ut/retiarii/test_graph.py
+1
-1
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+2
-2
test/ut/retiarii/test_lightning_trainer.py
test/ut/retiarii/test_lightning_trainer.py
+7
-7
test/ut/retiarii/test_serializer.py
test/ut/retiarii/test_serializer.py
+17
-17
test/ut/retiarii/test_strategy.py
test/ut/retiarii/test_strategy.py
+2
-2
No files found.
test/retiarii_test/mnasnet/test.py
View file @
cae4308f
...
...
@@ -3,10 +3,8 @@ import sys
import
torch
from
pathlib
import
Path
from
nni.retiarii.trainer.pytorch
import
PyTorchImageClassificationTrainer
import
nni.retiarii.trainer.pytorch.lightning
as
pl
from
nni.retiarii
import
blackbox_module
as
bm
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
from
nni.retiarii
import
serialize
from
base_mnasnet
import
MNASNet
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.strategy
import
TPEStrategy
...
...
@@ -35,8 +33,8 @@ if __name__ == '__main__':
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
])
train_dataset
=
bm
(
CIFAR10
)(
root
=
'data/cifar10'
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
test_dataset
=
bm
(
CIFAR10
)(
root
=
'data/cifar10'
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
train_dataset
=
serialize
(
CIFAR10
,
root
=
'data/cifar10'
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
test_dataset
=
serialize
(
CIFAR10
,
root
=
'data/cifar10'
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
1
,
limit_train_batches
=
0.2
)
...
...
@@ -56,4 +54,4 @@ if __name__ == '__main__':
exp_config
.
max_trial_number
=
10
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
80
81
)
exp
.
run
(
exp_config
,
80
97
)
test/retiarii_test/mnist/test.py
View file @
cae4308f
...
...
@@ -2,9 +2,9 @@ import random
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.strategy
as
strategy
import
nni.retiarii.
traine
r.pytorch.lightning
as
pl
import
nni.retiarii.
evaluato
r.pytorch.lightning
as
pl
import
torch.nn.functional
as
F
from
nni.retiarii
import
blackbox_module
as
bm
from
nni.retiarii
import
serialize
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
...
...
@@ -36,8 +36,8 @@ class Net(nn.Module):
if
__name__
==
'__main__'
:
base_model
=
Net
(
128
)
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
)
...
...
test/ut/retiarii/converted_mnist_pytorch.json
View file @
cae4308f
...
...
@@ -340,7 +340,7 @@
}
]
},
"_
training_config
"
:
{
"_
evaluator
"
:
{
"module"
:
"nni.retiarii.trainer.PyTorchImageClassificationTrainer"
,
"kwargs"
:
{
"dataset_cls"
:
"MNIST"
,
...
...
test/ut/retiarii/imported/model.py
View file @
cae4308f
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
b
lackbox_module
from
nni.retiarii
import
b
asic_unit
@
b
lackbox_module
@
b
asic_unit
class
ImportTest
(
nn
.
Module
):
def
__init__
(
self
,
foo
,
bar
):
super
().
__init__
()
...
...
test/ut/retiarii/inject_nn.py
View file @
cae4308f
...
...
@@ -4,7 +4,7 @@ import logging
import
torch
import
torch.nn
as
nn
from
nni.retiarii.utils
import
add_record
,
del_record
,
version_larger_equal
from
nni.retiarii.utils
import
version_larger_equal
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -13,39 +13,23 @@ def wrap_module(original_class):
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
original_class
.
bak_init_for_inject
=
orig_init
if
hasattr
(
original_class
,
'__del__'
):
orig_del
=
original_class
.
__del__
original_class
.
bak_del_for_inject
=
orig_del
else
:
orig_del
=
None
original_class
.
bak_del_for_inject
=
None
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
self
.
_init_parameters
=
full_args
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
def
__del__
(
self
):
del_record
(
id
(
self
))
if
orig_del
is
not
None
:
orig_del
(
self
)
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__del__
=
__del__
return
original_class
def
unwrap_module
(
wrapped_class
):
if
hasattr
(
wrapped_class
,
'bak_init_for_inject'
):
wrapped_class
.
__init__
=
wrapped_class
.
bak_init_for_inject
delattr
(
wrapped_class
,
'bak_init_for_inject'
)
if
hasattr
(
wrapped_class
,
'bak_del_for_inject'
):
if
wrapped_class
.
bak_del_for_inject
is
not
None
:
wrapped_class
.
__del__
=
wrapped_class
.
bak_del_for_inject
delattr
(
wrapped_class
,
'bak_del_for_inject'
)
return
None
def
remove_inject_pytorch_nn
():
...
...
test/ut/retiarii/mnist-tensorflow.json
View file @
cae4308f
...
...
@@ -38,7 +38,7 @@
]
},
"_
training_config
"
:
{
"_
evaluator
"
:
{
"__type__"
:
"_debug_no_trainer"
}
}
test/ut/retiarii/mnist_pytorch.json
View file @
cae4308f
...
...
@@ -38,7 +38,7 @@
]
},
"_
training_config
"
:
{
"_
evaluator
"
:
{
"module"
:
"nni.retiarii.trainer.PyTorchImageClassificationTrainer"
,
"kwargs"
:
{
"dataset_cls"
:
"MNIST"
,
...
...
test/ut/retiarii/test_cgo_engine.py
View file @
cae4308f
...
...
@@ -18,7 +18,7 @@ from nni.retiarii import Model, Node
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.
traine
r.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.
evaluato
r.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.utils
import
import_
...
...
test/ut/retiarii/test_convert.py
View file @
cae4308f
...
...
@@ -12,10 +12,9 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
b
lackbox_module
from
nni.retiarii
import
b
asic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
class
MnistNet
(
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -35,8 +34,8 @@ class MnistNet(nn.Module):
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
# NOTE:
blackbox
module cannot be placed within class or function
@
b
lackbox_module
# NOTE:
serialize
module cannot be placed within class or function
@
b
asic_unit
class
Linear
(
nn
.
Module
):
def
__init__
(
self
,
d_embed
,
d_proj
):
super
().
__init__
()
...
...
@@ -66,9 +65,6 @@ class TestConvert(unittest.TestCase):
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
from
.inject_nn
import
remove_inject_pytorch_nn
remove_inject_pytorch_nn
()
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
...
...
@@ -458,9 +454,12 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
VAE
().
eval
(),
(
torch
.
rand
(
128
,
1
,
28
,
28
),))
def
test_torchvision_resnet18
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
from
.inject_nn
import
inject_pytorch_nn
,
remove_inject_pytorch_nn
try
:
inject_pytorch_nn
()
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
finally
:
remove_inject_pytorch_nn
()
def
test_resnet
(
self
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
...
...
@@ -572,8 +571,11 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
resnet18
,
(
torch
.
randn
(
1
,
3
,
224
,
224
),))
def
test_alexnet
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
model
=
torchvision
.
models
.
AlexNet
()
self
.
checkExportImport
(
model
,
(
x
,))
from
.inject_nn
import
inject_pytorch_nn
,
remove_inject_pytorch_nn
try
:
inject_pytorch_nn
()
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
model
=
torchvision
.
models
.
AlexNet
()
self
.
checkExportImport
(
model
,
(
x
,))
finally
:
remove_inject_pytorch_nn
()
test/ut/retiarii/test_convert_basic.py
View file @
cae4308f
...
...
@@ -8,10 +8,9 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
b
lackbox_module
from
nni.retiarii
import
b
asic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
# following pytorch v1.7.1
...
...
test/ut/retiarii/test_convert_operators.py
View file @
cae4308f
...
...
@@ -15,10 +15,8 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_module
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
# following pytorch v1.7.1
...
...
test/ut/retiarii/test_convert_pytorch.py
View file @
cae4308f
...
...
@@ -14,10 +14,9 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_modul
e
from
nni.retiarii
import
serializ
e
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
class
TestPytorch
(
unittest
.
TestCase
):
...
...
test/ut/retiarii/test_dedup_input.py
View file @
cae4308f
...
...
@@ -17,7 +17,6 @@ from nni.retiarii import Model, Node
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.trainer.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.utils
import
import_
...
...
@@ -74,7 +73,7 @@ class DedupInputTest(unittest.TestCase):
# sys.path.insert(0, 'generated')
# multi_model = import_('debug_dedup_input.logical_0')
# trainer = PyTorchMultiModelTrainer(
# multi_model(), phy_models[0][0].
training_config
.kwargs
# multi_model(), phy_models[0][0].
evaluator
.kwargs
# )
# trainer.fit()
...
...
test/ut/retiarii/test_engine.py
View file @
cae4308f
...
...
@@ -9,7 +9,7 @@ import nni
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.integration
import
RetiariiAdvisor
,
register_advisor
from
nni.retiarii.
traine
r.pytorch
import
PyTorchImageClassificationTrainer
from
nni.retiarii.
evaluato
r.pytorch
import
PyTorchImageClassificationTrainer
from
nni.retiarii.utils
import
import_
...
...
test/ut/retiarii/test_graph.py
View file @
cae4308f
...
...
@@ -23,7 +23,7 @@ def _test_file(json_path):
# add default values to JSON, so we can compare with `==`
for
graph_name
,
graph
in
orig_ir
.
items
():
if
graph_name
==
'_
training_config
'
:
if
graph_name
==
'_
evaluator
'
:
continue
if
'inputs'
not
in
graph
:
graph
[
'inputs'
]
=
None
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
cae4308f
...
...
@@ -4,7 +4,7 @@ import unittest
import
nni.retiarii.nn.pytorch
as
nn
import
torch
import
torch.nn.functional
as
F
from
nni.retiarii
import
Sampler
,
b
lackbox_module
from
nni.retiarii
import
Sampler
,
b
asic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.nn.pytorch.mutator
import
process_inline_mutation
...
...
@@ -29,7 +29,7 @@ class RandomSampler(Sampler):
return
random
.
choice
(
candidates
)
@
b
lackbox_module
@
b
asic_unit
class
MutableConv
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
...
...
test/ut/retiarii/test_lightning_trainer.py
View file @
cae4308f
...
...
@@ -2,13 +2,13 @@ import json
import
pytest
import
nni
import
nni.retiarii.
traine
r.pytorch.lightning
as
pl
import
nni.retiarii.
evaluato
r.pytorch.lightning
as
pl
import
pytorch_lightning
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.retiarii
import
blackbox_module
as
bm
from
nni.retiarii.
traine
r
import
Functional
Traine
r
from
nni.retiarii
import
serialize_cls
,
serialize
from
nni.retiarii.
evaluato
r
import
Functional
Evaluato
r
from
sklearn.datasets
import
load_diabetes
from
torch.utils.data
import
Dataset
from
torchvision
import
transforms
...
...
@@ -49,7 +49,7 @@ class FCNet(nn.Module):
return
output
.
view
(
-
1
)
@
bm
@
serialize_cls
class
DiabetesDataset
(
Dataset
):
def
__init__
(
self
,
train
=
True
):
data
=
load_diabetes
()
...
...
@@ -91,8 +91,8 @@ def _reset():
def
test_mnist
():
_reset
()
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
,
limit_train_batches
=
0.25
,
# for faster training
...
...
@@ -121,7 +121,7 @@ def test_diabetes():
@
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs.'
)
def
test_functional
():
Functional
Traine
r
(
_foo
).
_execute
(
MNISTModel
)
Functional
Evaluato
r
(
_foo
).
_execute
(
MNISTModel
)
if
__name__
==
'__main__'
:
...
...
test/ut/retiarii/test_serializer.py
View file @
cae4308f
...
...
@@ -4,7 +4,7 @@ import re
import
sys
import
torch
from
nni.retiarii
import
json_dumps
,
json_loads
,
blackbox
from
nni.retiarii
import
json_dumps
,
json_loads
,
serialize
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
...
...
@@ -23,30 +23,30 @@ class Foo:
return
self
.
aa
==
other
.
aa
and
self
.
bb
==
other
.
bb
def
test_
blackbox
():
module
=
blackbox
(
Foo
,
3
)
def
test_
serialize
():
module
=
serialize
(
Foo
,
3
)
assert
json_loads
(
json_dumps
(
module
))
==
module
module
=
blackbox
(
Foo
,
b
=
2
,
a
=
1
)
module
=
serialize
(
Foo
,
b
=
2
,
a
=
1
)
assert
json_loads
(
json_dumps
(
module
))
==
module
module
=
blackbox
(
Foo
,
Foo
(
1
),
5
)
module
=
serialize
(
Foo
,
Foo
(
1
),
5
)
dumped_module
=
json_dumps
(
module
)
assert
len
(
dumped_module
)
>
200
# should not be too longer if the serialization is correct
module
=
blackbox
(
Foo
,
blackbox
(
Foo
,
1
),
5
)
module
=
serialize
(
Foo
,
serialize
(
Foo
,
1
),
5
)
dumped_module
=
json_dumps
(
module
)
assert
len
(
dumped_module
)
<
200
# should not be too longer if the serialization is correct
assert
json_loads
(
dumped_module
)
==
module
def
test_b
lackbox_module
():
def
test_b
asic_unit
():
module
=
ImportTest
(
3
,
0.5
)
assert
json_loads
(
json_dumps
(
module
))
==
module
def
test_dataset
():
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
)
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
)
dataloader
=
serialize
(
DataLoader
,
dataset
,
batch_size
=
10
)
dumped_ans
=
{
"__type__"
:
"torch.utils.data.dataloader.DataLoader"
,
...
...
@@ -62,19 +62,19 @@ def test_dataset():
dataloader
=
json_loads
(
json_dumps
(
dumped_ans
))
assert
isinstance
(
dataloader
,
DataLoader
)
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
blackbox
(
dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
serialize
(
transforms
.
Compose
,
[
blackbox
(
transforms
.
ToTensor
),
blackbox
(
transforms
.
Normalize
,
(
0.1307
,),
(
0.3081
,))]
[
serialize
(
transforms
.
ToTensor
),
serialize
(
transforms
.
Normalize
,
(
0.1307
,),
(
0.3081
,))]
))
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
dataloader
=
serialize
(
DataLoader
,
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))]))
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
dataloader
=
serialize
(
DataLoader
,
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
...
...
@@ -87,7 +87,7 @@ def test_type():
if
__name__
==
'__main__'
:
test_
blackbox
()
test_b
lackbox_module
()
test_
serialize
()
test_b
asic_unit
()
test_dataset
()
test_type
()
test/ut/retiarii/test_strategy.py
View file @
cae4308f
...
...
@@ -12,7 +12,7 @@ from nni.retiarii import Model
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.execution
import
wait_models
from
nni.retiarii.execution.interface
import
AbstractExecutionEngine
,
WorkerInfo
,
MetricData
,
AbstractGraphListener
from
nni.retiarii.graph
import
Debug
Training
,
ModelStatus
from
nni.retiarii.graph
import
Debug
Evaluator
,
ModelStatus
from
nni.retiarii.nn.pytorch.mutator
import
process_inline_mutation
...
...
@@ -80,7 +80,7 @@ def _get_model_and_mutators():
base_model
=
Net
()
script_module
=
torch
.
jit
.
script
(
base_model
)
base_model_ir
=
convert_to_graph
(
script_module
,
base_model
)
base_model_ir
.
training_config
=
DebugTraining
()
base_model_ir
.
evaluator
=
DebugEvaluator
()
mutators
=
process_inline_mutation
(
base_model_ir
)
return
base_model_ir
,
mutators
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment