Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
445e7e0b
"tests/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "6906f72a805b2ba3057af462bc2c5214f2d87cd5"
Unverified
Commit
445e7e0b
authored
Feb 14, 2021
by
Yuge Zhang
Committed by
GitHub
Feb 14, 2021
Browse files
[Retiarii] Rewrite trainer with PyTorch Lightning (#3359)
parent
137830df
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
279 additions
and
67 deletions
+279
-67
test/retiarii_test/darts/test_oneshot.py
test/retiarii_test/darts/test_oneshot.py
+1
-1
test/retiarii_test/mnasnet/mutator.py
test/retiarii_test/mnasnet/mutator.py
+2
-4
test/retiarii_test/mnasnet/test.py
test/retiarii_test/mnasnet/test.py
+27
-12
test/retiarii_test/mnist/test.py
test/retiarii_test/mnist/test.py
+12
-7
test/retiarii_test/simple_strategy.py
test/retiarii_test/simple_strategy.py
+0
-41
test/ut/retiarii/imported/model.py
test/ut/retiarii/imported/model.py
+13
-0
test/ut/retiarii/mnist-tensorflow.json
test/ut/retiarii/mnist-tensorflow.json
+1
-2
test/ut/retiarii/test_lightning_trainer.py
test/ut/retiarii/test_lightning_trainer.py
+130
-0
test/ut/retiarii/test_serializer.py
test/ut/retiarii/test_serializer.py
+93
-0
No files found.
test/retiarii_test/darts/test_oneshot.py
View file @
445e7e0b
...
@@ -8,7 +8,7 @@ from pathlib import Path
...
@@ -8,7 +8,7 @@ from pathlib import Path
from
torchvision
import
transforms
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
from
torchvision.datasets
import
CIFAR10
from
nni.retiarii.experiment
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.experiment
.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.strategies
import
TPEStrategy
from
nni.retiarii.strategies
import
TPEStrategy
from
nni.retiarii.trainer.pytorch
import
DartsTrainer
from
nni.retiarii.trainer.pytorch
import
DartsTrainer
...
...
test/retiarii_test/mnasnet/mutator.py
View file @
445e7e0b
...
@@ -27,14 +27,12 @@ class BlockMutator(Mutator):
...
@@ -27,14 +27,12 @@ class BlockMutator(Mutator):
n_filter
=
self
.
choice
(
related_info
[
'n_filter_options'
])
n_filter
=
self
.
choice
(
related_info
[
'n_filter_options'
])
if
related_info
[
'in_ch'
]
is
not
None
:
if
related_info
[
'in_ch'
]
is
not
None
:
_logger
.
info
(
'zql debug X ...'
)
in_ch
=
related_info
[
'in_ch'
]
in_ch
=
related_info
[
'in_ch'
]
else
:
else
:
assert
len
(
node
.
predecessors
)
==
1
assert
len
(
node
.
predecessors
)
==
1
the_node
=
node
.
predecessors
[
0
]
the_node
=
node
.
predecessors
[
0
]
_logger
.
info
(
'zql debug ...'
)
_logger
.
debug
(
repr
(
the_node
.
operation
.
parameters
))
_logger
.
info
(
the_node
.
operation
.
parameters
)
_logger
.
debug
(
the_node
.
__repr__
())
_logger
.
info
(
the_node
.
__repr__
())
in_ch
=
the_node
.
operation
.
parameters
[
'out_ch'
]
in_ch
=
the_node
.
operation
.
parameters
[
'out_ch'
]
# update the placeholder to be a new operation
# update the placeholder to be a new operation
...
...
test/retiarii_test/mnasnet/test.py
View file @
445e7e0b
...
@@ -5,10 +5,14 @@ from pathlib import Path
...
@@ -5,10 +5,14 @@ from pathlib import Path
from
nni.retiarii.trainer.pytorch
import
PyTorchImageClassificationTrainer
from
nni.retiarii.trainer.pytorch
import
PyTorchImageClassificationTrainer
import
nni.retiarii.trainer.pytorch.lightning
as
pl
from
nni.retiarii
import
blackbox_module
as
bm
from
base_mnasnet
import
MNASNet
from
base_mnasnet
import
MNASNet
from
nni.retiarii.experiment
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.strategies
import
TPEStrategy
from
nni.retiarii.strategies
import
TPEStrategy
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
from
mutator
import
BlockMutator
from
mutator
import
BlockMutator
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -20,16 +24,27 @@ if __name__ == '__main__':
...
@@ -20,16 +24,27 @@ if __name__ == '__main__':
base_model
=
MNASNet
(
0.5
,
_DEFAULT_DEPTHS
,
_DEFAULT_CONVOPS
,
_DEFAULT_KERNEL_SIZES
,
base_model
=
MNASNet
(
0.5
,
_DEFAULT_DEPTHS
,
_DEFAULT_CONVOPS
,
_DEFAULT_KERNEL_SIZES
,
_DEFAULT_NUM_LAYERS
,
_DEFAULT_SKIPS
)
_DEFAULT_NUM_LAYERS
,
_DEFAULT_SKIPS
)
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"CIFAR10"
,
dataset_kwargs
=
{
"root"
:
"data/cifar10"
,
"download"
:
True
},
train_transform
=
transforms
.
Compose
([
dataloader_kwargs
=
{
"batch_size"
:
32
},
transforms
.
RandomCrop
(
32
,
padding
=
4
),
optimizer_kwargs
=
{
"lr"
:
1e-3
},
transforms
.
RandomHorizontalFlip
(),
trainer_kwargs
=
{
"max_epochs"
:
1
})
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
# new interface
])
applied_mutators
=
[]
valid_transform
=
transforms
.
Compose
([
applied_mutators
.
append
(
BlockMutator
(
'mutable_0'
))
transforms
.
ToTensor
(),
applied_mutators
.
append
(
BlockMutator
(
'mutable_1'
))
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
])
train_dataset
=
bm
(
CIFAR10
)(
root
=
'data/cifar10'
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
test_dataset
=
bm
(
CIFAR10
)(
root
=
'data/cifar10'
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
1
,
limit_train_batches
=
0.2
)
applied_mutators
=
[
BlockMutator
(
'mutable_0'
),
BlockMutator
(
'mutable_1'
)
]
simple_startegy
=
TPEStrategy
()
simple_startegy
=
TPEStrategy
()
...
...
test/retiarii_test/mnist/test.py
View file @
445e7e0b
import
random
import
random
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.trainer.pytorch.lightning
as
pl
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
nni.retiarii.experiment
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii
import
blackbox_module
as
bm
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.strategies
import
RandomStrategy
from
nni.retiarii.strategies
import
RandomStrategy
from
nni.retiarii.trainer.pytorch
import
PyTorchImageClassificationTrainer
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
...
@@ -31,11 +35,12 @@ class Net(nn.Module):
...
@@ -31,11 +35,12 @@ class Net(nn.Module):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
base_model
=
Net
(
128
)
base_model
=
Net
(
128
)
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"MNIST"
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
dataset_kwargs
=
{
"root"
:
"data/mnist"
,
"download"
:
True
},
train_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
dataloader_kwargs
=
{
"batch_size"
:
32
},
test_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
optimizer_kwargs
=
{
"lr"
:
1e-3
},
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
trainer_kwargs
=
{
"max_epochs"
:
1
})
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
)
simple_startegy
=
RandomStrategy
()
simple_startegy
=
RandomStrategy
()
...
...
test/retiarii_test/simple_strategy.py
deleted
100644 → 0
View file @
137830df
import
json
import
logging
import
random
import
os
from
nni.retiarii
import
Model
,
submit_models
,
wait_models
from
nni.retiarii.strategy
import
BaseStrategy
from
nni.retiarii
import
Sampler
_logger
=
logging
.
getLogger
(
__name__
)
class
RandomSampler
(
Sampler
):
def
choice
(
self
,
candidates
,
mutator
,
model
,
index
):
return
random
.
choice
(
candidates
)
class
SimpleStrategy
(
BaseStrategy
):
def
__init__
(
self
):
self
.
name
=
''
def
run
(
self
,
base_model
,
applied_mutators
,
trainer
):
try
:
_logger
.
info
(
'stargety start...'
)
while
True
:
model
=
base_model
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'mutators: {}'
.
format
(
applied_mutators
))
random_sampler
=
RandomSampler
()
for
mutator
in
applied_mutators
:
_logger
.
info
(
'mutate model...'
)
mutator
.
bind_sampler
(
random_sampler
)
model
=
mutator
.
apply
(
model
)
# get and apply training approach
_logger
.
info
(
'apply training approach...'
)
model
.
apply_trainer
(
trainer
[
'modulename'
],
trainer
[
'args'
])
# run models
submit_models
(
model
)
wait_models
(
model
)
_logger
.
info
(
'Strategy says:'
,
model
.
metric
)
except
Exception
as
e
:
_logger
.
error
(
logging
.
exception
(
'message'
))
test/ut/retiarii/imported/model.py
0 → 100644
View file @
445e7e0b
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_module
@
blackbox_module
class
ImportTest
(
nn
.
Module
):
def
__init__
(
self
,
foo
,
bar
):
super
().
__init__
()
self
.
foo
=
nn
.
Linear
(
foo
,
3
)
self
.
bar
=
nn
.
Dropout
(
bar
)
def
__eq__
(
self
,
other
):
return
self
.
foo
.
in_features
==
other
.
foo
.
in_features
and
self
.
bar
.
p
==
other
.
bar
.
p
test/ut/retiarii/mnist-tensorflow.json
View file @
445e7e0b
...
@@ -39,7 +39,6 @@
...
@@ -39,7 +39,6 @@
},
},
"_training_config"
:
{
"_training_config"
:
{
"module"
:
"_debug_no_trainer"
,
"__type__"
:
"_debug_no_trainer"
"kwargs"
:
{}
}
}
}
}
test/ut/retiarii/test_lightning_trainer.py
0 → 100644
View file @
445e7e0b
import
json
import
pytest
import
nni
import
nni.retiarii.trainer.pytorch.lightning
as
pl
import
pytorch_lightning
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.retiarii
import
blackbox_module
as
bm
from
nni.retiarii.trainer
import
FunctionalTrainer
from
sklearn.datasets
import
load_diabetes
from
torch.utils.data
import
Dataset
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
debug
=
False
progress_bar_refresh_rate
=
0
if
debug
:
progress_bar_refresh_rate
=
1
class
MNISTModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layer_1
=
nn
.
Linear
(
28
*
28
,
128
)
self
.
layer_2
=
nn
.
Linear
(
128
,
10
)
def
forward
(
self
,
x
):
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
self
.
layer_1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
layer_2
(
x
)
return
x
class
FCNet
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
output_size
):
super
().
__init__
()
self
.
l1
=
nn
.
Linear
(
input_size
,
5
)
self
.
relu
=
nn
.
ReLU
()
self
.
l2
=
nn
.
Linear
(
5
,
output_size
)
def
forward
(
self
,
x
):
output
=
self
.
l1
(
x
)
output
=
self
.
relu
(
output
)
output
=
self
.
l2
(
output
)
return
output
.
view
(
-
1
)
@
bm
class
DiabetesDataset
(
Dataset
):
def
__init__
(
self
,
train
=
True
):
data
=
load_diabetes
()
self
.
x
=
torch
.
tensor
(
data
[
'data'
],
dtype
=
torch
.
float32
)
self
.
y
=
torch
.
tensor
(
data
[
'target'
],
dtype
=
torch
.
float32
)
self
.
length
=
self
.
x
.
shape
[
0
]
split
=
int
(
self
.
length
*
0.8
)
if
train
:
self
.
x
=
self
.
x
[:
split
]
self
.
y
=
self
.
y
[:
split
]
else
:
self
.
x
=
self
.
x
[
split
:]
self
.
y
=
self
.
y
[
split
:]
self
.
length
=
len
(
self
.
y
)
def
__getitem__
(
self
,
idx
):
return
self
.
x
[
idx
],
self
.
y
[
idx
]
def
__len__
(
self
):
return
self
.
length
def
_get_final_result
():
return
float
(
json
.
loads
(
nni
.
runtime
.
platform
.
test
.
_last_metric
)[
'value'
])
def
_foo
(
model_cls
):
assert
model_cls
==
MNISTModel
def
_reset
():
# this is to not affect other tests in sdk
nni
.
trial
.
_intermediate_seq
=
0
nni
.
trial
.
_params
=
{
'foo'
:
'bar'
,
'parameter_id'
:
0
}
nni
.
runtime
.
platform
.
test
.
_last_metric
=
None
@
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs.'
)
def
test_mnist
():
_reset
()
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
bm
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
,
limit_train_batches
=
0.25
,
# for faster training
progress_bar_refresh_rate
=
progress_bar_refresh_rate
)
lightning
.
_execute
(
MNISTModel
)
assert
_get_final_result
()
>
0.7
_reset
()
@
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs.'
)
def
test_diabetes
():
_reset
()
nni
.
trial
.
_params
=
{
'foo'
:
'bar'
,
'parameter_id'
:
0
}
nni
.
runtime
.
platform
.
test
.
_last_metric
=
None
train_dataset
=
DiabetesDataset
(
train
=
True
)
test_dataset
=
DiabetesDataset
(
train
=
False
)
lightning
=
pl
.
Regression
(
optimizer
=
torch
.
optim
.
SGD
,
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
20
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
20
),
max_epochs
=
100
,
progress_bar_refresh_rate
=
progress_bar_refresh_rate
)
lightning
.
_execute
(
FCNet
(
train_dataset
.
x
.
shape
[
1
],
1
))
assert
_get_final_result
()
<
2e4
_reset
()
@
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs.'
)
def
test_functional
():
FunctionalTrainer
(
_foo
).
_execute
(
MNISTModel
)
if
__name__
==
'__main__'
:
test_mnist
()
test_diabetes
()
test_functional
()
test/ut/retiarii/test_serializer.py
0 → 100644
View file @
445e7e0b
import
json
from
pathlib
import
Path
import
re
import
sys
import
torch
from
nni.retiarii
import
json_dumps
,
json_loads
,
blackbox
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
sys
.
path
.
insert
(
0
,
Path
(
__file__
).
parent
.
as_posix
())
from
imported.model
import
ImportTest
class
Foo
:
def
__init__
(
self
,
a
,
b
=
1
):
self
.
aa
=
a
self
.
bb
=
[
b
+
1
for
_
in
range
(
1000
)]
def
__eq__
(
self
,
other
):
return
self
.
aa
==
other
.
aa
and
self
.
bb
==
other
.
bb
def
test_blackbox
():
module
=
blackbox
(
Foo
,
3
)
assert
json_loads
(
json_dumps
(
module
))
==
module
module
=
blackbox
(
Foo
,
b
=
2
,
a
=
1
)
assert
json_loads
(
json_dumps
(
module
))
==
module
module
=
blackbox
(
Foo
,
Foo
(
1
),
5
)
dumped_module
=
json_dumps
(
module
)
assert
len
(
dumped_module
)
>
200
# should not be too longer if the serialization is correct
module
=
blackbox
(
Foo
,
blackbox
(
Foo
,
1
),
5
)
dumped_module
=
json_dumps
(
module
)
assert
len
(
dumped_module
)
<
200
# should not be too longer if the serialization is correct
assert
json_loads
(
dumped_module
)
==
module
def
test_blackbox_module
():
module
=
ImportTest
(
3
,
0.5
)
assert
json_loads
(
json_dumps
(
module
))
==
module
def
test_dataset
():
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
)
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
dumped_ans
=
{
"__type__"
:
"torch.utils.data.dataloader.DataLoader"
,
"arguments"
:
{
"batch_size"
:
10
,
"dataset"
:
{
"__type__"
:
"torchvision.datasets.mnist.MNIST"
,
"arguments"
:
{
"root"
:
"data/mnist"
,
"train"
:
False
,
"download"
:
True
}
}
}
}
assert
json_dumps
(
dataloader
)
==
json_dumps
(
dumped_ans
)
dataloader
=
json_loads
(
json_dumps
(
dumped_ans
))
assert
isinstance
(
dataloader
,
DataLoader
)
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
blackbox
(
transforms
.
Compose
,
[
blackbox
(
transforms
.
ToTensor
),
blackbox
(
transforms
.
Normalize
,
(
0.1307
,),
(
0.3081
,))]
))
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
dataset
=
blackbox
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))]))
dataloader
=
blackbox
(
DataLoader
,
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
def
test_type
():
assert
json_dumps
(
torch
.
optim
.
Adam
)
==
'{"__typename__": "torch.optim.adam.Adam"}'
assert
json_loads
(
'{"__typename__": "torch.optim.adam.Adam"}'
)
==
torch
.
optim
.
Adam
assert
re
.
match
(
r
'{"__typename__": "(.*)test_serializer.Foo"}'
,
json_dumps
(
Foo
))
if
__name__
==
'__main__'
:
test_blackbox
()
test_blackbox_module
()
test_dataset
()
test_type
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment