Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
802650ff
Unverified
Commit
802650ff
authored
Aug 12, 2022
by
Yuge Zhang
Committed by
GitHub
Aug 12, 2022
Browse files
Miscellaneous fixes of NAS (v2.9) (#5051)
parent
cd98c48f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
41 additions
and
22 deletions
+41
-22
examples/nas/multi-trial/nasbench101/network.py
examples/nas/multi-trial/nasbench101/network.py
+1
-1
examples/nas/multi-trial/nasbench201/network.py
examples/nas/multi-trial/nasbench201/network.py
+1
-1
nni/nas/benchmarks/utils.py
nni/nas/benchmarks/utils.py
+6
-1
nni/nas/evaluator/pytorch/cgo/evaluator.py
nni/nas/evaluator/pytorch/cgo/evaluator.py
+2
-2
nni/nas/evaluator/pytorch/lightning.py
nni/nas/evaluator/pytorch/lightning.py
+24
-13
nni/nas/execution/common/graph.py
nni/nas/execution/common/graph.py
+1
-1
nni/nas/nn/pytorch/layers.py
nni/nas/nn/pytorch/layers.py
+3
-0
test/algo/nas/test_cgo_engine.py
test/algo/nas/test_cgo_engine.py
+3
-3
No files found.
examples/nas/multi-trial/nasbench101/network.py
View file @
802650ff
...
@@ -114,7 +114,7 @@ class NasBench101TrainingModule(pl.LightningModule):
...
@@ -114,7 +114,7 @@ class NasBench101TrainingModule(pl.LightningModule):
momentum
=
0.9
,
alpha
=
0.9
,
eps
=
1.0
)
momentum
=
0.9
,
alpha
=
0.9
,
eps
=
1.0
)
return
{
return
{
'optimizer'
:
optimizer
,
'optimizer'
:
optimizer
,
'scheduler'
:
CosineAnnealingLR
(
optimizer
,
self
.
hparams
.
max_epochs
)
'
lr_
scheduler'
:
CosineAnnealingLR
(
optimizer
,
self
.
hparams
.
max_epochs
)
}
}
def
on_validation_epoch_end
(
self
):
def
on_validation_epoch_end
(
self
):
...
...
examples/nas/multi-trial/nasbench201/network.py
View file @
802650ff
...
@@ -103,7 +103,7 @@ class NasBench201TrainingModule(pl.LightningModule):
...
@@ -103,7 +103,7 @@ class NasBench201TrainingModule(pl.LightningModule):
momentum
=
0.9
,
alpha
=
0.9
,
eps
=
1.0
)
momentum
=
0.9
,
alpha
=
0.9
,
eps
=
1.0
)
return
{
return
{
'optimizer'
:
optimizer
,
'optimizer'
:
optimizer
,
'scheduler'
:
CosineAnnealingLR
(
optimizer
,
self
.
hparams
.
max_epochs
)
'
lr_
scheduler'
:
CosineAnnealingLR
(
optimizer
,
self
.
hparams
.
max_epochs
)
}
}
def
on_validation_epoch_end
(
self
):
def
on_validation_epoch_end
(
self
):
...
...
nni/nas/benchmarks/utils.py
View file @
802650ff
...
@@ -31,7 +31,12 @@ def load_benchmark(benchmark: str) -> SqliteExtDatabase:
...
@@ -31,7 +31,12 @@ def load_benchmark(benchmark: str) -> SqliteExtDatabase:
return
_loaded_benchmarks
[
benchmark
]
return
_loaded_benchmarks
[
benchmark
]
url
=
DB_URLS
[
benchmark
]
url
=
DB_URLS
[
benchmark
]
local_path
=
os
.
path
.
join
(
DATABASE_DIR
,
os
.
path
.
basename
(
url
))
local_path
=
os
.
path
.
join
(
DATABASE_DIR
,
os
.
path
.
basename
(
url
))
load_or_download_file
(
local_path
,
url
)
try
:
load_or_download_file
(
local_path
,
url
)
except
FileNotFoundError
:
raise
FileNotFoundError
(
f
'Please use `nni.nas.benchmarks.download_benchmark("
{
benchmark
}
")` to setup the benchmark first before using it.'
)
_loaded_benchmarks
[
benchmark
]
=
SqliteExtDatabase
(
local_path
,
autoconnect
=
True
)
_loaded_benchmarks
[
benchmark
]
=
SqliteExtDatabase
(
local_path
,
autoconnect
=
True
)
return
_loaded_benchmarks
[
benchmark
]
return
_loaded_benchmarks
[
benchmark
]
...
...
nni/nas/evaluator/pytorch/cgo/evaluator.py
View file @
802650ff
...
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
...
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
import
nni
import
nni
from
..lightning
import
LightningModule
,
_
AccuracyWithLogits
,
Lightning
from
..lightning
import
LightningModule
,
AccuracyWithLogits
,
Lightning
from
.trainer
import
Trainer
from
.trainer
import
Trainer
__all__
=
[
__all__
=
[
...
@@ -148,7 +148,7 @@ class _ClassificationModule(_MultiModelSupervisedLearningModule):
...
@@ -148,7 +148,7 @@ class _ClassificationModule(_MultiModelSupervisedLearningModule):
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
super
().
__init__
(
criterion
,
{
'acc'
:
_
AccuracyWithLogits
},
super
().
__init__
(
criterion
,
{
'acc'
:
AccuracyWithLogits
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
...
...
nni/nas/evaluator/pytorch/lightning.py
View file @
802650ff
...
@@ -27,7 +27,7 @@ from nni.typehint import Literal
...
@@ -27,7 +27,7 @@ from nni.typehint import Literal
__all__
=
[
__all__
=
[
'LightningModule'
,
'Trainer'
,
'DataLoader'
,
'Lightning'
,
'Classification'
,
'Regression'
,
'LightningModule'
,
'Trainer'
,
'DataLoader'
,
'Lightning'
,
'Classification'
,
'Regression'
,
'
_AccuracyWithLogits'
,
'_
SupervisedLearningModule'
,
'
_
ClassificationModule'
,
'
_
RegressionModule'
,
'SupervisedLearningModule'
,
'ClassificationModule'
,
'RegressionModule'
,
'AccuracyWithLogits'
,
# FIXME: hack to make it importable for tests
# FIXME: hack to make it importable for tests
]
]
...
@@ -102,12 +102,15 @@ class Lightning(Evaluator):
...
@@ -102,12 +102,15 @@ class Lightning(Evaluator):
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
fit_kwargs
Keyword arguments passed to ``trainer.fit()``.
"""
"""
def
__init__
(
self
,
lightning_module
:
LightningModule
,
trainer
:
Trainer
,
def
__init__
(
self
,
lightning_module
:
LightningModule
,
trainer
:
Trainer
,
train_dataloaders
:
Optional
[
Any
]
=
None
,
train_dataloaders
:
Optional
[
Any
]
=
None
,
val_dataloaders
:
Optional
[
Any
]
=
None
,
val_dataloaders
:
Optional
[
Any
]
=
None
,
train_dataloader
:
Optional
[
Any
]
=
None
):
train_dataloader
:
Optional
[
Any
]
=
None
,
fit_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
assert
isinstance
(
lightning_module
,
LightningModule
),
f
'Lightning module must be an instance of
{
__name__
}
.LightningModule.'
assert
isinstance
(
lightning_module
,
LightningModule
),
f
'Lightning module must be an instance of
{
__name__
}
.LightningModule.'
if
train_dataloader
is
not
None
:
if
train_dataloader
is
not
None
:
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
...
@@ -117,7 +120,7 @@ class Lightning(Evaluator):
...
@@ -117,7 +120,7 @@ class Lightning(Evaluator):
else
:
else
:
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert
(
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
))
or
isinstance
(
trainer
,
cgo_trainer
.
Trainer
),
\
assert
(
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
))
or
isinstance
(
trainer
,
cgo_trainer
.
Trainer
),
\
f
'Trainer must be imported from
{
__name__
}
or nni.
nas
.evaluator.pytorch.cgo.trainer'
f
'Trainer must be imported from
{
__name__
}
or nni.
retiarii
.evaluator.pytorch.cgo.trainer'
if
not
_check_dataloader
(
train_dataloaders
):
if
not
_check_dataloader
(
train_dataloaders
):
warnings
.
warn
(
f
'Please try to wrap PyTorch DataLoader with nni.trace or '
warnings
.
warn
(
f
'Please try to wrap PyTorch DataLoader with nni.trace or '
f
'import DataLoader from
{
__name__
}
:
{
train_dataloaders
}
'
,
f
'import DataLoader from
{
__name__
}
:
{
train_dataloaders
}
'
,
...
@@ -130,6 +133,7 @@ class Lightning(Evaluator):
...
@@ -130,6 +133,7 @@ class Lightning(Evaluator):
self
.
trainer
=
trainer
self
.
trainer
=
trainer
self
.
train_dataloaders
=
train_dataloaders
self
.
train_dataloaders
=
train_dataloaders
self
.
val_dataloaders
=
val_dataloaders
self
.
val_dataloaders
=
val_dataloaders
self
.
fit_kwargs
=
fit_kwargs
or
{}
@
staticmethod
@
staticmethod
def
_load
(
ir
):
def
_load
(
ir
):
...
@@ -178,7 +182,7 @@ class Lightning(Evaluator):
...
@@ -178,7 +182,7 @@ class Lightning(Evaluator):
The model to fit.
The model to fit.
"""
"""
self
.
module
.
set_model
(
model
)
self
.
module
.
set_model
(
model
)
return
self
.
trainer
.
fit
(
self
.
module
,
self
.
train_dataloaders
,
self
.
val_dataloaders
)
return
self
.
trainer
.
fit
(
self
.
module
,
self
.
train_dataloaders
,
self
.
val_dataloaders
,
**
self
.
fit_kwargs
)
def
_check_dataloader
(
dataloader
):
def
_check_dataloader
(
dataloader
):
...
@@ -194,7 +198,7 @@ def _check_dataloader(dataloader):
...
@@ -194,7 +198,7 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###
### The following are some commonly used Lightning modules ###
class
_
SupervisedLearningModule
(
LightningModule
):
class
SupervisedLearningModule
(
LightningModule
):
trainer
:
pl
.
Trainer
trainer
:
pl
.
Trainer
...
@@ -273,19 +277,19 @@ class _SupervisedLearningModule(LightningModule):
...
@@ -273,19 +277,19 @@ class _SupervisedLearningModule(LightningModule):
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
class
_
AccuracyWithLogits
(
torchmetrics
.
Accuracy
):
class
AccuracyWithLogits
(
torchmetrics
.
Accuracy
):
def
update
(
self
,
pred
,
target
):
def
update
(
self
,
pred
,
target
):
return
super
().
update
(
nn_functional
.
softmax
(
pred
,
dim
=-
1
),
target
)
return
super
().
update
(
nn_functional
.
softmax
(
pred
,
dim
=-
1
),
target
)
@
nni
.
trace
@
nni
.
trace
class
_
ClassificationModule
(
_
SupervisedLearningModule
):
class
ClassificationModule
(
SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
CrossEntropyLoss
,
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
export_onnx
:
bool
=
True
):
export_onnx
:
bool
=
True
):
super
().
__init__
(
criterion
,
{
'acc'
:
_
AccuracyWithLogits
},
super
().
__init__
(
criterion
,
{
'acc'
:
AccuracyWithLogits
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
export_onnx
=
export_onnx
)
...
@@ -341,14 +345,14 @@ class Classification(Lightning):
...
@@ -341,14 +345,14 @@ class Classification(Lightning):
if
train_dataloader
is
not
None
:
if
train_dataloader
is
not
None
:
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
train_dataloaders
=
train_dataloader
train_dataloaders
=
train_dataloader
module
=
_
ClassificationModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
module
=
ClassificationModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
train_dataloaders
=
train_dataloaders
,
val_dataloaders
=
val_dataloaders
)
train_dataloaders
=
train_dataloaders
,
val_dataloaders
=
val_dataloaders
)
@
nni
.
trace
@
nni
.
trace
class
_
RegressionModule
(
_
SupervisedLearningModule
):
class
RegressionModule
(
SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
MSELoss
,
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
MSELoss
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
...
@@ -406,7 +410,14 @@ class Regression(Lightning):
...
@@ -406,7 +410,14 @@ class Regression(Lightning):
if
train_dataloader
is
not
None
:
if
train_dataloader
is
not
None
:
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
train_dataloaders
=
train_dataloader
train_dataloaders
=
train_dataloader
module
=
_
RegressionModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
module
=
RegressionModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
train_dataloaders
=
train_dataloaders
,
val_dataloaders
=
val_dataloaders
)
train_dataloaders
=
train_dataloaders
,
val_dataloaders
=
val_dataloaders
)
# Alias for backwards compatibility
_SupervisedLearningModule
=
SupervisedLearningModule
_AccuracyWithLogits
=
AccuracyWithLogits
_ClassificationModule
=
ClassificationModule
_RegressionModule
=
RegressionModule
nni/nas/execution/common/graph.py
View file @
802650ff
...
@@ -13,7 +13,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List,
...
@@ -13,7 +13,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
Optional
,
Set
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
.muta
tor
import
Mutator
from
nni.nas
.muta
ble
import
Mutator
from
nni.nas.evaluator
import
Evaluator
from
nni.nas.evaluator
import
Evaluator
from
nni.nas.utils
import
uid
from
nni.nas.utils
import
uid
...
...
nni/nas/nn/pytorch/layers.py
View file @
802650ff
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
# If you've seen lint errors like `"Sequential" is not a known member of module`,
# please run `python test/vso_tools/trigger_import.py` to generate `_layers.py`.
from
pathlib
import
Path
from
pathlib
import
Path
# To make auto-completion happy, we generate a _layers.py that lists out all the classes.
# To make auto-completion happy, we generate a _layers.py that lists out all the classes.
...
...
test/algo/nas/test_cgo_engine.py
View file @
802650ff
...
@@ -152,7 +152,7 @@ def _new_trainer():
...
@@ -152,7 +152,7 @@ def _new_trainer():
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
multi_module
=
_MultiModelSupervisedLearningModule
(
nn
.
CrossEntropyLoss
,
{
'acc'
:
pl
.
_
AccuracyWithLogits
})
multi_module
=
_MultiModelSupervisedLearningModule
(
nn
.
CrossEntropyLoss
,
{
'acc'
:
pl
.
AccuracyWithLogits
})
lightning
=
pl
.
Lightning
(
multi_module
,
cgo_trainer
.
Trainer
(
use_cgo
=
True
,
lightning
=
pl
.
Lightning
(
multi_module
,
cgo_trainer
.
Trainer
(
use_cgo
=
True
,
max_epochs
=
1
,
max_epochs
=
1
,
...
@@ -201,7 +201,7 @@ class CGOEngineTest(unittest.TestCase):
...
@@ -201,7 +201,7 @@ class CGOEngineTest(unittest.TestCase):
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
multi_module
=
_MultiModelSupervisedLearningModule
(
nn
.
CrossEntropyLoss
,
{
'acc'
:
pl
.
_
AccuracyWithLogits
},
n_models
=
2
)
multi_module
=
_MultiModelSupervisedLearningModule
(
nn
.
CrossEntropyLoss
,
{
'acc'
:
pl
.
AccuracyWithLogits
},
n_models
=
2
)
lightning
=
pl
.
Lightning
(
multi_module
,
cgo_trainer
.
Trainer
(
use_cgo
=
True
,
lightning
=
pl
.
Lightning
(
multi_module
,
cgo_trainer
.
Trainer
(
use_cgo
=
True
,
max_epochs
=
1
,
max_epochs
=
1
,
...
@@ -225,7 +225,7 @@ class CGOEngineTest(unittest.TestCase):
...
@@ -225,7 +225,7 @@ class CGOEngineTest(unittest.TestCase):
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
multi_module
=
_MultiModelSupervisedLearningModule
(
nn
.
CrossEntropyLoss
,
{
'acc'
:
pl
.
_
AccuracyWithLogits
},
n_models
=
2
)
multi_module
=
_MultiModelSupervisedLearningModule
(
nn
.
CrossEntropyLoss
,
{
'acc'
:
pl
.
AccuracyWithLogits
},
n_models
=
2
)
lightning
=
pl
.
Lightning
(
multi_module
,
cgo_trainer
.
Trainer
(
use_cgo
=
True
,
lightning
=
pl
.
Lightning
(
multi_module
,
cgo_trainer
.
Trainer
(
use_cgo
=
True
,
max_epochs
=
1
,
max_epochs
=
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment