Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
445e7e0b
Unverified
Commit
445e7e0b
authored
Feb 14, 2021
by
Yuge Zhang
Committed by
GitHub
Feb 14, 2021
Browse files
[Retiarii] Rewrite trainer with PyTorch Lightning (#3359)
parent
137830df
Changes
29
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
279 additions
and
67 deletions
+279
-67
test/retiarii_test/darts/test_oneshot.py
test/retiarii_test/darts/test_oneshot.py
+1
-1
test/retiarii_test/mnasnet/mutator.py
test/retiarii_test/mnasnet/mutator.py
+2
-4
test/retiarii_test/mnasnet/test.py
test/retiarii_test/mnasnet/test.py
+27
-12
test/retiarii_test/mnist/test.py
test/retiarii_test/mnist/test.py
+12
-7
test/retiarii_test/simple_strategy.py
test/retiarii_test/simple_strategy.py
+0
-41
test/ut/retiarii/imported/model.py
test/ut/retiarii/imported/model.py
+13
-0
test/ut/retiarii/mnist-tensorflow.json
test/ut/retiarii/mnist-tensorflow.json
+1
-2
test/ut/retiarii/test_lightning_trainer.py
test/ut/retiarii/test_lightning_trainer.py
+130
-0
test/ut/retiarii/test_serializer.py
test/ut/retiarii/test_serializer.py
+93
-0
No files found.
test/retiarii_test/darts/test_oneshot.py
View file @
445e7e0b
...
...
@@ -8,7 +8,7 @@ from pathlib import Path
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
from
nni.retiarii.experiment
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.experiment
.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.strategies
import
TPEStrategy
from
nni.retiarii.trainer.pytorch
import
DartsTrainer
...
...
test/retiarii_test/mnasnet/mutator.py
View file @
445e7e0b
...
...
@@ -27,14 +27,12 @@ class BlockMutator(Mutator):
n_filter
=
self
.
choice
(
related_info
[
'n_filter_options'
])
if
related_info
[
'in_ch'
]
is
not
None
:
_logger
.
info
(
'zql debug X ...'
)
in_ch
=
related_info
[
'in_ch'
]
else
:
assert
len
(
node
.
predecessors
)
==
1
the_node
=
node
.
predecessors
[
0
]
_logger
.
info
(
'zql debug ...'
)
_logger
.
info
(
the_node
.
operation
.
parameters
)
_logger
.
info
(
the_node
.
__repr__
())
_logger
.
debug
(
repr
(
the_node
.
operation
.
parameters
))
_logger
.
debug
(
the_node
.
__repr__
())
in_ch
=
the_node
.
operation
.
parameters
[
'out_ch'
]
# update the placeholder to be a new operation
...
...
test/retiarii_test/mnasnet/test.py
View file @
445e7e0b
...
...
@@ -5,10 +5,14 @@ from pathlib import Path
from
nni.retiarii.trainer.pytorch
import
PyTorchImageClassificationTrainer
import
nni.retiarii.trainer.pytorch.lightning
as
pl
from
nni.retiarii
import
blackbox_module
as
bm
from
base_mnasnet
import
MNASNet
from
nni.retiarii.experiment
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.strategies
import
TPEStrategy
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
from
mutator
import
BlockMutator
if
__name__
==
'__main__'
:
...
...
@@ -20,16 +24,27 @@ if __name__ == '__main__':
base_model
=
MNASNet
(
0.5
,
_DEFAULT_DEPTHS
,
_DEFAULT_CONVOPS
,
_DEFAULT_KERNEL_SIZES
,
_DEFAULT_NUM_LAYERS
,
_DEFAULT_SKIPS
)
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"CIFAR10"
,
dataset_kwargs
=
{
"root"
:
"data/cifar10"
,
"download"
:
True
},
dataloader_kwargs
=
{
"batch_size"
:
32
},
optimizer_kwargs
=
{
"lr"
:
1e-3
},
trainer_kwargs
=
{
"max_epochs"
:
1
})
# new interface
applied_mutators
=
[]
applied_mutators
.
append
(
BlockMutator
(
'mutable_0'
))
applied_mutators
.
append
(
BlockMutator
(
'mutable_1'
))
train_transform
=
transforms
.
Compose
([
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
])
valid_transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
])
train_dataset
=
bm
(
CIFAR10
)(
root
=
'data/cifar10'
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
test_dataset
=
bm
(
CIFAR10
)(
root
=
'data/cifar10'
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
1
,
limit_train_batches
=
0.2
)
applied_mutators
=
[
BlockMutator
(
'mutable_0'
),
BlockMutator
(
'mutable_1'
)
]
simple_startegy
=
TPEStrategy
()
...
...
test/retiarii_test/mnist/test.py
View file @
445e7e0b
import
random
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.trainer.pytorch.lightning
as
pl
import
torch.nn.functional
as
F
from
nni.retiarii.experiment
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii
import
blackbox_module
as
bm
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.strategies
import
RandomStrategy
from
nni.retiarii.trainer.pytorch
import
PyTorchImageClassificationTrainer
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
class
Net
(
nn
.
Module
):
...
...
@@ -31,11 +35,12 @@ class Net(nn.Module):
if
__name__
==
'__main__'
:
base_model
=
Net
(
128
)
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"MNIST"
,
dataset_kwargs
=
{
"root"
:
"data/mnist"
,
"download"
:
True
},
dataloader_kwargs
=
{
"batch_size"
:
32
},
optimizer_kwargs
=
{
"lr"
:
1e-3
},
trainer_kwargs
=
{
"max_epochs"
:
1
})
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
)
simple_startegy
=
RandomStrategy
()
...
...
test/retiarii_test/simple_strategy.py
deleted
100644 → 0
View file @
137830df
import
json
import
logging
import
random
import
os
from
nni.retiarii
import
Model
,
submit_models
,
wait_models
from
nni.retiarii.strategy
import
BaseStrategy
from
nni.retiarii
import
Sampler
_logger
=
logging
.
getLogger
(
__name__
)
class
RandomSampler
(
Sampler
):
def
choice
(
self
,
candidates
,
mutator
,
model
,
index
):
return
random
.
choice
(
candidates
)
class
SimpleStrategy
(
BaseStrategy
):
def
__init__
(
self
):
self
.
name
=
''
def
run
(
self
,
base_model
,
applied_mutators
,
trainer
):
try
:
_logger
.
info
(
'stargety start...'
)
while
True
:
model
=
base_model
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'mutators: {}'
.
format
(
applied_mutators
))
random_sampler
=
RandomSampler
()
for
mutator
in
applied_mutators
:
_logger
.
info
(
'mutate model...'
)
mutator
.
bind_sampler
(
random_sampler
)
model
=
mutator
.
apply
(
model
)
# get and apply training approach
_logger
.
info
(
'apply training approach...'
)
model
.
apply_trainer
(
trainer
[
'modulename'
],
trainer
[
'args'
])
# run models
submit_models
(
model
)
wait_models
(
model
)
_logger
.
info
(
'Strategy says:'
,
model
.
metric
)
except
Exception
as
e
:
_logger
.
error
(
logging
.
exception
(
'message'
))
test/ut/retiarii/imported/model.py
0 → 100644
View file @
445e7e0b
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_module
@
blackbox_module
class
ImportTest
(
nn
.
Module
):
def
__init__
(
self
,
foo
,
bar
):
super
().
__init__
()
self
.
foo
=
nn
.
Linear
(
foo
,
3
)
self
.
bar
=
nn
.
Dropout
(
bar
)
def
__eq__
(
self
,
other
):
return
self
.
foo
.
in_features
==
other
.
foo
.
in_features
and
self
.
bar
.
p
==
other
.
bar
.
p
test/ut/retiarii/mnist-tensorflow.json
View file @
445e7e0b
...
...
@@ -39,7 +39,6 @@
},
"_training_config"
:
{
"module"
:
"_debug_no_trainer"
,
"kwargs"
:
{}
"__type__"
:
"_debug_no_trainer"
}
}
test/ut/retiarii/test_lightning_trainer.py
0 → 100644
View file @
445e7e0b
import
json
import
pytest
import
nni
import
nni.retiarii.trainer.pytorch.lightning
as
pl
import
pytorch_lightning
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.retiarii
import
blackbox_module
as
bm
from
nni.retiarii.trainer
import
FunctionalTrainer
from
sklearn.datasets
import
load_diabetes
from
torch.utils.data
import
Dataset
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
debug
=
False
progress_bar_refresh_rate
=
0
if
debug
:
progress_bar_refresh_rate
=
1
class
MNISTModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layer_1
=
nn
.
Linear
(
28
*
28
,
128
)
self
.
layer_2
=
nn
.
Linear
(
128
,
10
)
def
forward
(
self
,
x
):
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
self
.
layer_1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
layer_2
(
x
)
return
x
class
FCNet
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
output_size
):
super
().
__init__
()
self
.
l1
=
nn
.
Linear
(
input_size
,
5
)
self
.
relu
=
nn
.
ReLU
()
self
.
l2
=
nn
.
Linear
(
5
,
output_size
)
def
forward
(
self
,
x
):
output
=
self
.
l1
(
x
)
output
=
self
.
relu
(
output
)
output
=
self
.
l2
(
output
)
return
output
.
view
(
-
1
)
@
bm
class
DiabetesDataset
(
Dataset
):
def
__init__
(
self
,
train
=
True
):
data
=
load_diabetes
()
self
.
x
=
torch
.
tensor
(
data
[
'data'
],
dtype
=
torch
.
float32
)
self
.
y
=
torch
.
tensor
(
data
[
'target'
],
dtype
=
torch
.
float32
)
self
.
length
=
self
.
x
.
shape
[
0
]
split
=
int
(
self
.
length
*
0.8
)
if
train
:
self
.
x
=
self
.
x
[:
split
]
self
.
y
=
self
.
y
[:
split
]
else
:
self
.
x
=
self
.
x
[
split
:]
self
.
y
=
self
.
y
[
split
:]
self
.
length
=
len
(
self
.
y
)
def
__getitem__
(
self
,
idx
):
return
self
.
x
[
idx
],
self
.
y
[
idx
]
def
__len__
(
self
):
return
self
.
length
def
_get_final_result
():
return
float
(
json
.
loads
(
nni
.
runtime
.
platform
.
test
.
_last_metric
)[
'value'
])
def
_foo
(
model_cls
):
assert
model_cls
==
MNISTModel
def
_reset
():
# this is to not affect other tests in sdk
nni
.
trial
.
_intermediate_seq
=
0
nni
.
trial
.
_params
=
{
'foo'
:
'bar'
,
'parameter_id'
:
0
}
nni
.
runtime
.
platform
.
test
.
_last_metric
=
None
@
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs.'
)
def
test_mnist
():
_reset
()
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
,
limit_train_batches
=
0.25
,
# for faster training
progress_bar_refresh_rate
=
progress_bar_refresh_rate
)
lightning
.
_execute
(
MNISTModel
)
assert
_get_final_result
()
>
0.7
_reset
()
@
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs.'
)
def
test_diabetes
():
_reset
()
nni
.
trial
.
_params
=
{
'foo'
:
'bar'
,
'parameter_id'
:
0
}
nni
.
runtime
.
platform
.
test
.
_last_metric
=
None
train_dataset
=
DiabetesDataset
(
train
=
True
)
test_dataset
=
DiabetesDataset
(
train
=
False
)
lightning
=
pl
.
Regression
(
optimizer
=
torch
.
optim
.
SGD
,
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
20
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
20
),
max_epochs
=
100
,
progress_bar_refresh_rate
=
progress_bar_refresh_rate
)
lightning
.
_execute
(
FCNet
(
train_dataset
.
x
.
shape
[
1
],
1
))
assert
_get_final_result
()
<
2e4
_reset
()
@
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs.'
)
def
test_functional
():
FunctionalTrainer
(
_foo
).
_execute
(
MNISTModel
)
if
__name__
==
'__main__'
:
test_mnist
()
test_diabetes
()
test_functional
()
test/ut/retiarii/test_serializer.py
0 → 100644
View file @
445e7e0b
import
json
from
pathlib
import
Path
import
re
import
sys
import
torch
from
nni.retiarii
import
json_dumps
,
json_loads
,
blackbox
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
sys
.
path
.
insert
(
0
,
Path
(
__file__
).
parent
.
as_posix
())
from
imported.model
import
ImportTest
class
Foo
:
def
__init__
(
self
,
a
,
b
=
1
):
self
.
aa
=
a
self
.
bb
=
[
b
+
1
for
_
in
range
(
1000
)]
def
__eq__
(
self
,
other
):
return
self
.
aa
==
other
.
aa
and
self
.
bb
==
other
.
bb
def
test_blackbox
():
module
=
blackbox
(
Foo
,
3
)
assert
json_loads
(
json_dumps
(
module
))
==
module
module
=
blackbox
(
Foo
,
b
=
2
,
a
=
1
)
assert
json_loads
(
json_dumps
(
module
))
==
module
module
=
blackbox
(
Foo
,
Foo
(
1
),
5
)
dumped_module
=
json_dumps
(
module
)
assert
len
(
dumped_module
)
>
200
# should not be too longer if the serialization is correct
module
=
blackbox
(
Foo
,
blackbox
(
Foo
,
1
),
5
)
dumped_module
=
json_dumps
(
module
)
assert
len
(
dumped_module
)
<
200
# should not be too longer if the serialization is correct
assert
json_loads
(
dumped_module
)
==
module
def
test_blackbox_module
():
module
=
ImportTest
(
3
,
0.5
)
assert
json_loads
(
json_dumps
(
module
))
==
module
def
test_dataset
():
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
)
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
dumped_ans
=
{
"__type__"
:
"torch.utils.data.dataloader.DataLoader"
,
"arguments"
:
{
"batch_size"
:
10
,
"dataset"
:
{
"__type__"
:
"torchvision.datasets.mnist.MNIST"
,
"arguments"
:
{
"root"
:
"data/mnist"
,
"train"
:
False
,
"download"
:
True
}
}
}
}
assert
json_dumps
(
dataloader
)
==
json_dumps
(
dumped_ans
)
dataloader
=
json_loads
(
json_dumps
(
dumped_ans
))
assert
isinstance
(
dataloader
,
DataLoader
)
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
blackbox
(
transforms
.
Compose
,
[
blackbox
(
transforms
.
ToTensor
),
blackbox
(
transforms
.
Normalize
,
(
0.1307
,),
(
0.3081
,))]
))
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))]))
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
def
test_type
():
assert
json_dumps
(
torch
.
optim
.
Adam
)
==
'{"__typename__": "torch.optim.adam.Adam"}'
assert
json_loads
(
'{"__typename__": "torch.optim.adam.Adam"}'
)
==
torch
.
optim
.
Adam
assert
re
.
match
(
r
'{"__typename__": "(.*)test_serializer.Foo"}'
,
json_dumps
(
Foo
))
if
__name__
==
'__main__'
:
test_blackbox
()
test_blackbox_module
()
test_dataset
()
test_type
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment