Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
cae4308f
Unverified
Commit
cae4308f
authored
Mar 03, 2021
by
Yuge Zhang
Committed by
GitHub
Mar 03, 2021
Browse files
[Retiarii] Rename APIs and refine documentation (#3404)
parent
d047d6f4
Changes
59
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
67 additions
and
88 deletions
+67
-88
test/retiarii_test/mnasnet/test.py
test/retiarii_test/mnasnet/test.py
+5
-7
test/retiarii_test/mnist/test.py
test/retiarii_test/mnist/test.py
+4
-4
test/ut/retiarii/converted_mnist_pytorch.json
test/ut/retiarii/converted_mnist_pytorch.json
+1
-1
test/ut/retiarii/imported/model.py
test/ut/retiarii/imported/model.py
+2
-2
test/ut/retiarii/inject_nn.py
test/ut/retiarii/inject_nn.py
+2
-18
test/ut/retiarii/mnist-tensorflow.json
test/ut/retiarii/mnist-tensorflow.json
+1
-1
test/ut/retiarii/mnist_pytorch.json
test/ut/retiarii/mnist_pytorch.json
+1
-1
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+1
-1
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+17
-15
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+1
-2
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+0
-2
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+1
-2
test/ut/retiarii/test_dedup_input.py
test/ut/retiarii/test_dedup_input.py
+1
-2
test/ut/retiarii/test_engine.py
test/ut/retiarii/test_engine.py
+1
-1
test/ut/retiarii/test_graph.py
test/ut/retiarii/test_graph.py
+1
-1
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+2
-2
test/ut/retiarii/test_lightning_trainer.py
test/ut/retiarii/test_lightning_trainer.py
+7
-7
test/ut/retiarii/test_serializer.py
test/ut/retiarii/test_serializer.py
+17
-17
test/ut/retiarii/test_strategy.py
test/ut/retiarii/test_strategy.py
+2
-2
No files found.
test/retiarii_test/mnasnet/test.py
View file @
cae4308f
...
@@ -3,10 +3,8 @@ import sys
...
@@ -3,10 +3,8 @@ import sys
import
torch
import
torch
from
pathlib
import
Path
from
pathlib
import
Path
from
nni.retiarii.trainer.pytorch
import
PyTorchImageClassificationTrainer
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
from
nni.retiarii
import
serialize
import
nni.retiarii.trainer.pytorch.lightning
as
pl
from
nni.retiarii
import
blackbox_module
as
bm
from
base_mnasnet
import
MNASNet
from
base_mnasnet
import
MNASNet
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.strategy
import
TPEStrategy
from
nni.retiarii.strategy
import
TPEStrategy
...
@@ -35,8 +33,8 @@ if __name__ == '__main__':
...
@@ -35,8 +33,8 @@ if __name__ == '__main__':
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
])
])
train_dataset
=
bm
(
CIFAR10
)(
root
=
'data/cifar10'
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
train_dataset
=
serialize
(
CIFAR10
,
root
=
'data/cifar10'
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
test_dataset
=
bm
(
CIFAR10
)(
root
=
'data/cifar10'
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
test_dataset
=
serialize
(
CIFAR10
,
root
=
'data/cifar10'
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
1
,
limit_train_batches
=
0.2
)
max_epochs
=
1
,
limit_train_batches
=
0.2
)
...
@@ -56,4 +54,4 @@ if __name__ == '__main__':
...
@@ -56,4 +54,4 @@ if __name__ == '__main__':
exp_config
.
max_trial_number
=
10
exp_config
.
max_trial_number
=
10
exp_config
.
training_service
.
use_active_gpu
=
False
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
80
81
)
exp
.
run
(
exp_config
,
80
97
)
test/retiarii_test/mnist/test.py
View file @
cae4308f
...
@@ -2,9 +2,9 @@ import random
...
@@ -2,9 +2,9 @@ import random
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.strategy
as
strategy
import
nni.retiarii.strategy
as
strategy
import
nni.retiarii.
traine
r.pytorch.lightning
as
pl
import
nni.retiarii.
evaluato
r.pytorch.lightning
as
pl
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
nni.retiarii
import
blackbox_module
as
bm
from
nni.retiarii
import
serialize
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
from
torchvision
import
transforms
...
@@ -36,8 +36,8 @@ class Net(nn.Module):
...
@@ -36,8 +36,8 @@ class Net(nn.Module):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
base_model
=
Net
(
128
)
base_model
=
Net
(
128
)
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
)
max_epochs
=
2
)
...
...
test/ut/retiarii/converted_mnist_pytorch.json
View file @
cae4308f
...
@@ -340,7 +340,7 @@
...
@@ -340,7 +340,7 @@
}
}
]
]
},
},
"_
training_config
"
:
{
"_
evaluator
"
:
{
"module"
:
"nni.retiarii.trainer.PyTorchImageClassificationTrainer"
,
"module"
:
"nni.retiarii.trainer.PyTorchImageClassificationTrainer"
,
"kwargs"
:
{
"kwargs"
:
{
"dataset_cls"
:
"MNIST"
,
"dataset_cls"
:
"MNIST"
,
...
...
test/ut/retiarii/imported/model.py
View file @
cae4308f
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
b
lackbox_module
from
nni.retiarii
import
b
asic_unit
@
b
lackbox_module
@
b
asic_unit
class
ImportTest
(
nn
.
Module
):
class
ImportTest
(
nn
.
Module
):
def
__init__
(
self
,
foo
,
bar
):
def
__init__
(
self
,
foo
,
bar
):
super
().
__init__
()
super
().
__init__
()
...
...
test/ut/retiarii/inject_nn.py
View file @
cae4308f
...
@@ -4,7 +4,7 @@ import logging
...
@@ -4,7 +4,7 @@ import logging
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.retiarii.utils
import
add_record
,
del_record
,
version_larger_equal
from
nni.retiarii.utils
import
version_larger_equal
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -13,39 +13,23 @@ def wrap_module(original_class):
...
@@ -13,39 +13,23 @@ def wrap_module(original_class):
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
# Make copy of original __init__, so we can call it without recursion
original_class
.
bak_init_for_inject
=
orig_init
original_class
.
bak_init_for_inject
=
orig_init
if
hasattr
(
original_class
,
'__del__'
):
orig_del
=
original_class
.
__del__
original_class
.
bak_del_for_inject
=
orig_del
else
:
orig_del
=
None
original_class
.
bak_del_for_inject
=
None
def
__init__
(
self
,
*
args
,
**
kws
):
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
=
{}
full_args
.
update
(
kws
)
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
self
.
_init_parameters
=
full_args
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
def
__del__
(
self
):
del_record
(
id
(
self
))
if
orig_del
is
not
None
:
orig_del
(
self
)
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__del__
=
__del__
return
original_class
return
original_class
def
unwrap_module
(
wrapped_class
):
def
unwrap_module
(
wrapped_class
):
if
hasattr
(
wrapped_class
,
'bak_init_for_inject'
):
if
hasattr
(
wrapped_class
,
'bak_init_for_inject'
):
wrapped_class
.
__init__
=
wrapped_class
.
bak_init_for_inject
wrapped_class
.
__init__
=
wrapped_class
.
bak_init_for_inject
delattr
(
wrapped_class
,
'bak_init_for_inject'
)
delattr
(
wrapped_class
,
'bak_init_for_inject'
)
if
hasattr
(
wrapped_class
,
'bak_del_for_inject'
):
if
wrapped_class
.
bak_del_for_inject
is
not
None
:
wrapped_class
.
__del__
=
wrapped_class
.
bak_del_for_inject
delattr
(
wrapped_class
,
'bak_del_for_inject'
)
return
None
return
None
def
remove_inject_pytorch_nn
():
def
remove_inject_pytorch_nn
():
...
...
test/ut/retiarii/mnist-tensorflow.json
View file @
cae4308f
...
@@ -38,7 +38,7 @@
...
@@ -38,7 +38,7 @@
]
]
},
},
"_
training_config
"
:
{
"_
evaluator
"
:
{
"__type__"
:
"_debug_no_trainer"
"__type__"
:
"_debug_no_trainer"
}
}
}
}
test/ut/retiarii/mnist_pytorch.json
View file @
cae4308f
...
@@ -38,7 +38,7 @@
...
@@ -38,7 +38,7 @@
]
]
},
},
"_
training_config
"
:
{
"_
evaluator
"
:
{
"module"
:
"nni.retiarii.trainer.PyTorchImageClassificationTrainer"
,
"module"
:
"nni.retiarii.trainer.PyTorchImageClassificationTrainer"
,
"kwargs"
:
{
"kwargs"
:
{
"dataset_cls"
:
"MNIST"
,
"dataset_cls"
:
"MNIST"
,
...
...
test/ut/retiarii/test_cgo_engine.py
View file @
cae4308f
...
@@ -18,7 +18,7 @@ from nni.retiarii import Model, Node
...
@@ -18,7 +18,7 @@ from nni.retiarii import Model, Node
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.
traine
r.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.
evaluato
r.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.utils
import
import_
from
nni.retiarii.utils
import
import_
...
...
test/ut/retiarii/test_convert.py
View file @
cae4308f
...
@@ -12,10 +12,9 @@ import torch.nn.functional as F
...
@@ -12,10 +12,9 @@ import torch.nn.functional as F
import
torchvision
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
b
lackbox_module
from
nni.retiarii
import
b
asic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
class
MnistNet
(
nn
.
Module
):
class
MnistNet
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -35,8 +34,8 @@ class MnistNet(nn.Module):
...
@@ -35,8 +34,8 @@ class MnistNet(nn.Module):
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
# NOTE:
blackbox
module cannot be placed within class or function
# NOTE:
serialize
module cannot be placed within class or function
@
b
lackbox_module
@
b
asic_unit
class
Linear
(
nn
.
Module
):
class
Linear
(
nn
.
Module
):
def
__init__
(
self
,
d_embed
,
d_proj
):
def
__init__
(
self
,
d_embed
,
d_proj
):
super
().
__init__
()
super
().
__init__
()
...
@@ -66,9 +65,6 @@ class TestConvert(unittest.TestCase):
...
@@ -66,9 +65,6 @@ class TestConvert(unittest.TestCase):
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
from
.inject_nn
import
remove_inject_pytorch_nn
remove_inject_pytorch_nn
()
exec_vars
=
{}
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_model
=
exec_vars
[
'converted_model'
]
...
@@ -458,9 +454,12 @@ class TestConvert(unittest.TestCase):
...
@@ -458,9 +454,12 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
VAE
().
eval
(),
(
torch
.
rand
(
128
,
1
,
28
,
28
),))
self
.
checkExportImport
(
VAE
().
eval
(),
(
torch
.
rand
(
128
,
1
,
28
,
28
),))
def
test_torchvision_resnet18
(
self
):
def
test_torchvision_resnet18
(
self
):
from
.inject_nn
import
inject_pytorch_nn
from
.inject_nn
import
inject_pytorch_nn
,
remove_inject_pytorch_nn
try
:
inject_pytorch_nn
()
inject_pytorch_nn
()
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
finally
:
remove_inject_pytorch_nn
()
def
test_resnet
(
self
):
def
test_resnet
(
self
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
...
@@ -572,8 +571,11 @@ class TestConvert(unittest.TestCase):
...
@@ -572,8 +571,11 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
resnet18
,
(
torch
.
randn
(
1
,
3
,
224
,
224
),))
self
.
checkExportImport
(
resnet18
,
(
torch
.
randn
(
1
,
3
,
224
,
224
),))
def
test_alexnet
(
self
):
def
test_alexnet
(
self
):
from
.inject_nn
import
inject_pytorch_nn
from
.inject_nn
import
inject_pytorch_nn
,
remove_inject_pytorch_nn
try
:
inject_pytorch_nn
()
inject_pytorch_nn
()
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
model
=
torchvision
.
models
.
AlexNet
()
model
=
torchvision
.
models
.
AlexNet
()
self
.
checkExportImport
(
model
,
(
x
,))
self
.
checkExportImport
(
model
,
(
x
,))
finally
:
remove_inject_pytorch_nn
()
test/ut/retiarii/test_convert_basic.py
View file @
cae4308f
...
@@ -8,10 +8,9 @@ import torch.nn.functional as F
...
@@ -8,10 +8,9 @@ import torch.nn.functional as F
import
torchvision
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
b
lackbox_module
from
nni.retiarii
import
b
asic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
# following pytorch v1.7.1
# following pytorch v1.7.1
...
...
test/ut/retiarii/test_convert_operators.py
View file @
cae4308f
...
@@ -15,10 +15,8 @@ import torch.nn.functional as F
...
@@ -15,10 +15,8 @@ import torch.nn.functional as F
import
torchvision
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_module
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
# following pytorch v1.7.1
# following pytorch v1.7.1
...
...
test/ut/retiarii/test_convert_pytorch.py
View file @
cae4308f
...
@@ -14,10 +14,9 @@ import torch.nn.functional as F
...
@@ -14,10 +14,9 @@ import torch.nn.functional as F
import
torchvision
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_modul
e
from
nni.retiarii
import
serializ
e
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
class
TestPytorch
(
unittest
.
TestCase
):
class
TestPytorch
(
unittest
.
TestCase
):
...
...
test/ut/retiarii/test_dedup_input.py
View file @
cae4308f
...
@@ -17,7 +17,6 @@ from nni.retiarii import Model, Node
...
@@ -17,7 +17,6 @@ from nni.retiarii import Model, Node
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.trainer.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.utils
import
import_
from
nni.retiarii.utils
import
import_
...
@@ -74,7 +73,7 @@ class DedupInputTest(unittest.TestCase):
...
@@ -74,7 +73,7 @@ class DedupInputTest(unittest.TestCase):
# sys.path.insert(0, 'generated')
# sys.path.insert(0, 'generated')
# multi_model = import_('debug_dedup_input.logical_0')
# multi_model = import_('debug_dedup_input.logical_0')
# trainer = PyTorchMultiModelTrainer(
# trainer = PyTorchMultiModelTrainer(
# multi_model(), phy_models[0][0].
training_config
.kwargs
# multi_model(), phy_models[0][0].
evaluator
.kwargs
# )
# )
# trainer.fit()
# trainer.fit()
...
...
test/ut/retiarii/test_engine.py
View file @
cae4308f
...
@@ -9,7 +9,7 @@ import nni
...
@@ -9,7 +9,7 @@ import nni
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.integration
import
RetiariiAdvisor
,
register_advisor
from
nni.retiarii.integration
import
RetiariiAdvisor
,
register_advisor
from
nni.retiarii.
traine
r.pytorch
import
PyTorchImageClassificationTrainer
from
nni.retiarii.
evaluato
r.pytorch
import
PyTorchImageClassificationTrainer
from
nni.retiarii.utils
import
import_
from
nni.retiarii.utils
import
import_
...
...
test/ut/retiarii/test_graph.py
View file @
cae4308f
...
@@ -23,7 +23,7 @@ def _test_file(json_path):
...
@@ -23,7 +23,7 @@ def _test_file(json_path):
# add default values to JSON, so we can compare with `==`
# add default values to JSON, so we can compare with `==`
for
graph_name
,
graph
in
orig_ir
.
items
():
for
graph_name
,
graph
in
orig_ir
.
items
():
if
graph_name
==
'_
training_config
'
:
if
graph_name
==
'_
evaluator
'
:
continue
continue
if
'inputs'
not
in
graph
:
if
'inputs'
not
in
graph
:
graph
[
'inputs'
]
=
None
graph
[
'inputs'
]
=
None
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
cae4308f
...
@@ -4,7 +4,7 @@ import unittest
...
@@ -4,7 +4,7 @@ import unittest
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
nni.retiarii
import
Sampler
,
b
lackbox_module
from
nni.retiarii
import
Sampler
,
b
asic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.nn.pytorch.mutator
import
process_inline_mutation
from
nni.retiarii.nn.pytorch.mutator
import
process_inline_mutation
...
@@ -29,7 +29,7 @@ class RandomSampler(Sampler):
...
@@ -29,7 +29,7 @@ class RandomSampler(Sampler):
return
random
.
choice
(
candidates
)
return
random
.
choice
(
candidates
)
@
b
lackbox_module
@
b
asic_unit
class
MutableConv
(
nn
.
Module
):
class
MutableConv
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
...
test/ut/retiarii/test_lightning_trainer.py
View file @
cae4308f
...
@@ -2,13 +2,13 @@ import json
...
@@ -2,13 +2,13 @@ import json
import
pytest
import
pytest
import
nni
import
nni
import
nni.retiarii.
traine
r.pytorch.lightning
as
pl
import
nni.retiarii.
evaluato
r.pytorch.lightning
as
pl
import
pytorch_lightning
import
pytorch_lightning
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
nni.retiarii
import
blackbox_module
as
bm
from
nni.retiarii
import
serialize_cls
,
serialize
from
nni.retiarii.
traine
r
import
Functional
Traine
r
from
nni.retiarii.
evaluato
r
import
Functional
Evaluato
r
from
sklearn.datasets
import
load_diabetes
from
sklearn.datasets
import
load_diabetes
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
torchvision
import
transforms
from
torchvision
import
transforms
...
@@ -49,7 +49,7 @@ class FCNet(nn.Module):
...
@@ -49,7 +49,7 @@ class FCNet(nn.Module):
return
output
.
view
(
-
1
)
return
output
.
view
(
-
1
)
@
bm
@
serialize_cls
class
DiabetesDataset
(
Dataset
):
class
DiabetesDataset
(
Dataset
):
def
__init__
(
self
,
train
=
True
):
def
__init__
(
self
,
train
=
True
):
data
=
load_diabetes
()
data
=
load_diabetes
()
...
@@ -91,8 +91,8 @@ def _reset():
...
@@ -91,8 +91,8 @@ def _reset():
def
test_mnist
():
def
test_mnist
():
_reset
()
_reset
()
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
,
limit_train_batches
=
0.25
,
# for faster training
max_epochs
=
2
,
limit_train_batches
=
0.25
,
# for faster training
...
@@ -121,7 +121,7 @@ def test_diabetes():
...
@@ -121,7 +121,7 @@ def test_diabetes():
@
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs.'
)
@
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs.'
)
def
test_functional
():
def
test_functional
():
Functional
Traine
r
(
_foo
).
_execute
(
MNISTModel
)
Functional
Evaluato
r
(
_foo
).
_execute
(
MNISTModel
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
test/ut/retiarii/test_serializer.py
View file @
cae4308f
...
@@ -4,7 +4,7 @@ import re
...
@@ -4,7 +4,7 @@ import re
import
sys
import
sys
import
torch
import
torch
from
nni.retiarii
import
json_dumps
,
json_loads
,
blackbox
from
nni.retiarii
import
json_dumps
,
json_loads
,
serialize
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
from
torchvision.datasets
import
MNIST
...
@@ -23,30 +23,30 @@ class Foo:
...
@@ -23,30 +23,30 @@ class Foo:
return
self
.
aa
==
other
.
aa
and
self
.
bb
==
other
.
bb
return
self
.
aa
==
other
.
aa
and
self
.
bb
==
other
.
bb
def
test_
blackbox
():
def
test_
serialize
():
module
=
blackbox
(
Foo
,
3
)
module
=
serialize
(
Foo
,
3
)
assert
json_loads
(
json_dumps
(
module
))
==
module
assert
json_loads
(
json_dumps
(
module
))
==
module
module
=
blackbox
(
Foo
,
b
=
2
,
a
=
1
)
module
=
serialize
(
Foo
,
b
=
2
,
a
=
1
)
assert
json_loads
(
json_dumps
(
module
))
==
module
assert
json_loads
(
json_dumps
(
module
))
==
module
module
=
blackbox
(
Foo
,
Foo
(
1
),
5
)
module
=
serialize
(
Foo
,
Foo
(
1
),
5
)
dumped_module
=
json_dumps
(
module
)
dumped_module
=
json_dumps
(
module
)
assert
len
(
dumped_module
)
>
200
# should not be too longer if the serialization is correct
assert
len
(
dumped_module
)
>
200
# should not be too longer if the serialization is correct
module
=
blackbox
(
Foo
,
blackbox
(
Foo
,
1
),
5
)
module
=
serialize
(
Foo
,
serialize
(
Foo
,
1
),
5
)
dumped_module
=
json_dumps
(
module
)
dumped_module
=
json_dumps
(
module
)
assert
len
(
dumped_module
)
<
200
# should not be too longer if the serialization is correct
assert
len
(
dumped_module
)
<
200
# should not be too longer if the serialization is correct
assert
json_loads
(
dumped_module
)
==
module
assert
json_loads
(
dumped_module
)
==
module
def
test_b
lackbox_module
():
def
test_b
asic_unit
():
module
=
ImportTest
(
3
,
0.5
)
module
=
ImportTest
(
3
,
0.5
)
assert
json_loads
(
json_dumps
(
module
))
==
module
assert
json_loads
(
json_dumps
(
module
))
==
module
def
test_dataset
():
def
test_dataset
():
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
)
dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
)
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
dataloader
=
serialize
(
DataLoader
,
dataset
,
batch_size
=
10
)
dumped_ans
=
{
dumped_ans
=
{
"__type__"
:
"torch.utils.data.dataloader.DataLoader"
,
"__type__"
:
"torch.utils.data.dataloader.DataLoader"
,
...
@@ -62,19 +62,19 @@ def test_dataset():
...
@@ -62,19 +62,19 @@ def test_dataset():
dataloader
=
json_loads
(
json_dumps
(
dumped_ans
))
dataloader
=
json_loads
(
json_dumps
(
dumped_ans
))
assert
isinstance
(
dataloader
,
DataLoader
)
assert
isinstance
(
dataloader
,
DataLoader
)
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
blackbox
(
transform
=
serialize
(
transforms
.
Compose
,
transforms
.
Compose
,
[
blackbox
(
transforms
.
ToTensor
),
blackbox
(
transforms
.
Normalize
,
(
0.1307
,),
(
0.3081
,))]
[
serialize
(
transforms
.
ToTensor
),
serialize
(
transforms
.
Normalize
,
(
0.1307
,),
(
0.3081
,))]
))
))
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
dataloader
=
serialize
(
DataLoader
,
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))]))
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))]))
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
dataloader
=
serialize
(
DataLoader
,
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
...
@@ -87,7 +87,7 @@ def test_type():
...
@@ -87,7 +87,7 @@ def test_type():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_
blackbox
()
test_
serialize
()
test_b
lackbox_module
()
test_b
asic_unit
()
test_dataset
()
test_dataset
()
test_type
()
test_type
()
test/ut/retiarii/test_strategy.py
View file @
cae4308f
...
@@ -12,7 +12,7 @@ from nni.retiarii import Model
...
@@ -12,7 +12,7 @@ from nni.retiarii import Model
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.execution
import
wait_models
from
nni.retiarii.execution
import
wait_models
from
nni.retiarii.execution.interface
import
AbstractExecutionEngine
,
WorkerInfo
,
MetricData
,
AbstractGraphListener
from
nni.retiarii.execution.interface
import
AbstractExecutionEngine
,
WorkerInfo
,
MetricData
,
AbstractGraphListener
from
nni.retiarii.graph
import
Debug
Training
,
ModelStatus
from
nni.retiarii.graph
import
Debug
Evaluator
,
ModelStatus
from
nni.retiarii.nn.pytorch.mutator
import
process_inline_mutation
from
nni.retiarii.nn.pytorch.mutator
import
process_inline_mutation
...
@@ -80,7 +80,7 @@ def _get_model_and_mutators():
...
@@ -80,7 +80,7 @@ def _get_model_and_mutators():
base_model
=
Net
()
base_model
=
Net
()
script_module
=
torch
.
jit
.
script
(
base_model
)
script_module
=
torch
.
jit
.
script
(
base_model
)
base_model_ir
=
convert_to_graph
(
script_module
,
base_model
)
base_model_ir
=
convert_to_graph
(
script_module
,
base_model
)
base_model_ir
.
training_config
=
DebugTraining
()
base_model_ir
.
evaluator
=
DebugEvaluator
()
mutators
=
process_inline_mutation
(
base_model_ir
)
mutators
=
process_inline_mutation
(
base_model_ir
)
return
base_model_ir
,
mutators
return
base_model_ir
,
mutators
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment