Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
39ec21ca
Unverified
Commit
39ec21ca
authored
May 20, 2022
by
Frandium
Committed by
GitHub
May 20, 2022
Browse files
Multi-GPU support of one-shot NAS (#4603)
parent
b4559f60
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
382 additions
and
94 deletions
+382
-94
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+60
-24
nni/retiarii/oneshot/pytorch/base_lightning.py
nni/retiarii/oneshot/pytorch/base_lightning.py
+10
-8
nni/retiarii/oneshot/pytorch/dataloader.py
nni/retiarii/oneshot/pytorch/dataloader.py
+77
-0
nni/retiarii/oneshot/pytorch/differentiable.py
nni/retiarii/oneshot/pytorch/differentiable.py
+3
-2
nni/retiarii/oneshot/pytorch/sampling.py
nni/retiarii/oneshot/pytorch/sampling.py
+19
-14
nni/retiarii/oneshot/pytorch/strategy.py
nni/retiarii/oneshot/pytorch/strategy.py
+38
-28
nni/retiarii/oneshot/pytorch/utils.py
nni/retiarii/oneshot/pytorch/utils.py
+1
-0
nni/retiarii/strategy/tpe_strategy.py
nni/retiarii/strategy/tpe_strategy.py
+3
-2
test/ut/retiarii/test_oneshot.py
test/ut/retiarii/test_oneshot.py
+40
-16
test/ut/retiarii/test_oneshot_utils.py
test/ut/retiarii/test_oneshot_utils.py
+131
-0
No files found.
nni/retiarii/evaluator/pytorch/lightning.py
View file @
39ec21ca
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
os
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
Union
,
Optional
,
List
,
Callable
,
Type
from
typing
import
Any
,
Dict
,
Union
,
Optional
,
List
,
Callable
,
Type
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -22,6 +22,7 @@ except ImportError:
...
@@ -22,6 +22,7 @@ except ImportError:
cgo_import_failed
=
True
cgo_import_failed
=
True
from
nni.retiarii.graph
import
Evaluator
from
nni.retiarii.graph
import
Evaluator
from
nni.typehint
import
Literal
__all__
=
[
'LightningModule'
,
'Trainer'
,
'DataLoader'
,
'Lightning'
,
'Classification'
,
'Regression'
]
__all__
=
[
'LightningModule'
,
'Trainer'
,
'DataLoader'
,
'Lightning'
,
'Classification'
,
'Regression'
]
...
@@ -36,6 +37,11 @@ class LightningModule(pl.LightningModule):
...
@@ -36,6 +37,11 @@ class LightningModule(pl.LightningModule):
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
"""
"""
running_mode
:
Literal
[
'multi'
,
'oneshot'
]
=
'multi'
"""An indicator of whether current module is running in a multi-trial experiment or an one-shot.
This flag should be automatically set by experiments when they start to run.
"""
def
set_model
(
self
,
model
:
Union
[
Callable
[[],
nn
.
Module
],
nn
.
Module
])
->
None
:
def
set_model
(
self
,
model
:
Union
[
Callable
[[],
nn
.
Module
],
nn
.
Module
])
->
None
:
"""Set the inner model (architecture) to train / evaluate.
"""Set the inner model (architecture) to train / evaluate.
...
@@ -59,6 +65,7 @@ DataLoader.__doc__ = """
...
@@ -59,6 +65,7 @@ DataLoader.__doc__ = """
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html
"""
"""
@
nni
.
trace
@
nni
.
trace
class
Lightning
(
Evaluator
):
class
Lightning
(
Evaluator
):
"""
"""
...
@@ -74,51 +81,67 @@ class Lightning(Evaluator):
...
@@ -74,51 +81,67 @@ class Lightning(Evaluator):
Parameters
Parameters
----------
----------
lightning_module
: LightningModule
lightning_module
Lightning module that defines the training logic.
Lightning module that defines the training logic.
trainer
: Trainer
trainer
Lightning trainer that handles the training.
Lightning trainer that handles the training.
train_dataloders
: DataLoader
train_dataloders
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
val_dataloaders
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
"""
"""
def
__init__
(
self
,
lightning_module
:
LightningModule
,
trainer
:
Trainer
,
def
__init__
(
self
,
lightning_module
:
LightningModule
,
trainer
:
Trainer
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
train_dataloaders
:
Optional
[
Any
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
):
val_dataloaders
:
Optional
[
Any
]
=
None
,
train_dataloader
:
Optional
[
Any
]
=
None
):
assert
isinstance
(
lightning_module
,
LightningModule
),
f
'Lightning module must be an instance of
{
__name__
}
.LightningModule.'
assert
isinstance
(
lightning_module
,
LightningModule
),
f
'Lightning module must be an instance of
{
__name__
}
.LightningModule.'
if
train_dataloader
is
not
None
:
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
train_dataloaders
=
train_dataloader
if
cgo_import_failed
:
if
cgo_import_failed
:
assert
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
),
f
'Trainer must be imported from
{
__name__
}
'
assert
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
),
f
'Trainer must be imported from
{
__name__
}
'
else
:
else
:
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert
(
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
))
or
isinstance
(
trainer
,
cgo_trainer
.
Trainer
),
\
assert
(
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
))
or
isinstance
(
trainer
,
cgo_trainer
.
Trainer
),
\
f
'Trainer must be imported from
{
__name__
}
or nni.retiarii.evaluator.pytorch.cgo.trainer'
f
'Trainer must be imported from
{
__name__
}
or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert
_check_dataloader
(
train_dataloader
),
f
'Wrong dataloader type. Try import DataLoader from
{
__name__
}
.'
if
not
_check_dataloader
(
train_dataloaders
):
assert
_check_dataloader
(
val_dataloaders
),
f
'Wrong dataloader type. Try import DataLoader from
{
__name__
}
.'
warnings
.
warn
(
f
'Please try to wrap PyTorch DataLoader with nni.trace or '
f
'import DataLoader from
{
__name__
}
:
{
train_dataloaders
}
'
,
RuntimeWarning
)
if
not
_check_dataloader
(
val_dataloaders
):
warnings
.
warn
(
f
'Please try to wrap PyTorch DataLoader with nni.trace or '
f
'import DataLoader from
{
__name__
}
:
{
val_dataloaders
}
'
,
RuntimeWarning
)
self
.
module
=
lightning_module
self
.
module
=
lightning_module
self
.
trainer
=
trainer
self
.
trainer
=
trainer
self
.
train_dataloader
=
train_dataloader
self
.
train_dataloader
s
=
train_dataloader
s
self
.
val_dataloaders
=
val_dataloaders
self
.
val_dataloaders
=
val_dataloaders
@
staticmethod
@
staticmethod
def
_load
(
ir
):
def
_load
(
ir
):
return
Lightning
(
ir
[
'module'
],
ir
[
'trainer'
],
ir
[
'train_dataloader'
],
ir
[
'val_dataloaders'
])
return
Lightning
(
ir
[
'module'
],
ir
[
'trainer'
],
ir
[
'train_dataloader
s
'
],
ir
[
'val_dataloaders'
])
def
_dump
(
self
):
def
_dump
(
self
):
return
{
return
{
'type'
:
self
.
__class__
,
'type'
:
self
.
__class__
,
'module'
:
self
.
module
,
'module'
:
self
.
module
,
'trainer'
:
self
.
trainer
,
'trainer'
:
self
.
trainer
,
'train_dataloader'
:
self
.
train_dataloader
,
'train_dataloader
s
'
:
self
.
train_dataloader
s
,
'val_dataloaders'
:
self
.
val_dataloaders
'val_dataloaders'
:
self
.
val_dataloaders
}
}
def
_execute
(
self
,
model_cls
):
def
_execute
(
self
,
model_cls
):
return
self
.
fit
(
model_cls
)
return
self
.
fit
(
model_cls
)
@
property
def
train_dataloader
(
self
):
warnings
.
warn
(
'train_dataloader is deprecated, please use `train_dataloaders`.'
,
DeprecationWarning
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
eq_func
=
False
eq_func
=
False
eq_args
=
False
eq_args
=
False
...
@@ -146,15 +169,18 @@ class Lightning(Evaluator):
...
@@ -146,15 +169,18 @@ class Lightning(Evaluator):
The model to fit.
The model to fit.
"""
"""
self
.
module
.
set_model
(
model
)
self
.
module
.
set_model
(
model
)
return
self
.
trainer
.
fit
(
self
.
module
,
self
.
train_dataloader
,
self
.
val_dataloaders
)
return
self
.
trainer
.
fit
(
self
.
module
,
self
.
train_dataloader
s
,
self
.
val_dataloaders
)
def
_check_dataloader
(
dataloader
):
def
_check_dataloader
(
dataloader
):
if
dataloader
is
None
:
# Check the type of dataloader recursively.
return
True
if
isinstance
(
dataloader
,
list
):
if
isinstance
(
dataloader
,
list
):
return
all
([
_check_dataloader
(
d
)
for
d
in
dataloader
])
return
all
([
_check_dataloader
(
d
)
for
d
in
dataloader
])
return
isinstance
(
dataloader
,
torch_data
.
DataLoader
)
and
is_traceable
(
dataloader
)
if
isinstance
(
dataloader
,
dict
):
return
all
([
_check_dataloader
(
v
)
for
v
in
dataloader
.
values
()])
if
isinstance
(
dataloader
,
torch_data
.
DataLoader
):
return
is_traceable
(
dataloader
)
return
True
### The following are some commonly used Lightning modules ###
### The following are some commonly used Lightning modules ###
...
@@ -176,7 +202,6 @@ class _SupervisedLearningModule(LightningModule):
...
@@ -176,7 +202,6 @@ class _SupervisedLearningModule(LightningModule):
if
export_onnx
is
None
or
export_onnx
is
True
:
if
export_onnx
is
None
or
export_onnx
is
True
:
self
.
export_onnx
=
Path
(
os
.
environ
.
get
(
'NNI_OUTPUT_DIR'
,
'.'
))
/
'model.onnx'
self
.
export_onnx
=
Path
(
os
.
environ
.
get
(
'NNI_OUTPUT_DIR'
,
'.'
))
/
'model.onnx'
self
.
export_onnx
.
parent
.
mkdir
(
exist_ok
=
True
)
elif
export_onnx
:
elif
export_onnx
:
self
.
export_onnx
=
Path
(
export_onnx
)
self
.
export_onnx
=
Path
(
export_onnx
)
else
:
else
:
...
@@ -199,7 +224,8 @@ class _SupervisedLearningModule(LightningModule):
...
@@ -199,7 +224,8 @@ class _SupervisedLearningModule(LightningModule):
x
,
y
=
batch
x
,
y
=
batch
y_hat
=
self
(
x
)
y_hat
=
self
(
x
)
if
self
.
export_onnx
is
not
None
:
if
self
.
running_mode
==
'multi'
and
self
.
export_onnx
is
not
None
:
self
.
export_onnx
.
parent
.
mkdir
(
exist_ok
=
True
)
try
:
try
:
self
.
to_onnx
(
self
.
export_onnx
,
x
,
export_params
=
True
)
self
.
to_onnx
(
self
.
export_onnx
,
x
,
export_params
=
True
)
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
...
@@ -221,10 +247,12 @@ class _SupervisedLearningModule(LightningModule):
...
@@ -221,10 +247,12 @@ class _SupervisedLearningModule(LightningModule):
return
self
.
optimizer
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
)
# type: ignore
return
self
.
optimizer
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
)
# type: ignore
def
on_validation_epoch_end
(
self
):
def
on_validation_epoch_end
(
self
):
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
if
self
.
running_mode
==
'multi'
:
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
def
on_fit_end
(
self
):
def
on_fit_end
(
self
):
nni
.
report_final_result
(
self
.
_get_validation_metrics
())
if
self
.
running_mode
==
'multi'
:
nni
.
report_final_result
(
self
.
_get_validation_metrics
())
def
_get_validation_metrics
(
self
):
def
_get_validation_metrics
(
self
):
if
len
(
self
.
metrics
)
==
1
:
if
len
(
self
.
metrics
)
==
1
:
...
@@ -283,14 +311,18 @@ class Classification(Lightning):
...
@@ -283,14 +311,18 @@ class Classification(Lightning):
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
train_dataloader
s
:
Optional
[
DataLoader
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
export_onnx
:
bool
=
True
,
export_onnx
:
bool
=
True
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
**
trainer_kwargs
):
**
trainer_kwargs
):
if
train_dataloader
is
not
None
:
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
train_dataloaders
=
train_dataloader
module
=
_ClassificationModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
module
=
_ClassificationModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
train_dataloader
=
train_dataloader
,
val_dataloaders
=
val_dataloaders
)
train_dataloader
s
=
train_dataloader
s
,
val_dataloaders
=
val_dataloaders
)
@
nni
.
trace
@
nni
.
trace
...
@@ -336,11 +368,15 @@ class Regression(Lightning):
...
@@ -336,11 +368,15 @@ class Regression(Lightning):
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
train_dataloader
s
:
Optional
[
DataLoader
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
export_onnx
:
bool
=
True
,
export_onnx
:
bool
=
True
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
**
trainer_kwargs
):
**
trainer_kwargs
):
if
train_dataloader
is
not
None
:
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
train_dataloaders
=
train_dataloader
module
=
_RegressionModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
module
=
_RegressionModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
train_dataloader
=
train_dataloader
,
val_dataloaders
=
val_dataloaders
)
train_dataloader
s
=
train_dataloader
s
,
val_dataloaders
=
val_dataloaders
)
nni/retiarii/oneshot/pytorch/base_lightning.py
View file @
39ec21ca
...
@@ -18,6 +18,7 @@ import nni.retiarii.nn.pytorch as nas_nn
...
@@ -18,6 +18,7 @@ import nni.retiarii.nn.pytorch as nas_nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.serializer
import
is_traceable
from
nni.common.serializer
import
is_traceable
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
nni.typehint
import
Literal
from
.supermodule.base
import
BaseSuperNetModule
from
.supermodule.base
import
BaseSuperNetModule
__all__
=
[
'MutationHook'
,
'BaseSuperNetModule'
,
'BaseOneShotLightningModule'
,
'traverse_and_mutate_submodules'
]
__all__
=
[
'MutationHook'
,
'BaseSuperNetModule'
,
'BaseOneShotLightningModule'
,
'traverse_and_mutate_submodules'
]
...
@@ -334,21 +335,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -334,21 +335,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
return
arc_optimizers
+
w_optimizers
,
lr_schedulers
return
arc_optimizers
+
w_optimizers
,
lr_schedulers
def
on_train_start
(
self
):
def
on_train_start
(
self
):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self
.
model
.
trainer
=
self
.
trainer
# type: ignore
self
.
model
.
log
=
self
.
log
return
self
.
model
.
on_train_start
()
return
self
.
model
.
on_train_start
()
def
on_train_end
(
self
):
def
on_train_end
(
self
):
return
self
.
model
.
on_train_end
()
return
self
.
model
.
on_train_end
()
def
on_fit_start
(
self
):
def
on_fit_start
(
self
):
return
self
.
model
.
on_train_start
()
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self
.
model
.
trainer
=
self
.
trainer
# type: ignore
self
.
model
.
log
=
self
.
log
return
self
.
model
.
on_fit_start
()
def
on_fit_end
(
self
):
def
on_fit_end
(
self
):
return
self
.
model
.
on_
train
_end
()
return
self
.
model
.
on_
fit
_end
()
def
on_train_batch_start
(
self
,
batch
,
batch_idx
,
unused
=
0
):
def
on_train_batch_start
(
self
,
batch
,
batch_idx
,
unused
=
0
):
return
self
.
model
.
on_train_batch_start
(
batch
,
batch_idx
,
unused
)
return
self
.
model
.
on_train_batch_start
(
batch
,
batch_idx
,
unused
)
...
@@ -356,6 +357,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -356,6 +357,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
def
on_train_batch_end
(
self
,
outputs
,
batch
,
batch_idx
,
unused
=
0
):
def
on_train_batch_end
(
self
,
outputs
,
batch
,
batch_idx
,
unused
=
0
):
return
self
.
model
.
on_train_batch_end
(
outputs
,
batch
,
batch_idx
,
unused
)
return
self
.
model
.
on_train_batch_end
(
outputs
,
batch
,
batch_idx
,
unused
)
# Deprecated hooks in pytorch-lightning
def
on_epoch_start
(
self
):
def
on_epoch_start
(
self
):
return
self
.
model
.
on_epoch_start
()
return
self
.
model
.
on_epoch_start
()
...
@@ -427,7 +429,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -427,7 +429,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
else
:
else
:
apply
(
lr_schedulers
)
apply
(
lr_schedulers
)
def
call_weight_optimizers
(
self
,
method
):
def
call_weight_optimizers
(
self
,
method
:
Literal
[
'step'
,
'zero_grad'
]
):
"""
"""
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
...
...
nni/retiarii/oneshot/pytorch/dataloader.py
0 → 100644
View file @
39ec21ca
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
from
typing
import
Any
from
pytorch_lightning.trainer.supporters
import
CombinedLoader
,
CombinedLoaderIterator
class
ConcatLoader
(
CombinedLoader
):
"""This loader is same as CombinedLoader in PyTorch-Lightning, but concatenate sub-loaders
instead of loading them in parallel.
Parameters
----------
loaders
For example, ::
{
"train": DataLoader(train_dataset),
"val": DataLoader(val_dataset)
}
In this example, the loader will first produce the batches from "train", then "val".
mode
Only support "min_size" for now.
"""
def
__init__
(
self
,
loaders
:
dict
[
str
,
Any
],
mode
:
str
=
'min_size'
):
# FIXME: max_cycle will make dataloaders cycle iterators,
# causing extra problems.
if
mode
!=
'min_size'
:
raise
ValueError
(
'Only min_size mode is supported now.'
)
super
().
__init__
(
loaders
,
mode
)
def
__iter__
(
self
)
->
Any
:
"""Replace the super-class iterator with ours."""
self
.
_try_to_patch_pytorch_dataloader
()
iterator
=
ConcatLoaderIterator
(
self
.
loaders
)
# handle fault tolerant restart.
self
.
on_restart
(
iterator
)
self
.
_iterator
=
iterator
return
iterator
@
staticmethod
def
_try_to_patch_pytorch_dataloader
():
"""Copied from CombinedLoader."""
from
torch.utils.data.dataloader
import
_BaseDataLoaderIter
# prevent `NotImplementedError` from PyTorch:
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
def
__getstate__patch__
(
*
_
):
return
{}
_BaseDataLoaderIter
.
__getstate__
=
__getstate__patch__
# type: ignore
def
__len__
(
self
)
->
int
:
return
int
(
sum
(
self
.
_calc_num_batches
(
loader
)
for
loader
in
self
.
loaders
.
values
()))
class
ConcatLoaderIterator
(
CombinedLoaderIterator
):
"""Similar to CombinedLoaderIterator in Lightning, but in a concat manner."""
def
__next__
(
self
)
->
Any
:
"""Fetches the next batch from multiple data loaders,
by looking for the first iterator that isn't exhausted yet.
"""
if
not
len
(
self
.
loader_iters
)
==
len
(
self
.
loaders
):
raise
RuntimeError
(
'loader_iters must have the same length as loaders.'
)
for
i
,
(
loader_name
,
iterator
)
in
enumerate
(
self
.
loader_iters
.
items
()):
try
:
return
(
self
.
request_next_batch
(
iterator
),
loader_name
)
except
StopIteration
:
if
i
+
1
==
len
(
self
.
loader_iters
):
raise
nni/retiarii/oneshot/pytorch/differentiable.py
View file @
39ec21ca
...
@@ -75,8 +75,9 @@ class DartsLightningModule(BaseOneShotLightningModule):
...
@@ -75,8 +75,9 @@ class DartsLightningModule(BaseOneShotLightningModule):
if
not
isinstance
(
arc_optim
,
optim
.
Optimizer
):
if
not
isinstance
(
arc_optim
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_optim to be a single Optimizer, but found:
{
arc_optim
}
'
)
raise
TypeError
(
f
'Expect arc_optim to be a single Optimizer, but found:
{
arc_optim
}
'
)
# The InterleavedTrainValDataLoader yields both train and val data in a batch
# DARTS strategy makes sure that ``train`` and ``val`` must be in the batch
trn_batch
,
val_batch
=
batch
trn_batch
=
batch
[
'train'
]
val_batch
=
batch
[
'val'
]
# phase 1: architecture step
# phase 1: architecture step
# The _resample hook is kept for some darts-based NAS methods like proxyless.
# The _resample hook is kept for some darts-based NAS methods like proxyless.
...
...
nni/retiarii/oneshot/pytorch/sampling.py
View file @
39ec21ca
...
@@ -133,29 +133,30 @@ class EnasLightningModule(RandomSamplingLightningModule):
...
@@ -133,29 +133,30 @@ class EnasLightningModule(RandomSamplingLightningModule):
def
configure_architecture_optimizers
(
self
):
def
configure_architecture_optimizers
(
self
):
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
batch_packed
,
batch_idx
):
# The ConcatenateTrainValDataloader yields both data and which dataloader it comes from.
batch
,
mode
=
batch_packed
batch
,
source
=
batch
if
source
==
'train'
:
if
mode
==
'train'
:
# step 1: train model params
# train model params
self
.
resample
()
with
torch
.
no_grad
():
self
.
resample
()
self
.
call_weight_optimizers
(
'zero_grad'
)
self
.
call_weight_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
step_output
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
w_step_loss
=
step_output
[
'loss'
]
\
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
if
isinstance
(
step_output
,
dict
)
else
step_output
self
.
manual_backward
(
w_step_loss
)
self
.
manual_backward
(
w_step_loss
)
self
.
call_weight_optimizers
(
'step'
)
self
.
call_weight_optimizers
(
'step'
)
return
loss_and_metrics
if
source
==
'val'
:
else
:
#
step 2:
train ENAS agent
# train ENAS agent
arc_opt
=
self
.
architecture_optimizers
()
arc_opt
=
self
.
architecture_optimizers
()
if
not
isinstance
(
arc_opt
,
optim
.
Optimizer
):
if
not
isinstance
(
arc_opt
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_opt to be a single Optimizer, but found:
{
arc_opt
}
'
)
raise
TypeError
(
f
'Expect arc_opt to be a single Optimizer, but found:
{
arc_opt
}
'
)
arc_opt
.
zero_grad
()
arc_opt
.
zero_grad
()
self
.
resample
()
self
.
resample
()
self
.
model
.
validation_step
(
batch
,
batch_idx
)
step_output
=
self
.
model
.
validation_step
(
batch
,
batch_idx
)
# use the default metric of self.model as reward function
# use the default metric of self.model as reward function
if
len
(
self
.
trainer
.
callback_metrics
)
==
1
:
if
len
(
self
.
trainer
.
callback_metrics
)
==
1
:
_
,
metric
=
next
(
iter
(
self
.
trainer
.
callback_metrics
.
items
()))
_
,
metric
=
next
(
iter
(
self
.
trainer
.
callback_metrics
.
items
()))
...
@@ -163,7 +164,9 @@ class EnasLightningModule(RandomSamplingLightningModule):
...
@@ -163,7 +164,9 @@ class EnasLightningModule(RandomSamplingLightningModule):
metric_name
=
self
.
reward_metric_name
or
'default'
metric_name
=
self
.
reward_metric_name
or
'default'
if
metric_name
not
in
self
.
trainer
.
callback_metrics
:
if
metric_name
not
in
self
.
trainer
.
callback_metrics
:
raise
KeyError
(
f
'Model reported metrics should contain a ``
{
metric_name
}
`` key but '
raise
KeyError
(
f
'Model reported metrics should contain a ``
{
metric_name
}
`` key but '
f
'found multiple metrics without default:
{
self
.
trainer
.
callback_metrics
.
keys
()
}
'
)
f
'found multiple (or zero) metrics without default:
{
list
(
self
.
trainer
.
callback_metrics
.
keys
())
}
. '
f
'Try to use self.log to report metrics with the specified key ``
{
metric_name
}
`` in validation_step, '
'and remember to set on_step=True.'
)
metric
=
self
.
trainer
.
callback_metrics
[
metric_name
]
metric
=
self
.
trainer
.
callback_metrics
[
metric_name
]
reward
:
float
=
metric
.
item
()
reward
:
float
=
metric
.
item
()
...
@@ -183,6 +186,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
...
@@ -183,6 +186,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
arc_opt
.
step
()
arc_opt
.
step
()
arc_opt
.
zero_grad
()
arc_opt
.
zero_grad
()
return
step_output
def
resample
(
self
):
def
resample
(
self
):
"""Resample the architecture with ENAS controller."""
"""Resample the architecture with ENAS controller."""
sample
=
self
.
controller
.
resample
()
sample
=
self
.
controller
.
resample
()
...
...
nni/retiarii/oneshot/pytorch/strategy.py
View file @
39ec21ca
...
@@ -16,7 +16,6 @@ import warnings
...
@@ -16,7 +16,6 @@ import warnings
from
typing
import
Any
,
Type
from
typing
import
Any
,
Type
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
nni.retiarii.graph
import
Model
from
nni.retiarii.graph
import
Model
from
nni.retiarii.strategy.base
import
BaseStrategy
from
nni.retiarii.strategy.base
import
BaseStrategy
...
@@ -25,7 +24,6 @@ from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule
...
@@ -25,7 +24,6 @@ from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule
from
.base_lightning
import
BaseOneShotLightningModule
from
.base_lightning
import
BaseOneShotLightningModule
from
.differentiable
import
DartsLightningModule
,
ProxylessLightningModule
,
GumbelDartsLightningModule
from
.differentiable
import
DartsLightningModule
,
ProxylessLightningModule
,
GumbelDartsLightningModule
from
.sampling
import
EnasLightningModule
,
RandomSamplingLightningModule
from
.sampling
import
EnasLightningModule
,
RandomSamplingLightningModule
from
.utils
import
InterleavedTrainValDataLoader
,
ConcatenateTrainValDataLoader
class
OneShotStrategy
(
BaseStrategy
):
class
OneShotStrategy
(
BaseStrategy
):
...
@@ -37,15 +35,18 @@ class OneShotStrategy(BaseStrategy):
...
@@ -37,15 +35,18 @@ class OneShotStrategy(BaseStrategy):
self
.
model
:
BaseOneShotLightningModule
|
None
=
None
self
.
model
:
BaseOneShotLightningModule
|
None
=
None
def
_get_dataloader
(
self
,
train_dataloader
:
DataLoader
,
val_dataloaders
:
DataLoader
|
list
[
DataLoader
])
\
def
preprocess_dataloader
(
self
,
train_dataloaders
:
Any
,
val_dataloaders
:
Any
)
->
tuple
[
Any
,
Any
]:
->
DataLoader
|
tuple
[
DataLoader
,
DataLoader
]:
"""
"""
One-shot strategy typically requires a customized dataloader.
One-shot strategy typically requires fusing train and validation dataloader in an ad-hoc way.
As one-shot strategy doesn't try to open the blackbox of a batch,
If only train dataloader is produced, return one dataloader.
theoretically, these dataloader can be
Otherwise, return train dataloader and valid loader as a tuple.
`any dataloader types supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
Returns
-------
A tuple of preprocessed train dataloaders and validation dataloaders.
"""
"""
r
aise
NotImplementedError
()
r
eturn
train_dataloaders
,
val_dataloaders
def
run
(
self
,
base_model
:
Model
,
applied_mutators
):
def
run
(
self
,
base_model
:
Model
,
applied_mutators
):
# one-shot strategy doesn't use ``applied_mutators``
# one-shot strategy doesn't use ``applied_mutators``
...
@@ -64,18 +65,15 @@ class OneShotStrategy(BaseStrategy):
...
@@ -64,18 +65,15 @@ class OneShotStrategy(BaseStrategy):
raise
TypeError
(
'Evaluator needs to be a lightning evaluator to make one-shot strategy work.'
)
raise
TypeError
(
'Evaluator needs to be a lightning evaluator to make one-shot strategy work.'
)
evaluator_module
:
LightningModule
=
base_model
.
evaluator
.
module
evaluator_module
:
LightningModule
=
base_model
.
evaluator
.
module
evaluator_module
.
running_mode
=
'oneshot'
evaluator_module
.
set_model
(
py_model
)
evaluator_module
.
set_model
(
py_model
)
self
.
model
=
self
.
oneshot_module
(
evaluator_module
,
**
self
.
oneshot_kwargs
)
self
.
model
=
self
.
oneshot_module
(
evaluator_module
,
**
self
.
oneshot_kwargs
)
evaluator
:
Lightning
=
base_model
.
evaluator
evaluator
:
Lightning
=
base_model
.
evaluator
if
evaluator
.
train_dataloader
is
None
or
evaluator
.
val_dataloaders
is
None
:
if
evaluator
.
train_dataloaders
is
None
or
evaluator
.
val_dataloaders
is
None
:
raise
TypeError
(
'Train or val dataloader is not set.'
)
raise
TypeError
(
'Training and validation dataloader are both required to set in evaluator for one-shot strategy.'
)
dataloader
=
self
.
_get_dataloader
(
evaluator
.
train_dataloader
,
evaluator
.
val_dataloaders
)
train_loader
,
val_loader
=
self
.
preprocess_dataloader
(
evaluator
.
train_dataloaders
,
evaluator
.
val_dataloaders
)
if
isinstance
(
dataloader
,
tuple
):
evaluator
.
trainer
.
fit
(
self
.
model
,
train_loader
,
val_loader
)
dataloader
,
val_loader
=
dataloader
evaluator
.
trainer
.
fit
(
self
.
model
,
dataloader
,
val_loader
)
else
:
evaluator
.
trainer
.
fit
(
self
.
model
,
dataloader
)
def
export_top_models
(
self
,
top_k
:
int
=
1
)
->
list
[
Any
]:
def
export_top_models
(
self
,
top_k
:
int
=
1
)
->
list
[
Any
]:
if
self
.
model
is
None
:
if
self
.
model
is
None
:
...
@@ -91,8 +89,12 @@ class DARTS(OneShotStrategy):
...
@@ -91,8 +89,12 @@ class DARTS(OneShotStrategy):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
DartsLightningModule
,
**
kwargs
)
super
().
__init__
(
DartsLightningModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
def
preprocess_dataloader
(
self
,
train_dataloaders
,
val_dataloaders
):
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
# By returning a dict, we make a CombinedLoader (in Lightning)
return
{
'train'
:
train_dataloaders
,
'val'
:
val_dataloaders
},
None
class
Proxyless
(
OneShotStrategy
):
class
Proxyless
(
OneShotStrategy
):
...
@@ -101,8 +103,11 @@ class Proxyless(OneShotStrategy):
...
@@ -101,8 +103,11 @@ class Proxyless(OneShotStrategy):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
ProxylessLightningModule
,
**
kwargs
)
super
().
__init__
(
ProxylessLightningModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
def
preprocess_dataloader
(
self
,
train_dataloaders
,
val_dataloaders
):
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
return
{
'train'
:
train_dataloaders
,
'val'
:
val_dataloaders
},
None
class
GumbelDARTS
(
OneShotStrategy
):
class
GumbelDARTS
(
OneShotStrategy
):
...
@@ -111,8 +116,11 @@ class GumbelDARTS(OneShotStrategy):
...
@@ -111,8 +116,11 @@ class GumbelDARTS(OneShotStrategy):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
GumbelDartsLightningModule
,
**
kwargs
)
super
().
__init__
(
GumbelDartsLightningModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
def
preprocess_dataloader
(
self
,
train_dataloaders
,
val_dataloaders
):
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
return
{
'train'
:
train_dataloaders
,
'val'
:
val_dataloaders
},
None
class
ENAS
(
OneShotStrategy
):
class
ENAS
(
OneShotStrategy
):
...
@@ -121,8 +129,13 @@ class ENAS(OneShotStrategy):
...
@@ -121,8 +129,13 @@ class ENAS(OneShotStrategy):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
EnasLightningModule
,
**
kwargs
)
super
().
__init__
(
EnasLightningModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
def
preprocess_dataloader
(
self
,
train_dataloaders
,
val_dataloaders
):
return
ConcatenateTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
# Import locally to avoid import error on legacy PL version
from
.dataloader
import
ConcatLoader
return
ConcatLoader
({
'train'
:
train_dataloaders
,
'val'
:
val_dataloaders
}),
None
class
RandomOneShot
(
OneShotStrategy
):
class
RandomOneShot
(
OneShotStrategy
):
...
@@ -130,6 +143,3 @@ class RandomOneShot(OneShotStrategy):
...
@@ -130,6 +143,3 @@ class RandomOneShot(OneShotStrategy):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
RandomSamplingLightningModule
,
**
kwargs
)
super
().
__init__
(
RandomSamplingLightningModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
train_dataloader
,
val_dataloaders
nni/retiarii/oneshot/pytorch/utils.py
View file @
39ec21ca
...
@@ -132,6 +132,7 @@ class AverageMeter:
...
@@ -132,6 +132,7 @@ class AverageMeter:
def
_replace_module_with_type
(
root_module
,
init_fn
,
type_name
,
modules
):
def
_replace_module_with_type
(
root_module
,
init_fn
,
type_name
,
modules
):
if
modules
is
None
:
if
modules
is
None
:
modules
=
[]
modules
=
[]
def
apply
(
m
):
def
apply
(
m
):
for
name
,
child
in
m
.
named_children
():
for
name
,
child
in
m
.
named_children
():
if
isinstance
(
child
,
type_name
):
if
isinstance
(
child
,
type_name
):
...
...
nni/retiarii/strategy/tpe_strategy.py
View file @
39ec21ca
...
@@ -5,8 +5,6 @@ import logging
...
@@ -5,8 +5,6 @@ import logging
import
time
import
time
from
typing
import
Optional
from
typing
import
Optional
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
,
budget_exhausted
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
,
budget_exhausted
from
.base
import
BaseStrategy
from
.base
import
BaseStrategy
...
@@ -15,6 +13,9 @@ _logger = logging.getLogger(__name__)
...
@@ -15,6 +13,9 @@ _logger = logging.getLogger(__name__)
class
TPESampler
(
Sampler
):
class
TPESampler
(
Sampler
):
def
__init__
(
self
,
optimize_mode
=
'minimize'
):
def
__init__
(
self
,
optimize_mode
=
'minimize'
):
# Move import here to eliminate some warning messages about dill.
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
self
.
tpe_tuner
=
HyperoptTuner
(
'tpe'
,
optimize_mode
)
self
.
tpe_tuner
=
HyperoptTuner
(
'tpe'
,
optimize_mode
)
self
.
cur_sample
:
Optional
[
dict
]
=
None
self
.
cur_sample
:
Optional
[
dict
]
=
None
self
.
index
:
Optional
[
int
]
=
None
self
.
index
:
Optional
[
int
]
=
None
...
...
test/ut/retiarii/test_oneshot.py
View file @
39ec21ca
...
@@ -15,6 +15,9 @@ from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice
...
@@ -15,6 +15,9 @@ from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice
from
nni.retiarii.strategy
import
BaseStrategy
from
nni.retiarii.strategy
import
BaseStrategy
pytestmark
=
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
class
DepthwiseSeparableConv
(
nn
.
Module
):
class
DepthwiseSeparableConv
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
,
out_ch
):
def
__init__
(
self
,
in_ch
,
out_ch
):
super
().
__init__
()
super
().
__init__
()
...
@@ -171,7 +174,7 @@ class CustomOpValueChoiceNet(nn.Module):
...
@@ -171,7 +174,7 @@ class CustomOpValueChoiceNet(nn.Module):
return
F
.
log_softmax
(
x
,
dim
=
1
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
_mnist_net
(
type_
):
def
_mnist_net
(
type_
,
evaluator_kwargs
):
if
type_
==
'simple'
:
if
type_
==
'simple'
:
base_model
=
SimpleNet
(
False
)
base_model
=
SimpleNet
(
False
)
elif
type_
==
'simple_value_choice'
:
elif
type_
==
'simple_value_choice'
:
...
@@ -187,17 +190,18 @@ def _mnist_net(type_):
...
@@ -187,17 +190,18 @@ def _mnist_net(type_):
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
MNIST
(
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
MNIST
(
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
# Multi-GPU combined dataloader will break this subset sampler. Expected though.
train_random_sampler
=
RandomSampler
(
train_dataset
,
True
,
int
(
len
(
train_dataset
)
/
20
))
train_random_sampler
=
RandomSampler
(
train_dataset
,
True
,
int
(
len
(
train_dataset
)
/
20
))
train_loader
=
DataLoader
(
train_dataset
,
64
,
sampler
=
train_random_sampler
)
train_loader
=
DataLoader
(
train_dataset
,
64
,
sampler
=
train_random_sampler
)
valid_dataset
=
MNIST
(
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
valid_dataset
=
MNIST
(
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
valid_random_sampler
=
RandomSampler
(
valid_dataset
,
True
,
int
(
len
(
valid_dataset
)
/
20
))
valid_random_sampler
=
RandomSampler
(
valid_dataset
,
True
,
int
(
len
(
valid_dataset
)
/
20
))
valid_loader
=
DataLoader
(
valid_dataset
,
64
,
sampler
=
valid_random_sampler
)
valid_loader
=
DataLoader
(
valid_dataset
,
64
,
sampler
=
valid_random_sampler
)
evaluator
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
max_epochs
=
1
)
evaluator
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
evaluator_kwargs
)
return
base_model
,
evaluator
return
base_model
,
evaluator
def
_multihead_attention_net
():
def
_multihead_attention_net
(
evaluator_kwargs
):
base_model
=
MultiHeadAttentionNet
(
1
)
base_model
=
MultiHeadAttentionNet
(
1
)
class
AttentionRandDataset
(
Dataset
):
class
AttentionRandDataset
(
Dataset
):
...
@@ -222,19 +226,29 @@ def _multihead_attention_net():
...
@@ -222,19 +226,29 @@ def _multihead_attention_net():
train_loader
=
DataLoader
(
train_set
,
batch_size
=
32
)
train_loader
=
DataLoader
(
train_set
,
batch_size
=
32
)
val_loader
=
DataLoader
(
val_set
,
batch_size
=
32
)
val_loader
=
DataLoader
(
val_set
,
batch_size
=
32
)
evaluator
=
Regression
(
train_dataloader
=
train_loader
,
val_dataloaders
=
val_loader
,
max_epochs
=
1
)
evaluator
=
Regression
(
train_dataloader
=
train_loader
,
val_dataloaders
=
val_loader
,
**
evaluator_kwargs
)
return
base_model
,
evaluator
return
base_model
,
evaluator
def
_test_strategy
(
strategy_
,
support_value_choice
=
True
):
def
_test_strategy
(
strategy_
,
support_value_choice
=
True
,
multi_gpu
=
False
):
evaluator_kwargs
=
{
'max_epochs'
:
1
}
if
multi_gpu
:
evaluator_kwargs
.
update
(
strategy
=
'ddp'
,
accelerator
=
'gpu'
,
devices
=
torch
.
cuda
.
device_count
()
)
to_test
=
[
to_test
=
[
# (model, evaluator), support_or_net
# (model, evaluator), support_or_net
(
_mnist_net
(
'simple'
),
True
),
(
_mnist_net
(
'simple'
,
evaluator_kwargs
),
True
),
(
_mnist_net
(
'simple_value_choice'
),
support_value_choice
),
(
_mnist_net
(
'simple_value_choice'
,
evaluator_kwargs
),
support_value_choice
),
(
_mnist_net
(
'value_choice'
),
support_value_choice
),
(
_mnist_net
(
'value_choice'
,
evaluator_kwargs
),
support_value_choice
),
(
_mnist_net
(
'repeat'
),
False
),
# no strategy supports repeat currently
(
_mnist_net
(
'repeat'
,
evaluator_kwargs
),
False
),
# no strategy supports repeat currently
(
_mnist_net
(
'custom_op'
),
False
),
# this is definitely a NO
(
_mnist_net
(
'custom_op'
,
evaluator_kwargs
),
False
),
# this is definitely a NO
(
_multihead_attention_net
(),
support_value_choice
),
(
_multihead_attention_net
(
evaluator_kwargs
),
support_value_choice
),
]
]
for
(
base_model
,
evaluator
),
support_or_not
in
to_test
:
for
(
base_model
,
evaluator
),
support_or_not
in
to_test
:
...
@@ -256,17 +270,19 @@ def _test_strategy(strategy_, support_value_choice=True):
...
@@ -256,17 +270,19 @@ def _test_strategy(strategy_, support_value_choice=True):
experiment
.
run
(
config
)
experiment
.
run
(
config
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_darts
():
def
test_darts
():
_test_strategy
(
strategy
.
DARTS
())
_test_strategy
(
strategy
.
DARTS
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
device_count
()
<=
1
,
reason
=
'Must have multiple GPUs.'
)
def
test_darts_multi_gpu
():
_test_strategy
(
strategy
.
DARTS
(),
multi_gpu
=
True
)
def
test_proxyless
():
def
test_proxyless
():
_test_strategy
(
strategy
.
Proxyless
(),
False
)
_test_strategy
(
strategy
.
Proxyless
(),
False
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_enas
():
def
test_enas
():
def
strategy_fn
(
base_model
,
evaluator
):
def
strategy_fn
(
base_model
,
evaluator
):
if
isinstance
(
base_model
,
MultiHeadAttentionNet
):
if
isinstance
(
base_model
,
MultiHeadAttentionNet
):
...
@@ -276,12 +292,20 @@ def test_enas():
...
@@ -276,12 +292,20 @@ def test_enas():
_test_strategy
(
strategy_fn
)
_test_strategy
(
strategy_fn
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
device_count
()
<=
1
,
reason
=
'Must have multiple GPUs.'
)
def
test_enas_multi_gpu
():
def
strategy_fn
(
base_model
,
evaluator
):
if
isinstance
(
base_model
,
MultiHeadAttentionNet
):
return
strategy
.
ENAS
(
reward_metric_name
=
'val_mse'
)
return
strategy
.
ENAS
(
reward_metric_name
=
'val_acc'
)
_test_strategy
(
strategy_fn
,
multi_gpu
=
True
)
def
test_random
():
def
test_random
():
_test_strategy
(
strategy
.
RandomOneShot
())
_test_strategy
(
strategy
.
RandomOneShot
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_gumbel_darts
():
def
test_gumbel_darts
():
_test_strategy
(
strategy
.
GumbelDARTS
())
_test_strategy
(
strategy
.
GumbelDARTS
())
...
...
test/ut/retiarii/test_oneshot_utils.py
0 → 100644
View file @
39ec21ca
import
math
from
typing
import
Union
import
pytest
import
torch
import
pytorch_lightning
from
pytorch_lightning
import
LightningModule
,
Trainer
from
torch.utils.data
import
DataLoader
,
Dataset
pytestmark
=
pytest
.
mark
.
skipif
(
pytorch_lightning
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
class
RandomDataset
(
Dataset
):
def
__init__
(
self
,
size
,
length
):
self
.
len
=
length
self
.
data
=
torch
.
randn
(
length
,
size
)
def
__getitem__
(
self
,
index
):
return
self
.
data
[
index
]
def
__len__
(
self
):
return
self
.
len
class
BoringModel
(
LightningModule
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layer
=
torch
.
nn
.
Linear
(
32
,
2
)
def
forward
(
self
,
x
):
return
self
.
layer
(
x
)
def
training_step
(
self
,
batch
,
batch_idx
):
loss
=
self
(
batch
).
sum
()
self
.
log
(
'train_loss'
,
loss
)
return
{
'loss'
:
loss
}
def
validation_step
(
self
,
batch
,
batch_idx
):
loss
=
self
(
batch
).
sum
()
self
.
log
(
'valid_loss'
,
loss
)
def
test_step
(
self
,
batch
,
batch_idx
):
loss
=
self
(
batch
).
sum
()
self
.
log
(
'test_loss'
,
loss
)
def
configure_optimizers
(
self
):
return
torch
.
optim
.
SGD
(
self
.
layer
.
parameters
(),
lr
=
0.1
)
def
test_concat_loader
():
from
nni.retiarii.oneshot.pytorch.dataloader
import
ConcatLoader
loaders
=
{
'a'
:
DataLoader
(
range
(
10
),
batch_size
=
4
),
'b'
:
DataLoader
(
range
(
20
),
batch_size
=
5
),
}
dataloader
=
ConcatLoader
(
loaders
)
assert
len
(
dataloader
)
==
7
for
i
,
(
data
,
label
)
in
enumerate
(
dataloader
):
if
i
<
3
:
assert
len
(
data
)
<=
4
assert
label
==
'a'
else
:
assert
len
(
data
)
<=
5
assert
label
==
'b'
def
test_concat_loader_nested
():
from
nni.retiarii.oneshot.pytorch.dataloader
import
ConcatLoader
loaders
=
{
'a'
:
[
DataLoader
(
range
(
10
),
batch_size
=
4
),
DataLoader
(
range
(
20
),
batch_size
=
6
)],
'b'
:
DataLoader
(
range
(
20
),
batch_size
=
5
),
}
dataloader
=
ConcatLoader
(
loaders
)
assert
len
(
dataloader
)
==
7
for
i
,
(
data
,
label
)
in
enumerate
(
dataloader
):
if
i
<
3
:
assert
isinstance
(
data
,
list
)
and
len
(
data
)
==
2
assert
label
==
'a'
else
:
assert
label
==
'b'
@
pytest
.
mark
.
parametrize
(
'replace_sampler_ddp'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'is_min_size_mode'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'num_devices'
,
[
'auto'
,
1
,
3
,
10
])
def
test_concat_loader_with_ddp
(
replace_sampler_ddp
:
bool
,
is_min_size_mode
:
bool
,
num_devices
:
Union
[
int
,
str
]
):
"""Inspired by tests/trainer/test_supporters.py in lightning."""
from
nni.retiarii.oneshot.pytorch.dataloader
import
ConcatLoader
mode
=
'min_size'
if
is_min_size_mode
else
'max_size_cycle'
dim
=
3
n1
=
8
n2
=
6
n3
=
9
dataloader
=
ConcatLoader
({
'a'
:
{
'a1'
:
DataLoader
(
RandomDataset
(
dim
,
n1
),
batch_size
=
1
),
'a2'
:
DataLoader
(
RandomDataset
(
dim
,
n2
),
batch_size
=
1
),
},
'b'
:
DataLoader
(
RandomDataset
(
dim
,
n3
),
batch_size
=
1
),
},
mode
=
mode
)
expected_length_before_ddp
=
n3
+
(
min
(
n1
,
n2
)
if
is_min_size_mode
else
max
(
n1
,
n2
))
print
(
len
(
dataloader
))
assert
len
(
dataloader
)
==
expected_length_before_ddp
model
=
BoringModel
()
trainer
=
Trainer
(
strategy
=
'ddp'
,
accelerator
=
'auto'
,
devices
=
num_devices
,
replace_sampler_ddp
=
replace_sampler_ddp
,
)
trainer
.
_data_connector
.
attach_data
(
model
=
model
,
train_dataloaders
=
dataloader
,
val_dataloaders
=
None
,
datamodule
=
None
)
expected_length_after_ddp
=
(
math
.
ceil
(
n3
/
trainer
.
num_devices
)
+
\
math
.
ceil
((
min
(
n1
,
n2
)
if
is_min_size_mode
else
max
(
n1
,
n2
))
/
trainer
.
num_devices
)
if
replace_sampler_ddp
else
expected_length_before_ddp
)
print
(
'Num devices ='
,
trainer
.
num_devices
)
trainer
.
reset_train_dataloader
(
model
=
model
)
assert
trainer
.
train_dataloader
is
not
None
assert
trainer
.
train_dataloader
.
mode
==
mode
assert
trainer
.
num_training_batches
==
expected_length_after_ddp
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment