Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
39ec21ca
"examples/vscode:/vscode.git/clone" did not exist on "34e0883db09c4a66827126aeac5cf0dd66b5f1ef"
Unverified
Commit
39ec21ca
authored
May 20, 2022
by
Frandium
Committed by
GitHub
May 20, 2022
Browse files
Multi-GPU support of one-shot NAS (#4603)
parent
b4559f60
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
382 additions
and
94 deletions
+382
-94
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+60
-24
nni/retiarii/oneshot/pytorch/base_lightning.py
nni/retiarii/oneshot/pytorch/base_lightning.py
+10
-8
nni/retiarii/oneshot/pytorch/dataloader.py
nni/retiarii/oneshot/pytorch/dataloader.py
+77
-0
nni/retiarii/oneshot/pytorch/differentiable.py
nni/retiarii/oneshot/pytorch/differentiable.py
+3
-2
nni/retiarii/oneshot/pytorch/sampling.py
nni/retiarii/oneshot/pytorch/sampling.py
+19
-14
nni/retiarii/oneshot/pytorch/strategy.py
nni/retiarii/oneshot/pytorch/strategy.py
+38
-28
nni/retiarii/oneshot/pytorch/utils.py
nni/retiarii/oneshot/pytorch/utils.py
+1
-0
nni/retiarii/strategy/tpe_strategy.py
nni/retiarii/strategy/tpe_strategy.py
+3
-2
test/ut/retiarii/test_oneshot.py
test/ut/retiarii/test_oneshot.py
+40
-16
test/ut/retiarii/test_oneshot_utils.py
test/ut/retiarii/test_oneshot_utils.py
+131
-0
No files found.
nni/retiarii/evaluator/pytorch/lightning.py
View file @
39ec21ca
...
...
@@ -4,7 +4,7 @@
import
os
import
warnings
from
pathlib
import
Path
from
typing
import
Dict
,
Union
,
Optional
,
List
,
Callable
,
Type
from
typing
import
Any
,
Dict
,
Union
,
Optional
,
List
,
Callable
,
Type
import
pytorch_lightning
as
pl
import
torch.nn
as
nn
...
...
@@ -22,6 +22,7 @@ except ImportError:
cgo_import_failed
=
True
from
nni.retiarii.graph
import
Evaluator
from
nni.typehint
import
Literal
__all__
=
[
'LightningModule'
,
'Trainer'
,
'DataLoader'
,
'Lightning'
,
'Classification'
,
'Regression'
]
...
...
@@ -36,6 +37,11 @@ class LightningModule(pl.LightningModule):
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
"""
running_mode
:
Literal
[
'multi'
,
'oneshot'
]
=
'multi'
"""An indicator of whether current module is running in a multi-trial experiment or an one-shot.
This flag should be automatically set by experiments when they start to run.
"""
def
set_model
(
self
,
model
:
Union
[
Callable
[[],
nn
.
Module
],
nn
.
Module
])
->
None
:
"""Set the inner model (architecture) to train / evaluate.
...
...
@@ -59,6 +65,7 @@ DataLoader.__doc__ = """
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html
"""
@
nni
.
trace
class
Lightning
(
Evaluator
):
"""
...
...
@@ -74,51 +81,67 @@ class Lightning(Evaluator):
Parameters
----------
lightning_module
: LightningModule
lightning_module
Lightning module that defines the training logic.
trainer
: Trainer
trainer
Lightning trainer that handles the training.
train_dataloders
: DataLoader
train_dataloders
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
val_dataloaders
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
"""
def
__init__
(
self
,
lightning_module
:
LightningModule
,
trainer
:
Trainer
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
):
train_dataloaders
:
Optional
[
Any
]
=
None
,
val_dataloaders
:
Optional
[
Any
]
=
None
,
train_dataloader
:
Optional
[
Any
]
=
None
):
assert
isinstance
(
lightning_module
,
LightningModule
),
f
'Lightning module must be an instance of
{
__name__
}
.LightningModule.'
if
train_dataloader
is
not
None
:
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
train_dataloaders
=
train_dataloader
if
cgo_import_failed
:
assert
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
),
f
'Trainer must be imported from
{
__name__
}
'
else
:
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert
(
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
))
or
isinstance
(
trainer
,
cgo_trainer
.
Trainer
),
\
f
'Trainer must be imported from
{
__name__
}
or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert
_check_dataloader
(
train_dataloader
),
f
'Wrong dataloader type. Try import DataLoader from
{
__name__
}
.'
assert
_check_dataloader
(
val_dataloaders
),
f
'Wrong dataloader type. Try import DataLoader from
{
__name__
}
.'
if
not
_check_dataloader
(
train_dataloaders
):
warnings
.
warn
(
f
'Please try to wrap PyTorch DataLoader with nni.trace or '
f
'import DataLoader from
{
__name__
}
:
{
train_dataloaders
}
'
,
RuntimeWarning
)
if
not
_check_dataloader
(
val_dataloaders
):
warnings
.
warn
(
f
'Please try to wrap PyTorch DataLoader with nni.trace or '
f
'import DataLoader from
{
__name__
}
:
{
val_dataloaders
}
'
,
RuntimeWarning
)
self
.
module
=
lightning_module
self
.
trainer
=
trainer
self
.
train_dataloader
=
train_dataloader
self
.
train_dataloader
s
=
train_dataloader
s
self
.
val_dataloaders
=
val_dataloaders
@
staticmethod
def
_load
(
ir
):
return
Lightning
(
ir
[
'module'
],
ir
[
'trainer'
],
ir
[
'train_dataloader'
],
ir
[
'val_dataloaders'
])
return
Lightning
(
ir
[
'module'
],
ir
[
'trainer'
],
ir
[
'train_dataloader
s
'
],
ir
[
'val_dataloaders'
])
def
_dump
(
self
):
return
{
'type'
:
self
.
__class__
,
'module'
:
self
.
module
,
'trainer'
:
self
.
trainer
,
'train_dataloader'
:
self
.
train_dataloader
,
'train_dataloader
s
'
:
self
.
train_dataloader
s
,
'val_dataloaders'
:
self
.
val_dataloaders
}
def
_execute
(
self
,
model_cls
):
return
self
.
fit
(
model_cls
)
@
property
def
train_dataloader
(
self
):
warnings
.
warn
(
'train_dataloader is deprecated, please use `train_dataloaders`.'
,
DeprecationWarning
)
def
__eq__
(
self
,
other
):
eq_func
=
False
eq_args
=
False
...
...
@@ -146,15 +169,18 @@ class Lightning(Evaluator):
The model to fit.
"""
self
.
module
.
set_model
(
model
)
return
self
.
trainer
.
fit
(
self
.
module
,
self
.
train_dataloader
,
self
.
val_dataloaders
)
return
self
.
trainer
.
fit
(
self
.
module
,
self
.
train_dataloader
s
,
self
.
val_dataloaders
)
def
_check_dataloader
(
dataloader
):
if
dataloader
is
None
:
return
True
# Check the type of dataloader recursively.
if
isinstance
(
dataloader
,
list
):
return
all
([
_check_dataloader
(
d
)
for
d
in
dataloader
])
return
isinstance
(
dataloader
,
torch_data
.
DataLoader
)
and
is_traceable
(
dataloader
)
if
isinstance
(
dataloader
,
dict
):
return
all
([
_check_dataloader
(
v
)
for
v
in
dataloader
.
values
()])
if
isinstance
(
dataloader
,
torch_data
.
DataLoader
):
return
is_traceable
(
dataloader
)
return
True
### The following are some commonly used Lightning modules ###
...
...
@@ -176,7 +202,6 @@ class _SupervisedLearningModule(LightningModule):
if
export_onnx
is
None
or
export_onnx
is
True
:
self
.
export_onnx
=
Path
(
os
.
environ
.
get
(
'NNI_OUTPUT_DIR'
,
'.'
))
/
'model.onnx'
self
.
export_onnx
.
parent
.
mkdir
(
exist_ok
=
True
)
elif
export_onnx
:
self
.
export_onnx
=
Path
(
export_onnx
)
else
:
...
...
@@ -199,7 +224,8 @@ class _SupervisedLearningModule(LightningModule):
x
,
y
=
batch
y_hat
=
self
(
x
)
if
self
.
export_onnx
is
not
None
:
if
self
.
running_mode
==
'multi'
and
self
.
export_onnx
is
not
None
:
self
.
export_onnx
.
parent
.
mkdir
(
exist_ok
=
True
)
try
:
self
.
to_onnx
(
self
.
export_onnx
,
x
,
export_params
=
True
)
except
RuntimeError
as
e
:
...
...
@@ -221,9 +247,11 @@ class _SupervisedLearningModule(LightningModule):
return
self
.
optimizer
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
)
# type: ignore
def
on_validation_epoch_end
(
self
):
if
self
.
running_mode
==
'multi'
:
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
def
on_fit_end
(
self
):
if
self
.
running_mode
==
'multi'
:
nni
.
report_final_result
(
self
.
_get_validation_metrics
())
def
_get_validation_metrics
(
self
):
...
...
@@ -283,14 +311,18 @@ class Classification(Lightning):
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
train_dataloader
s
:
Optional
[
DataLoader
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
export_onnx
:
bool
=
True
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
**
trainer_kwargs
):
if
train_dataloader
is
not
None
:
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
train_dataloaders
=
train_dataloader
module
=
_ClassificationModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
train_dataloader
=
train_dataloader
,
val_dataloaders
=
val_dataloaders
)
train_dataloader
s
=
train_dataloader
s
,
val_dataloaders
=
val_dataloaders
)
@
nni
.
trace
...
...
@@ -336,11 +368,15 @@ class Regression(Lightning):
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
train_dataloader
s
:
Optional
[
DataLoader
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
export_onnx
:
bool
=
True
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
**
trainer_kwargs
):
if
train_dataloader
is
not
None
:
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
train_dataloaders
=
train_dataloader
module
=
_RegressionModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
train_dataloader
=
train_dataloader
,
val_dataloaders
=
val_dataloaders
)
train_dataloader
s
=
train_dataloader
s
,
val_dataloaders
=
val_dataloaders
)
nni/retiarii/oneshot/pytorch/base_lightning.py
View file @
39ec21ca
...
...
@@ -18,6 +18,7 @@ import nni.retiarii.nn.pytorch as nas_nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.serializer
import
is_traceable
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
nni.typehint
import
Literal
from
.supermodule.base
import
BaseSuperNetModule
__all__
=
[
'MutationHook'
,
'BaseSuperNetModule'
,
'BaseOneShotLightningModule'
,
'traverse_and_mutate_submodules'
]
...
...
@@ -334,21 +335,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
return
arc_optimizers
+
w_optimizers
,
lr_schedulers
def
on_train_start
(
self
):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self
.
model
.
trainer
=
self
.
trainer
# type: ignore
self
.
model
.
log
=
self
.
log
return
self
.
model
.
on_train_start
()
def
on_train_end
(
self
):
return
self
.
model
.
on_train_end
()
def
on_fit_start
(
self
):
return
self
.
model
.
on_train_start
()
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self
.
model
.
trainer
=
self
.
trainer
# type: ignore
self
.
model
.
log
=
self
.
log
return
self
.
model
.
on_fit_start
()
def
on_fit_end
(
self
):
return
self
.
model
.
on_
train
_end
()
return
self
.
model
.
on_
fit
_end
()
def
on_train_batch_start
(
self
,
batch
,
batch_idx
,
unused
=
0
):
return
self
.
model
.
on_train_batch_start
(
batch
,
batch_idx
,
unused
)
...
...
@@ -356,6 +357,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
def
on_train_batch_end
(
self
,
outputs
,
batch
,
batch_idx
,
unused
=
0
):
return
self
.
model
.
on_train_batch_end
(
outputs
,
batch
,
batch_idx
,
unused
)
# Deprecated hooks in pytorch-lightning
def
on_epoch_start
(
self
):
return
self
.
model
.
on_epoch_start
()
...
...
@@ -427,7 +429,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
else
:
apply
(
lr_schedulers
)
def
call_weight_optimizers
(
self
,
method
):
def
call_weight_optimizers
(
self
,
method
:
Literal
[
'step'
,
'zero_grad'
]
):
"""
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
...
...
nni/retiarii/oneshot/pytorch/dataloader.py
0 → 100644
View file @
39ec21ca
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
from
typing
import
Any
from
pytorch_lightning.trainer.supporters
import
CombinedLoader
,
CombinedLoaderIterator
class
ConcatLoader
(
CombinedLoader
):
"""This loader is same as CombinedLoader in PyTorch-Lightning, but concatenate sub-loaders
instead of loading them in parallel.
Parameters
----------
loaders
For example, ::
{
"train": DataLoader(train_dataset),
"val": DataLoader(val_dataset)
}
In this example, the loader will first produce the batches from "train", then "val".
mode
Only support "min_size" for now.
"""
def
__init__
(
self
,
loaders
:
dict
[
str
,
Any
],
mode
:
str
=
'min_size'
):
# FIXME: max_cycle will make dataloaders cycle iterators,
# causing extra problems.
if
mode
!=
'min_size'
:
raise
ValueError
(
'Only min_size mode is supported now.'
)
super
().
__init__
(
loaders
,
mode
)
def
__iter__
(
self
)
->
Any
:
"""Replace the super-class iterator with ours."""
self
.
_try_to_patch_pytorch_dataloader
()
iterator
=
ConcatLoaderIterator
(
self
.
loaders
)
# handle fault tolerant restart.
self
.
on_restart
(
iterator
)
self
.
_iterator
=
iterator
return
iterator
@
staticmethod
def
_try_to_patch_pytorch_dataloader
():
"""Copied from CombinedLoader."""
from
torch.utils.data.dataloader
import
_BaseDataLoaderIter
# prevent `NotImplementedError` from PyTorch:
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
def
__getstate__patch__
(
*
_
):
return
{}
_BaseDataLoaderIter
.
__getstate__
=
__getstate__patch__
# type: ignore
def
__len__
(
self
)
->
int
:
return
int
(
sum
(
self
.
_calc_num_batches
(
loader
)
for
loader
in
self
.
loaders
.
values
()))
class
ConcatLoaderIterator
(
CombinedLoaderIterator
):
"""Similar to CombinedLoaderIterator in Lightning, but in a concat manner."""
def
__next__
(
self
)
->
Any
:
"""Fetches the next batch from multiple data loaders,
by looking for the first iterator that isn't exhausted yet.
"""
if
not
len
(
self
.
loader_iters
)
==
len
(
self
.
loaders
):
raise
RuntimeError
(
'loader_iters must have the same length as loaders.'
)
for
i
,
(
loader_name
,
iterator
)
in
enumerate
(
self
.
loader_iters
.
items
()):
try
:
return
(
self
.
request_next_batch
(
iterator
),
loader_name
)
except
StopIteration
:
if
i
+
1
==
len
(
self
.
loader_iters
):
raise
nni/retiarii/oneshot/pytorch/differentiable.py
View file @
39ec21ca
...
...
@@ -75,8 +75,9 @@ class DartsLightningModule(BaseOneShotLightningModule):
if
not
isinstance
(
arc_optim
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_optim to be a single Optimizer, but found:
{
arc_optim
}
'
)
# The InterleavedTrainValDataLoader yields both train and val data in a batch
trn_batch
,
val_batch
=
batch
# DARTS strategy makes sure that ``train`` and ``val`` must be in the batch
trn_batch
=
batch
[
'train'
]
val_batch
=
batch
[
'val'
]
# phase 1: architecture step
# The _resample hook is kept for some darts-based NAS methods like proxyless.
...
...
nni/retiarii/oneshot/pytorch/sampling.py
View file @
39ec21ca
...
...
@@ -133,29 +133,30 @@ class EnasLightningModule(RandomSamplingLightningModule):
def
configure_architecture_optimizers
(
self
):
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
def
training_step
(
self
,
batch
,
batch_idx
):
# The ConcatenateTrainValDataloader yields both data and which dataloader it comes from.
batch
,
source
=
batch
def
training_step
(
self
,
batch_packed
,
batch_idx
):
batch
,
mode
=
batch_packed
if
source
==
'train'
:
# step 1: train model params
if
mode
==
'train'
:
# train model params
with
torch
.
no_grad
():
self
.
resample
()
self
.
call_weight_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
step_output
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
w_step_loss
=
step_output
[
'loss'
]
\
if
isinstance
(
step_output
,
dict
)
else
step_output
self
.
manual_backward
(
w_step_loss
)
self
.
call_weight_optimizers
(
'step'
)
return
loss_and_metrics
if
source
==
'val'
:
#
step 2:
train ENAS agent
else
:
# train ENAS agent
arc_opt
=
self
.
architecture_optimizers
()
if
not
isinstance
(
arc_opt
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_opt to be a single Optimizer, but found:
{
arc_opt
}
'
)
arc_opt
.
zero_grad
()
self
.
resample
()
self
.
model
.
validation_step
(
batch
,
batch_idx
)
step_output
=
self
.
model
.
validation_step
(
batch
,
batch_idx
)
# use the default metric of self.model as reward function
if
len
(
self
.
trainer
.
callback_metrics
)
==
1
:
_
,
metric
=
next
(
iter
(
self
.
trainer
.
callback_metrics
.
items
()))
...
...
@@ -163,7 +164,9 @@ class EnasLightningModule(RandomSamplingLightningModule):
metric_name
=
self
.
reward_metric_name
or
'default'
if
metric_name
not
in
self
.
trainer
.
callback_metrics
:
raise
KeyError
(
f
'Model reported metrics should contain a ``
{
metric_name
}
`` key but '
f
'found multiple metrics without default:
{
self
.
trainer
.
callback_metrics
.
keys
()
}
'
)
f
'found multiple (or zero) metrics without default:
{
list
(
self
.
trainer
.
callback_metrics
.
keys
())
}
. '
f
'Try to use self.log to report metrics with the specified key ``
{
metric_name
}
`` in validation_step, '
'and remember to set on_step=True.'
)
metric
=
self
.
trainer
.
callback_metrics
[
metric_name
]
reward
:
float
=
metric
.
item
()
...
...
@@ -183,6 +186,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
arc_opt
.
step
()
arc_opt
.
zero_grad
()
return
step_output
def
resample
(
self
):
"""Resample the architecture with ENAS controller."""
sample
=
self
.
controller
.
resample
()
...
...
nni/retiarii/oneshot/pytorch/strategy.py
View file @
39ec21ca
...
...
@@ -16,7 +16,6 @@ import warnings
from
typing
import
Any
,
Type
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
nni.retiarii.graph
import
Model
from
nni.retiarii.strategy.base
import
BaseStrategy
...
...
@@ -25,7 +24,6 @@ from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule
from
.base_lightning
import
BaseOneShotLightningModule
from
.differentiable
import
DartsLightningModule
,
ProxylessLightningModule
,
GumbelDartsLightningModule
from
.sampling
import
EnasLightningModule
,
RandomSamplingLightningModule
from
.utils
import
InterleavedTrainValDataLoader
,
ConcatenateTrainValDataLoader
class
OneShotStrategy
(
BaseStrategy
):
...
...
@@ -37,15 +35,18 @@ class OneShotStrategy(BaseStrategy):
self
.
model
:
BaseOneShotLightningModule
|
None
=
None
def
_get_dataloader
(
self
,
train_dataloader
:
DataLoader
,
val_dataloaders
:
DataLoader
|
list
[
DataLoader
])
\
->
DataLoader
|
tuple
[
DataLoader
,
DataLoader
]:
def
preprocess_dataloader
(
self
,
train_dataloaders
:
Any
,
val_dataloaders
:
Any
)
->
tuple
[
Any
,
Any
]:
"""
One-shot strategy typically requires a customized dataloader.
If only train dataloader is produced, return one dataloader.
Otherwise, return train dataloader and valid loader as a tuple.
One-shot strategy typically requires fusing train and validation dataloader in an ad-hoc way.
As one-shot strategy doesn't try to open the blackbox of a batch,
theoretically, these dataloader can be
`any dataloader types supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
Returns
-------
A tuple of preprocessed train dataloaders and validation dataloaders.
"""
r
aise
NotImplementedError
()
r
eturn
train_dataloaders
,
val_dataloaders
def
run
(
self
,
base_model
:
Model
,
applied_mutators
):
# one-shot strategy doesn't use ``applied_mutators``
...
...
@@ -64,18 +65,15 @@ class OneShotStrategy(BaseStrategy):
raise
TypeError
(
'Evaluator needs to be a lightning evaluator to make one-shot strategy work.'
)
evaluator_module
:
LightningModule
=
base_model
.
evaluator
.
module
evaluator_module
.
running_mode
=
'oneshot'
evaluator_module
.
set_model
(
py_model
)
self
.
model
=
self
.
oneshot_module
(
evaluator_module
,
**
self
.
oneshot_kwargs
)
evaluator
:
Lightning
=
base_model
.
evaluator
if
evaluator
.
train_dataloader
is
None
or
evaluator
.
val_dataloaders
is
None
:
raise
TypeError
(
'Train or val dataloader is not set.'
)
dataloader
=
self
.
_get_dataloader
(
evaluator
.
train_dataloader
,
evaluator
.
val_dataloaders
)
if
isinstance
(
dataloader
,
tuple
):
dataloader
,
val_loader
=
dataloader
evaluator
.
trainer
.
fit
(
self
.
model
,
dataloader
,
val_loader
)
else
:
evaluator
.
trainer
.
fit
(
self
.
model
,
dataloader
)
if
evaluator
.
train_dataloaders
is
None
or
evaluator
.
val_dataloaders
is
None
:
raise
TypeError
(
'Training and validation dataloader are both required to set in evaluator for one-shot strategy.'
)
train_loader
,
val_loader
=
self
.
preprocess_dataloader
(
evaluator
.
train_dataloaders
,
evaluator
.
val_dataloaders
)
evaluator
.
trainer
.
fit
(
self
.
model
,
train_loader
,
val_loader
)
def
export_top_models
(
self
,
top_k
:
int
=
1
)
->
list
[
Any
]:
if
self
.
model
is
None
:
...
...
@@ -91,8 +89,12 @@ class DARTS(OneShotStrategy):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
DartsLightningModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
def
preprocess_dataloader
(
self
,
train_dataloaders
,
val_dataloaders
):
# By returning a dict, we make a CombinedLoader (in Lightning)
return
{
'train'
:
train_dataloaders
,
'val'
:
val_dataloaders
},
None
class
Proxyless
(
OneShotStrategy
):
...
...
@@ -101,8 +103,11 @@ class Proxyless(OneShotStrategy):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
ProxylessLightningModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
def
preprocess_dataloader
(
self
,
train_dataloaders
,
val_dataloaders
):
return
{
'train'
:
train_dataloaders
,
'val'
:
val_dataloaders
},
None
class
GumbelDARTS
(
OneShotStrategy
):
...
...
@@ -111,8 +116,11 @@ class GumbelDARTS(OneShotStrategy):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
GumbelDartsLightningModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
def
preprocess_dataloader
(
self
,
train_dataloaders
,
val_dataloaders
):
return
{
'train'
:
train_dataloaders
,
'val'
:
val_dataloaders
},
None
class
ENAS
(
OneShotStrategy
):
...
...
@@ -121,8 +129,13 @@ class ENAS(OneShotStrategy):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
EnasLightningModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
ConcatenateTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
def
preprocess_dataloader
(
self
,
train_dataloaders
,
val_dataloaders
):
# Import locally to avoid import error on legacy PL version
from
.dataloader
import
ConcatLoader
return
ConcatLoader
({
'train'
:
train_dataloaders
,
'val'
:
val_dataloaders
}),
None
class
RandomOneShot
(
OneShotStrategy
):
...
...
@@ -130,6 +143,3 @@ class RandomOneShot(OneShotStrategy):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
RandomSamplingLightningModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
train_dataloader
,
val_dataloaders
nni/retiarii/oneshot/pytorch/utils.py
View file @
39ec21ca
...
...
@@ -132,6 +132,7 @@ class AverageMeter:
def
_replace_module_with_type
(
root_module
,
init_fn
,
type_name
,
modules
):
if
modules
is
None
:
modules
=
[]
def
apply
(
m
):
for
name
,
child
in
m
.
named_children
():
if
isinstance
(
child
,
type_name
):
...
...
nni/retiarii/strategy/tpe_strategy.py
View file @
39ec21ca
...
...
@@ -5,8 +5,6 @@ import logging
import
time
from
typing
import
Optional
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
,
budget_exhausted
from
.base
import
BaseStrategy
...
...
@@ -15,6 +13,9 @@ _logger = logging.getLogger(__name__)
class
TPESampler
(
Sampler
):
def
__init__
(
self
,
optimize_mode
=
'minimize'
):
# Move import here to eliminate some warning messages about dill.
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
self
.
tpe_tuner
=
HyperoptTuner
(
'tpe'
,
optimize_mode
)
self
.
cur_sample
:
Optional
[
dict
]
=
None
self
.
index
:
Optional
[
int
]
=
None
...
...
test/ut/retiarii/test_oneshot.py
View file @
39ec21ca
...
...
@@ -15,6 +15,9 @@ from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice
from
nni.retiarii.strategy
import
BaseStrategy
pytestmark
=
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
class
DepthwiseSeparableConv
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
,
out_ch
):
super
().
__init__
()
...
...
@@ -171,7 +174,7 @@ class CustomOpValueChoiceNet(nn.Module):
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
_mnist_net
(
type_
):
def
_mnist_net
(
type_
,
evaluator_kwargs
):
if
type_
==
'simple'
:
base_model
=
SimpleNet
(
False
)
elif
type_
==
'simple_value_choice'
:
...
...
@@ -187,17 +190,18 @@ def _mnist_net(type_):
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
MNIST
(
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
# Multi-GPU combined dataloader will break this subset sampler. Expected though.
train_random_sampler
=
RandomSampler
(
train_dataset
,
True
,
int
(
len
(
train_dataset
)
/
20
))
train_loader
=
DataLoader
(
train_dataset
,
64
,
sampler
=
train_random_sampler
)
valid_dataset
=
MNIST
(
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
valid_random_sampler
=
RandomSampler
(
valid_dataset
,
True
,
int
(
len
(
valid_dataset
)
/
20
))
valid_loader
=
DataLoader
(
valid_dataset
,
64
,
sampler
=
valid_random_sampler
)
evaluator
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
max_epochs
=
1
)
evaluator
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
evaluator_kwargs
)
return
base_model
,
evaluator
def
_multihead_attention_net
():
def
_multihead_attention_net
(
evaluator_kwargs
):
base_model
=
MultiHeadAttentionNet
(
1
)
class
AttentionRandDataset
(
Dataset
):
...
...
@@ -222,19 +226,29 @@ def _multihead_attention_net():
train_loader
=
DataLoader
(
train_set
,
batch_size
=
32
)
val_loader
=
DataLoader
(
val_set
,
batch_size
=
32
)
evaluator
=
Regression
(
train_dataloader
=
train_loader
,
val_dataloaders
=
val_loader
,
max_epochs
=
1
)
evaluator
=
Regression
(
train_dataloader
=
train_loader
,
val_dataloaders
=
val_loader
,
**
evaluator_kwargs
)
return
base_model
,
evaluator
def
_test_strategy
(
strategy_
,
support_value_choice
=
True
):
def
_test_strategy
(
strategy_
,
support_value_choice
=
True
,
multi_gpu
=
False
):
evaluator_kwargs
=
{
'max_epochs'
:
1
}
if
multi_gpu
:
evaluator_kwargs
.
update
(
strategy
=
'ddp'
,
accelerator
=
'gpu'
,
devices
=
torch
.
cuda
.
device_count
()
)
to_test
=
[
# (model, evaluator), support_or_net
(
_mnist_net
(
'simple'
),
True
),
(
_mnist_net
(
'simple_value_choice'
),
support_value_choice
),
(
_mnist_net
(
'value_choice'
),
support_value_choice
),
(
_mnist_net
(
'repeat'
),
False
),
# no strategy supports repeat currently
(
_mnist_net
(
'custom_op'
),
False
),
# this is definitely a NO
(
_multihead_attention_net
(),
support_value_choice
),
(
_mnist_net
(
'simple'
,
evaluator_kwargs
),
True
),
(
_mnist_net
(
'simple_value_choice'
,
evaluator_kwargs
),
support_value_choice
),
(
_mnist_net
(
'value_choice'
,
evaluator_kwargs
),
support_value_choice
),
(
_mnist_net
(
'repeat'
,
evaluator_kwargs
),
False
),
# no strategy supports repeat currently
(
_mnist_net
(
'custom_op'
,
evaluator_kwargs
),
False
),
# this is definitely a NO
(
_multihead_attention_net
(
evaluator_kwargs
),
support_value_choice
),
]
for
(
base_model
,
evaluator
),
support_or_not
in
to_test
:
...
...
@@ -256,17 +270,19 @@ def _test_strategy(strategy_, support_value_choice=True):
experiment
.
run
(
config
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_darts
():
_test_strategy
(
strategy
.
DARTS
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
device_count
()
<=
1
,
reason
=
'Must have multiple GPUs.'
)
def
test_darts_multi_gpu
():
_test_strategy
(
strategy
.
DARTS
(),
multi_gpu
=
True
)
def
test_proxyless
():
_test_strategy
(
strategy
.
Proxyless
(),
False
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_enas
():
def
strategy_fn
(
base_model
,
evaluator
):
if
isinstance
(
base_model
,
MultiHeadAttentionNet
):
...
...
@@ -276,12 +292,20 @@ def test_enas():
_test_strategy
(
strategy_fn
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
device_count
()
<=
1
,
reason
=
'Must have multiple GPUs.'
)
def
test_enas_multi_gpu
():
def
strategy_fn
(
base_model
,
evaluator
):
if
isinstance
(
base_model
,
MultiHeadAttentionNet
):
return
strategy
.
ENAS
(
reward_metric_name
=
'val_mse'
)
return
strategy
.
ENAS
(
reward_metric_name
=
'val_acc'
)
_test_strategy
(
strategy_fn
,
multi_gpu
=
True
)
def
test_random
():
_test_strategy
(
strategy
.
RandomOneShot
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_gumbel_darts
():
_test_strategy
(
strategy
.
GumbelDARTS
())
...
...
test/ut/retiarii/test_oneshot_utils.py
0 → 100644
View file @
39ec21ca
import
math
from
typing
import
Union
import
pytest
import
torch
import
pytorch_lightning
from
pytorch_lightning
import
LightningModule
,
Trainer
from
torch.utils.data
import
DataLoader
,
Dataset
pytestmark
=
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
class
RandomDataset
(
Dataset
):
def
__init__
(
self
,
size
,
length
):
self
.
len
=
length
self
.
data
=
torch
.
randn
(
length
,
size
)
def
__getitem__
(
self
,
index
):
return
self
.
data
[
index
]
def
__len__
(
self
):
return
self
.
len
class
BoringModel
(
LightningModule
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layer
=
torch
.
nn
.
Linear
(
32
,
2
)
def
forward
(
self
,
x
):
return
self
.
layer
(
x
)
def
training_step
(
self
,
batch
,
batch_idx
):
loss
=
self
(
batch
).
sum
()
self
.
log
(
'train_loss'
,
loss
)
return
{
'loss'
:
loss
}
def
validation_step
(
self
,
batch
,
batch_idx
):
loss
=
self
(
batch
).
sum
()
self
.
log
(
'valid_loss'
,
loss
)
def
test_step
(
self
,
batch
,
batch_idx
):
loss
=
self
(
batch
).
sum
()
self
.
log
(
'test_loss'
,
loss
)
def
configure_optimizers
(
self
):
return
torch
.
optim
.
SGD
(
self
.
layer
.
parameters
(),
lr
=
0.1
)
def
test_concat_loader
():
from
nni.retiarii.oneshot.pytorch.dataloader
import
ConcatLoader
loaders
=
{
'a'
:
DataLoader
(
range
(
10
),
batch_size
=
4
),
'b'
:
DataLoader
(
range
(
20
),
batch_size
=
5
),
}
dataloader
=
ConcatLoader
(
loaders
)
assert
len
(
dataloader
)
==
7
for
i
,
(
data
,
label
)
in
enumerate
(
dataloader
):
if
i
<
3
:
assert
len
(
data
)
<=
4
assert
label
==
'a'
else
:
assert
len
(
data
)
<=
5
assert
label
==
'b'
def
test_concat_loader_nested
():
from
nni.retiarii.oneshot.pytorch.dataloader
import
ConcatLoader
loaders
=
{
'a'
:
[
DataLoader
(
range
(
10
),
batch_size
=
4
),
DataLoader
(
range
(
20
),
batch_size
=
6
)],
'b'
:
DataLoader
(
range
(
20
),
batch_size
=
5
),
}
dataloader
=
ConcatLoader
(
loaders
)
assert
len
(
dataloader
)
==
7
for
i
,
(
data
,
label
)
in
enumerate
(
dataloader
):
if
i
<
3
:
assert
isinstance
(
data
,
list
)
and
len
(
data
)
==
2
assert
label
==
'a'
else
:
assert
label
==
'b'
@
pytest
.
mark
.
parametrize
(
'replace_sampler_ddp'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'is_min_size_mode'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'num_devices'
,
[
'auto'
,
1
,
3
,
10
])
def
test_concat_loader_with_ddp
(
replace_sampler_ddp
:
bool
,
is_min_size_mode
:
bool
,
num_devices
:
Union
[
int
,
str
]
):
"""Inspired by tests/trainer/test_supporters.py in lightning."""
from
nni.retiarii.oneshot.pytorch.dataloader
import
ConcatLoader
mode
=
'min_size'
if
is_min_size_mode
else
'max_size_cycle'
dim
=
3
n1
=
8
n2
=
6
n3
=
9
dataloader
=
ConcatLoader
({
'a'
:
{
'a1'
:
DataLoader
(
RandomDataset
(
dim
,
n1
),
batch_size
=
1
),
'a2'
:
DataLoader
(
RandomDataset
(
dim
,
n2
),
batch_size
=
1
),
},
'b'
:
DataLoader
(
RandomDataset
(
dim
,
n3
),
batch_size
=
1
),
},
mode
=
mode
)
expected_length_before_ddp
=
n3
+
(
min
(
n1
,
n2
)
if
is_min_size_mode
else
max
(
n1
,
n2
))
print
(
len
(
dataloader
))
assert
len
(
dataloader
)
==
expected_length_before_ddp
model
=
BoringModel
()
trainer
=
Trainer
(
strategy
=
'ddp'
,
accelerator
=
'auto'
,
devices
=
num_devices
,
replace_sampler_ddp
=
replace_sampler_ddp
,
)
trainer
.
_data_connector
.
attach_data
(
model
=
model
,
train_dataloaders
=
dataloader
,
val_dataloaders
=
None
,
datamodule
=
None
)
expected_length_after_ddp
=
(
math
.
ceil
(
n3
/
trainer
.
num_devices
)
+
\
math
.
ceil
((
min
(
n1
,
n2
)
if
is_min_size_mode
else
max
(
n1
,
n2
))
/
trainer
.
num_devices
)
if
replace_sampler_ddp
else
expected_length_before_ddp
)
print
(
'Num devices ='
,
trainer
.
num_devices
)
trainer
.
reset_train_dataloader
(
model
=
model
)
assert
trainer
.
train_dataloader
is
not
None
assert
trainer
.
train_dataloader
.
mode
==
mode
assert
trainer
.
num_training_batches
==
expected_length_after_ddp
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment